Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] conv2d int8 doesn't work with python #1990

Open
IzanCatalan opened this issue Dec 16, 2024 · 1 comment
Open

[BUG] conv2d int8 doesn't work with python #1990

IzanCatalan opened this issue Dec 16, 2024 · 1 comment
Labels
? - Needs Triage bug Something isn't working

Comments

@IzanCatalan
Copy link

Describe the bug
Hello, I am trying to perform a Conv2D forward propagation with int8 in Python using an example similar to the one at https://github.com/NVIDIA/cutlass/blob/main/examples/python/03_basic_conv2d.ipynb. My GPU is an NVIDIA V100. However, I encounter an error with any configuration I use for int (8, 16, or 32), and the output is as follows:

Traceback (most recent call last):
  File "prova.py", line 60, in <module>
    plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
  File "/mnt/beegfs/gap/[email protected]/cutlass/python/cutlass/op/conv.py", line 920, in __init__
    super().__init__(
  File "/mnt/beegfs/gap/[email protected]/cutlass/python/cutlass/op/conv.py", line 283, in __init__
    self._reset_operations()
  File "/mnt/beegfs/gap/[email protected]/cutlass/python/cutlass/op/conv.py", line 310, in _reset_operations
    raise Exception(f'No kernel configuration found for supported data type and layout '
Exception: No kernel configuration found for supported data type and layout combination (<DataType.s8: 10>, <DataType.s8: 10>, <DataType.f32: 18>)x(<LayoutType.TensorNHWC: 10>, <LayoutType.TensorNHWC: 10>)

My code is the following, and I would like to know if there is any available configuration for ints in python:


import torch
import random

import cutlass

# This controls whether the C++ GEMM declaration will be printed at each step. 
# Set to `false` to omit this information.
print_module = True

# Input tensor: [N, H, W, C] under the channel-last layout
N, H, W, C = [32, 28, 28, 64]

# Weight tensor: [K, R, S, C] under the channel-last layout
K, R, S = [128, 3, 3]

# Stride, and padding
stride = (2, 2)
padding = (1, 1)
dilation = (1, 1)

print("HOST: output size")
# Compute the output size [N, P, Q, K]
N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

dtype = torch.int8
type_A = torch.int8
type_B = torch.int8
type_C = torch.int8
type_D = torch.int8

torch.manual_seed(1234)

print("HOST: create tensors")

# Aseguramos que los valores inicializados estén dentro del rango válido para int8 (-128 a 127)
input = torch.randint(
    low=-128, high=127, size=(N, C, H, W), dtype=type_A, device="cuda"
).to(memory_format=torch.channels_last)

weight = torch.randint(
    low=-128, high=127, size=(K, C, R, S), dtype=type_B, device="cuda"
).to(memory_format=torch.channels_last)

tensor_C = torch.randint(
    low=-128, high=127, size=(N, K, P, Q), dtype=type_C, device="cuda"
).to(memory_format=torch.channels_last)

output = torch.zeros_like(tensor_C)

tensor_D = torch.randint(
    low=-128, high=127, size=(N, C, H, W), dtype=type_D, device="cuda"
).to(memory_format=torch.channels_last)


alpha = 1.0
beta = 0.0

print("HOST: run first conv")
# Specifying `element_accumulator` is not required if it is the same as `element`
plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)


output_torch = alpha * torch.ops.aten.conv2d(
    input, weight, stride=stride, padding=padding, dilation=dilation
) + beta * tensor_C

# print (output)
print(torch.equal(output_torch, output))

print("HOST: run second conv")
plan.run(tensor_D, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)

print (output)

@jackkosaian
Copy link
Contributor

I think that you should be able to get something similar to what you want working by adding the following line in python/cutlass/library_defaults.py (here)

                (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),

You'll need to use element_accumulator=torch.int32

Can you try this out?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants