Skip to content

[BUG] conv2d int8 doesn't work with python #1990

Closed
@IzanCatalan

Description

@IzanCatalan

Describe the bug
Hello, I am trying to perform a Conv2D forward propagation with int8 in Python using an example similar to the one at https://github.com/NVIDIA/cutlass/blob/main/examples/python/03_basic_conv2d.ipynb. My GPU is an NVIDIA V100. However, I encounter an error with any configuration I use for int (8, 16, or 32), and the output is as follows:

Traceback (most recent call last):
  File "prova.py", line 60, in <module>
    plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
  File "/mnt/beegfs/gap/[email protected]/cutlass/python/cutlass/op/conv.py", line 920, in __init__
    super().__init__(
  File "/mnt/beegfs/gap/[email protected]/cutlass/python/cutlass/op/conv.py", line 283, in __init__
    self._reset_operations()
  File "/mnt/beegfs/gap/[email protected]/cutlass/python/cutlass/op/conv.py", line 310, in _reset_operations
    raise Exception(f'No kernel configuration found for supported data type and layout '
Exception: No kernel configuration found for supported data type and layout combination (<DataType.s8: 10>, <DataType.s8: 10>, <DataType.f32: 18>)x(<LayoutType.TensorNHWC: 10>, <LayoutType.TensorNHWC: 10>)

My code is the following, and I would like to know if there is any available configuration for ints in python:


import torch
import random

import cutlass

# This controls whether the C++ GEMM declaration will be printed at each step. 
# Set to `false` to omit this information.
print_module = True

# Input tensor: [N, H, W, C] under the channel-last layout
N, H, W, C = [32, 28, 28, 64]

# Weight tensor: [K, R, S, C] under the channel-last layout
K, R, S = [128, 3, 3]

# Stride, and padding
stride = (2, 2)
padding = (1, 1)
dilation = (1, 1)

print("HOST: output size")
# Compute the output size [N, P, Q, K]
N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

dtype = torch.int8
type_A = torch.int8
type_B = torch.int8
type_C = torch.int8
type_D = torch.int8

torch.manual_seed(1234)

print("HOST: create tensors")

# Aseguramos que los valores inicializados estén dentro del rango válido para int8 (-128 a 127)
input = torch.randint(
    low=-128, high=127, size=(N, C, H, W), dtype=type_A, device="cuda"
).to(memory_format=torch.channels_last)

weight = torch.randint(
    low=-128, high=127, size=(K, C, R, S), dtype=type_B, device="cuda"
).to(memory_format=torch.channels_last)

tensor_C = torch.randint(
    low=-128, high=127, size=(N, K, P, Q), dtype=type_C, device="cuda"
).to(memory_format=torch.channels_last)

output = torch.zeros_like(tensor_C)

tensor_D = torch.randint(
    low=-128, high=127, size=(N, C, H, W), dtype=type_D, device="cuda"
).to(memory_format=torch.channels_last)


alpha = 1.0
beta = 0.0

print("HOST: run first conv")
# Specifying `element_accumulator` is not required if it is the same as `element`
plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)


output_torch = alpha * torch.ops.aten.conv2d(
    input, weight, stride=stride, padding=padding, dilation=dilation
) + beta * tensor_C

# print (output)
print(torch.equal(output_torch, output))

print("HOST: run second conv")
plan.run(tensor_D, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)

print (output)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions