Skip to content

关于你在sgemm wmma的代码 我测试了存在2-3的误差 #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fpeanut opened this issue Apr 9, 2025 · 6 comments
Closed

关于你在sgemm wmma的代码 我测试了存在2-3的误差 #277

fpeanut opened this issue Apr 9, 2025 · 6 comments
Assignees

Comments

@fpeanut
Copy link

fpeanut commented Apr 9, 2025

Image
我在测试你参考的博主的代码https://zhuanlan.zhihu.com/p/555339335 是不存在误差的

@DefTruth DefTruth self-assigned this Apr 9, 2025
@DefTruth
Copy link
Member

DefTruth commented Apr 9, 2025

你是怎么测试的,我这边测试没有这么大的误差

@DefTruth
Copy link
Member

DefTruth commented Apr 9, 2025

这个是我随机测试的一个情况,随机看前边5个值,差异很小

----------------------------------------------------------------------------------------------------------------------------------
                                                       M=4096, N=8192, K=4096
                  out_f32x4(t8x8sk): ['17.0506820', '-42.684185', '-5.3834080', '90.9154586', '98.7183151'], time:9.569120ms, swizzle: NOOP, TFLOPS: 28.73 (+0.00%)
                 out_f32x4(t8x8bcf): ['17.0506820', '-42.684185', '-5.3834080', '90.9154586', '98.7183151'], time:8.955335ms, swizzle: NOOP, TFLOPS: 30.69 (+6.85%)
                out_f32x4(t8x8dbuf): ['17.0506820', '-42.684185', '-5.3834080', '90.9154586', '98.7183151'], time:8.150339ms, swizzle: NOOP, TFLOPS: 33.73 (+9.88%)
                    out_f32(cublas): ['17.0506820', '-42.684185', '-5.3834080', '90.9154586', '98.7183151'], time:7.699084ms, swizzle: NOOP, TFLOPS: 35.70 (+5.86%)
                         out_f32_th: ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:5.349826ms, swizzle: NOOP, TFLOPS: 51.38 (+43.91%)
--------------------------------------------------------------WMMA----------------------------------------------------------------
    out_tf32(mma2x4+warp2x4+stage3): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:7.280445ms, swizzle: NOOP, TFLOPS: 37.76
    out_tf32(mma2x4+warp2x4+stage2): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:6.516361ms, swizzle: NOOP, TFLOPS: 42.18
  out_tf32(mma2x4+...+stage3+dsmem): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:6.850147ms, swizzle: NOOP, TFLOPS: 40.13
  out_tf32(mma2x4+...+stage2+dsmem): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:6.511592ms, swizzle: NOOP, TFLOPS: 42.21
out_tf32(mma2x4+...+stage3+swizzle): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:7.203531ms, swizzle: 1024, TFLOPS: 38.16
out_tf32(mma2x4+...+stage2+swizzle): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:6.306934ms, swizzle: 1024, TFLOPS: 43.58
 out_tf32(...+stage3+dsmem+swizzle): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:6.800293ms, swizzle: 1024, TFLOPS: 40.42
 out_tf32(...+stage2+dsmem+swizzle): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:6.268692ms, swizzle: 1024, TFLOPS: 43.85
              out_tf32(cublas+tf32): ['17.0577831', '-42.689128', '-5.3686733', '90.9134826', '98.7108764'], time:5.119848ms, swizzle: NOOP, TFLOPS: 53.69 (+4.49%)
----------------------------------------------------------------------------------------------------------------------------------

@fpeanut
Copy link
Author

fpeanut commented Apr 10, 2025

int main() {
  int M = 256;
  int N = 256;
  int K = 256;

  std::vector<float> h_a(M * K);
  std::vector<float> h_b(K * N);
  std::vector<half> hh_a(M * K);
  std::vector<half> hh_b(K * N);
  std::vector<float> h_c_naive(M * N, 0.0f);
  std::vector<float> h_c_sliced(M * N, 0.0f);
  std::vector<float> h_c_sliced_8(M * N, 0.0f);
  std::vector<float> h_c_bnf(M * N, 0.0f);
  std::vector<float> h_c_wmma(M * N, 0.0f);
  std::vector<float> h_c(M * N, 0.0f);
  std::vector<half> hh_c(M * N, __float2half(0.0f));
  std::vector<half> hh_cpu(M * N, __float2half(0.0f));

  // Initialize matrices
  for (int i = 0; i < M * K; ++i) {
    h_a[i] = static_cast<float>(rand()) / RAND_MAX;
    hh_a[i]=__float2half(h_a[i]);
  }
  for (int i = 0; i < K * N; ++i) {
    h_b[i] = static_cast<float>(rand()) / RAND_MAX;
    hh_b[i]=__float2half(h_b[i]);
  }

  float *d_a, *d_b, *d_c_naive, *d_c_sliced,*d_c_sliced_8,*d_c_bnf,*d_c_wmma;
  half*dh_a,*dh_b,*dh_c;
  cudaMalloc(&d_a, M * K * sizeof(float));
  cudaMalloc(&d_b, K * N * sizeof(float));
  cudaMalloc(&d_c_naive, M * N * sizeof(float));
  cudaMalloc(&d_c_sliced, M * N * sizeof(float));
  cudaMalloc(&d_c_sliced_8, M * N * sizeof(float));
  cudaMalloc(&d_c_bnf, M * N * sizeof(float));
  cudaMalloc(&d_c_wmma, M * N * sizeof(float));

  cudaMalloc(&dh_a, M * K * sizeof(half));
  cudaMalloc(&dh_b, K * N * sizeof(half));
  cudaMalloc(&dh_c, M * N * sizeof(half));

  cudaMemcpy(d_a, h_a.data(), M * K * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(d_b, h_b.data(), K * N * sizeof(float), cudaMemcpyHostToDevice);

  cudaMemcpy(dh_a, hh_a.data(), M * K * sizeof(half), cudaMemcpyHostToDevice);
  cudaMemcpy(dh_b, hh_b.data(), K * N * sizeof(half), cudaMemcpyHostToDevice);

  // Naive kernel launch
  dim3 block_naive(32, 32);
  dim3 grid_naive((N + 32 - 1) / 32, (M + 32 - 1) / 32);
  auto start_naive = std::chrono::high_resolution_clock::now();
  sgemm_naive_f32_kernel<<<grid_naive, block_naive>>>(d_a, d_b, d_c_naive, M, N, K);
  cudaDeviceSynchronize();
  auto end_naive = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> duration_naive = end_naive - start_naive;
  std::cout << "Naive kernel execution time: " << duration_naive.count() << " seconds" << std::endl;

  // Sliced kernel launch
  dim3 block_sliced(32, 32);
  dim3 grid_sliced((N + 32 - 1) / 32, (M + 32 - 1) / 32);
  auto start_sliced = std::chrono::high_resolution_clock::now();
  sgemm_sliced_k_f32_kernel<32, 32, 32><<<grid_sliced, block_sliced>>>(d_a, d_b, d_c_sliced, M, N, K);
  cudaDeviceSynchronize();
  auto end_sliced = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> duration_sliced = end_sliced - start_sliced;
  std::cout << "Sliced kernel execution time: " << duration_sliced.count() << " seconds" << std::endl;

  // Sliced kernel launch
  dim3 block_sliced_8(128/TM_, 128/TN_);
  dim3 grid_sliced_8((N +128 - 1) /  128, (M +  128 - 1) / 128);
  auto start_sliced_8 = std::chrono::high_resolution_clock::now();
  sgemm_t_8x8_sliced_k_f32x4_kernel<128, 128, 8><<<grid_sliced_8, block_sliced_8>>>(d_a, d_b, d_c_sliced_8, M, N, K);
  cudaDeviceSynchronize();
  auto end_sliced_8 = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> duration_sliced_8 = end_sliced_8 - start_sliced_8;
  std::cout << "Sliced kernel execution time: " << duration_sliced.count() << " seconds" << std::endl;

    // Sliced kernel launch
  dim3 block_bnf(128/TM_, 128/TN_);
  dim3 grid_bnf((N +128 - 1) /  128, (M +  128 - 1) / 128);
  auto start_bnf = std::chrono::high_resolution_clock::now();
  sgemm_t_8x8_sliced_k_f32x4_bcf_dbuf_kernel<128, 128, 8><<<grid_bnf, block_bnf>>>(d_a, d_b, d_c_bnf, M, N, K);
  cudaDeviceSynchronize();
  auto end_bnf = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> duration_bnf = end_bnf - start_bnf;
  std::cout << "Sliced kernel execution time: " << duration_bnf.count() << " seconds" << std::endl;

  dim3 block_wmma(256);
  dim3 grid_wmma((N +128 - 1) /  128, (M +  128 - 1) / 128);
  sgemm_wmma_m16n16k8_mma4x2_warp2x4_stages_kernel<<<grid_wmma,block_wmma>>>(d_a, d_b, d_c_wmma, M, N, K);

  cpu_gemm(h_a.data(),h_b.data(),h_c.data(),M,N,K);
  cpuF16F16Gemm(hh_a.data(),hh_b.data(),hh_cpu.data(),M,N,K);

  dim3 block_wmmah(256);
  dim3 grid_wmmah((N +256 - 1) /  256, (M +  128 - 1) / 128);
  myHGEMMAlignedV1<<<grid_wmmah,block_wmmah>>>(dh_a,dh_b,dh_c,M,N,K);

  cudaMemcpy(h_c_naive.data(), d_c_naive, M * N * sizeof(float), cudaMemcpyDeviceToHost);
  cudaMemcpy(h_c_sliced.data(), d_c_sliced, M * N * sizeof(float), cudaMemcpyDeviceToHost);
  cudaMemcpy(h_c_sliced_8.data(), d_c_sliced_8, M * N * sizeof(float), cudaMemcpyDeviceToHost);
  cudaMemcpy(h_c_bnf.data(), d_c_bnf, M * N * sizeof(float), cudaMemcpyDeviceToHost);
  cudaMemcpy(h_c_wmma.data(), d_c_wmma, M * N * sizeof(float), cudaMemcpyDeviceToHost);
  cudaMemcpy(hh_c.data(), dh_c, M * N * sizeof(half), cudaMemcpyDeviceToHost);

  // Verification (optional)
  // ...
  //打印部分结果
  std::cout << "Naive Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << h_c_naive[i] << " ";
  }
  std::cout << std::endl;

  std::cout << "Sliced Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << h_c_sliced[i] << " ";
  }
  std::cout << std::endl;

  std::cout << "Sliced Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << h_c_sliced_8[i] << " ";
  }
  std::cout << std::endl;

  std::cout << "bnf Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << h_c_bnf[i] << " ";
  }
  std::cout << std::endl;

  std::cout << "wmma Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << h_c_wmma[i] << " ";
  }
  std::cout << std::endl;

  std::cout << "wmma half Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout <<__half2float(hh_c[i]) << " ";
  }
  std::cout << std::endl;

  std::cout << "cpu Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << h_c[i] << " ";
  }
  std::cout << std::endl;

  std::cout << "cpu half Result (first 10 elements):" << std::endl;
  for(int i = 1000; i < 1010; ++i){
      std::cout << __half2float(hh_cpu[i])  << " ";
  }
  std::cout << std::endl;

  cudaFree(d_a);
  cudaFree(d_b);
  cudaFree(d_c_naive);
  cudaFree(d_c_sliced);
  cudaFree(d_c_sliced_8);
  cudaFree(d_c_bnf);

  return 0;
}

这是我的测试代码,我利用c++代码测试

@fpeanut
Copy link
Author

fpeanut commented Apr 10, 2025

我详细手推了你代码的计算过程,计算过程正确;不过我发现wmma的api利用tf32的数据确实会精度误差大一点,利用half的精度误差就会小一点;我暂时还不知道这是什么原因造成的

@DefTruth
Copy link
Member

tf32的kernel不能直接用,还需要前置float32 -> tf32数据类型的转换,可以参考我pytorch bind的代码

f32x4_tf32x4_kernel<<<((Na + T * 4 - 1)/(T * 4)), T>>>(
    reinterpret_cast<float*>(a.data_ptr()),
    reinterpret_cast<float*>(a.data_ptr()),
    Na);

  f32x4_tf32x4_kernel<<<((Nb + T * 4 - 1)/(T * 4)), T>>>(
    reinterpret_cast<float*>(b.data_ptr()),
    reinterpret_cast<float*>(b.data_ptr()),
    Nb);

@fpeanut
Copy link
Author

fpeanut commented Apr 10, 2025

可以的 谢谢大佬解惑

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants