-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Which component has the problem?
CUTLASS C++
Bug Report
Describe the bug
Bug location: cutlass/epilogue/threadblock/output_tile_thread_map.h, in struct RowArrangement
In the creation of the warp arrangement within a threadblock for the epilogue of a GEMM, the kAccessWidth is computed using a ternary expression:
static int const kAccessWidth =
(Detail::kTargetAccessRows > Detail::kShapeRow ?
kWarpSize / Detail::kShapeRow
: const_min(
Detail::kShapeWidth,
const_min(kWarpSize, Detail::kTargetMemoryAccessWidth)
));
kAccessWidth and kAccessRows represent the 2D arrangement in which one warp accesses the output tile with the preferred 256-byte width.
In the else case, we correctly check that kAccessWidth does not exceed Detail::kShapeWidth.
However, the should be done in the then-case, too.
With a threadblock shape of 160x160x32, we can trigger this case and obtain the following values:
debug::kv<debug::kShapeRow, 1>,
debug::kv<debug::kShapeWidth, 20>,
debug::kv<debug::kTargetMemoryAccessWidth, 16>,
debug::kv<debug::kTargetAccessRows, 2>,
debug::kv<debug::kAccessWidth, 32>,
debug::kv<debug::kAccessRows, 1>,
debug::kv<debug::kIterationsRow, 1>,
debug::kv<debug::kDeltaRow, 1>,
debug::kv<debug::kIterationsColumn, 0>,
debug::kv<debug::kDeltaColumn, 256>
Since kAccessWidth is larger than Detail::kShapeWidth, kIterationsColumn becomes 0 and the assertion fails.
It is arguable whether 160x160 is an acceptable choice for a threadblock size, but the value should be computed with consistent constraints and be robust against abnormal inputs.
But here, the min constraint is only applied in the else case, not in the then case.
Steps/Code to reproduce bug
- Copy the CUDA file and the Makefile below in some directory.
- Set
CUTLASS_DIRin the top line of the Makefile to the path to your CUTLASS directory. - Run
make errorto compile the code for the threadblock shape 160x160x32 where the erroneous computation arises. Accordingly, the assertionstatic_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access");in the structRowArrangementfails. - (Optional) Run
maketo compile the code for a threadblock shape 128x128x32 where the compilation error does not occur.
Expected behavior
The memory access width should not be larger than the actual width Detail::kShapeWidth that is derived from the shape of the output tile.
This is enforced correctly in the else case of the ternary operation. There, const_min(Detail::kShapeWidth, ...) is already included.
I suggest to do the same thing for the then case in the ternary expression.
That is, change:
static int const kAccessWidth =
(Detail::kTargetAccessRows > Detail::kShapeRow ?
kWarpSize / Detail::kShapeRow
: const_min(
Detail::kShapeWidth,
const_min(kWarpSize, Detail::kTargetMemoryAccessWidth)
));
to
static int const kAccessWidth =
(Detail::kTargetAccessRows > Detail::kShapeRow ?
const_min(Detail::kShapeWidth, kWarpSize / Detail::kShapeRow)
: const_min(
Detail::kShapeWidth,
const_min(kWarpSize, Detail::kTargetMemoryAccessWidth)
));
Environment details (please complete the following information):
I compiled the file in the following environment:
- OS: Ubuntu 24.04.3 LTS
- CUDA Developer Toolkit: 13.0
Files to reproduce
Upload failed, so I paste them here:
Makefile
CUTLASS_DIR=/path/to/cutlass
NVCC=nvcc
XFLAGS=-Xcompiler=-Wno-psabi -Xcompiler=-fno-strict-aliasing
DEFINES=-DNDEBUG
INCLUDES=-I${CUTLASS_DIR}/include -I${CUTLASS_DIR}/tools/util/include
NVCCFLAGS=-std=c++17 -O3 ${XFLAGS} ${INCLUDES} \
--expt-relaxed-constexpr
LDFLAGS=
LDLIBS=-lcuda
success: access_width_exceeded.cu
${NVCC} ${NVCCFLAGS} ${DEFINES} ${LDFLAGS} -o $@ $< ${LDLIBS}
error: access_width_exceeded.cu
${NVCC} ${NVCCFLAGS} ${DEFINES} -DDEBUG_THREADBLOCK_FAIL ${LDFLAGS} -o $@ $< ${LDLIBS}
access_width_exceeded.cu
#include <iostream>
#include <string>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/device/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "cuda_runtime.h"
/**
* Panic wrapper for unwinding CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8)
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::half_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::half_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes)
// Multiply-accumulate blocking/pipelining details
using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
#ifdef DEBUG_THREADBLOCK_FAIL
using ThreadblockShape = cutlass::gemm::GemmShape<160, 160, 32>; // Threadblock-level tile size (concept: GemmShape)
#else
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape)
#endif
using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; // Warp-level tile size (concept: GemmShape)
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape)
constexpr int NumStages = 3; // Number of global->shared pipeline stages used in the GEMM mainloop
// Epilogue output operator
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementC, // Element type for C and D matrix operands
AlignmentC, // Memory access granularity of C and D matrix in units of elements
ElementAccumulator, // Element type from internal accumaccumulation
ElementAccumulator>; // Data type used to compute linear combination
// Reference device GEMM implementation type
using DeviceGemmReference = cutlass::reference::device::Gemm<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementAccumulator,
ElementAccumulator>;
// Classic data-parallel device GEMM implementation type
using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal<
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
NumStages,
AlignmentA,
AlignmentB>;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Result structure
struct Result
{
double avg_runtime_ms;
double gflops;
cutlass::Status status;
cudaError_t error;
bool passed;
Result(
double avg_runtime_ms = 0,
double gflops = 0,
cutlass::Status status = cutlass::Status::kSuccess,
cudaError_t error = cudaSuccess)
:
avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true)
{}
};
/// Command line options parsing
struct Options
{
std::string command_name;
bool help;
cutlass::gemm::GemmCoord problem_size;
float alpha;
float beta;
int split_k_factor;
int avail_sms;
bool reference_check;
int iterations;
bool init_dyn_only;
bool streamk;
bool baseline;
bool run;
int cohort_m;
int cohort_n;
cutlass::HostTensor<ElementA, LayoutA> tensor_a;
cutlass::HostTensor<ElementB, LayoutB> tensor_b;
cutlass::HostTensor<ElementC, LayoutC> tensor_c;
cutlass::HostTensor<ElementC, LayoutC> tensor_d;
cutlass::HostTensor<ElementC, LayoutC> tensor_ref_d;
Options(std::string command_name) :
command_name(command_name),
help(false),
problem_size({2048, 2048, 2048}),
alpha(1.0f),
beta(0.0f),
split_k_factor(1),
avail_sms(-1), // Number of device SMs to use is unlimited
reference_check(true),
iterations(10000),
init_dyn_only(false),
streamk(false),
baseline(false),
run(false),
cohort_m(8),
cohort_n(4)
{}
bool valid() const
{
return true;
}
void parse(int argc, char const **args)
{
cutlass::CommandLine cmd(argc, args);
if (cmd.check_cmd_line_flag("help")) {
help = true;
}
cmd.get_cmd_line_argument("m", problem_size.m());
cmd.get_cmd_line_argument("n", problem_size.n());
cmd.get_cmd_line_argument("k", problem_size.k());
cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);
cmd.get_cmd_line_argument("split", split_k_factor);
cmd.get_cmd_line_argument("iterations", iterations);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const
{
out
<< "Performs a GEMM computation.\n"
<< "\n"
<< "Options:\n"
<< "\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M dimension\n"
<< " --n=<int> GEMM N dimension\n"
<< " --k=<int> GEMM K dimension\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --split=<int> Split-K factor to emulate\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
out
<< "\n\nExamples:\n\n"
<< "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n";
return out;
}
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options
typename DeviceGemmBasic::Arguments args_from_options(
const DeviceGemmBasic &device_gemm,
const Options &options,
cutlass::HostTensor<ElementA, LayoutA> &tensor_a,
cutlass::HostTensor<ElementB, LayoutB> &tensor_b,
cutlass::HostTensor<ElementC, LayoutC> &tensor_c,
cutlass::HostTensor<ElementC, LayoutC> &tensor_d)
{
return typename DeviceGemmBasic::Arguments(
cutlass::gemm::GemmUniversalMode::kGemm, // universal mode
options.problem_size, // problem_size
options.split_k_factor, // batch count / splitk slices
{ // epilogue parameters
ElementAccumulator(options.alpha),
ElementAccumulator(options.beta)
},
tensor_a.device_data(), // ptr_A
tensor_b.device_data(), // ptr_B
tensor_c.device_data(), // ptr_C
tensor_d.device_data(), // ptr_D
options.problem_size.mk().product(), // batch_stride_A
options.problem_size.nk().product(), // batch_stride_B
options.problem_size.mn().product(), // batch_stride_C
options.problem_size.mn().product(), // batch_stride_D
tensor_a.layout().stride(0), // stride_a
tensor_b.layout().stride(0), // stride_b
tensor_c.layout().stride(0), // stride_c
tensor_d.layout().stride(0)); // stride_d
}
template <typename DeviceGemmT>
Result check(std::string description, Options &options)
{
// Display test description
std::cout << std::endl << description << std::endl;
// Zero-initialize test output matrix D
cutlass::reference::host::TensorFill(options.tensor_d.host_view());
options.tensor_d.sync_device();
// Instantiate CUTLASS kernel depending on templates
DeviceGemmT device_gemm;
// Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT
auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d);
// Using the arguments, query for extra workspace required for matrix multiplication computation
size_t workspace_size = DeviceGemmT::get_workspace_size(arguments);
// Allocate workspace memory
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Check the problem size is supported or not
CUTLASS_CHECK(device_gemm.can_implement(arguments));
// Initialize CUTLASS kernel with arguments and workspace pointer
CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get()));
// Correctness / Warmup iteration
CUTLASS_CHECK(device_gemm());
// Copy output data from CUTLASS and reference kernel to host for comparison
options.tensor_d.sync_host();
// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
result.passed = cutlass::reference::host::TensorEquals(
options.tensor_d.host_view(),
options.tensor_ref_d.host_view());
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;
return result;
}
/// Program entrypoint
int main(int argc, const char **argv)
{
// Current device must must have compute capability at least 80
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(¤t_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
if (!((props.major * 10 + props.minor) >= 80))
{
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
// Parse commandline options
Options options("ampere_streamk_gemm");
options.parse(argc, argv);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}
//
// Initialize GEMM datasets
//
// Initialize tensors using CUTLASS helper functions
options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K
options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N
options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N
options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel
options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel
// Fill matrix A on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_a.host_view(),
1,
ElementA(2),
ElementA(-2),
0);
// Fill matrix B on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_b.host_view(),
1,
ElementB(2),
ElementB(-2),
0);
// Fill matrix C on host with uniform-random data [-2, 2]
cutlass::reference::host::TensorFillRandomUniform(
options.tensor_c.host_view(),
1,
ElementC(2),
ElementC(-2),
0);
//
// Compute reference output
//
// Copy data from host to GPU
options.tensor_a.sync_device();
options.tensor_b.sync_device();
options.tensor_c.sync_device();
// Zero-initialize reference output matrix D
cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view());
options.tensor_ref_d.sync_device();
// Create instantiation for device reference gemm kernel
DeviceGemmReference gemm_reference;
// Launch device reference gemm kernel
gemm_reference(
options.problem_size,
ElementAccumulator(options.alpha),
options.tensor_a.device_ref(),
options.tensor_b.device_ref(),
ElementAccumulator(options.beta),
options.tensor_c.device_ref(),
options.tensor_ref_d.device_ref());
// Wait for kernels to finish
CUDA_CHECK(cudaDeviceSynchronize());
// Copy output data from reference kernel to host for comparison
options.tensor_ref_d.sync_host();
//
// Evaluate CUTLASS kernels
//
Result basic_dp = check<DeviceGemmBasic>("Basic data-parallel GEMM", options);
return 0;
}