Skip to content

[Bug] from __future__ import annotations breaks type annotation containing local variable #18379

@oraluben

Description

@oraluben

Simple reproducer is attached below. For detailed background, see tile-ai/tilelang#1079

Expected behavior

test program exit without error

Actual behavior

$ TVM_BACKTRACE=1 python test.py
error: Unexpected type for TIR Arg: ffi.String
 --> /home/yyc/repo/tvm/test.py:9:5
   |  
 9 |      def f(A: T.Buffer((M,), "float32")):
   |      ^^^^^^^^                            
Traceback (most recent call last):
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 309, in _wrapper
    return func(self, node)
           ^^^^^^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/parser.py", line 409, in visit_function_def
    param = T.arg(arg.arg, ann)
            ^^^^^^^^^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/ir_builder/tir/ir.py", line 200, in arg
    return _ffi_api.Arg(name, obj)  # type: ignore[attr-defined] # pylint: disable=no-member
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
ValueError: Unexpected type for TIR Arg: ffi.String

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/yyc/repo/tvm/test.py", line 17, in <module>
    f()
  File "/home/yyc/repo/tvm/test.py", line 8, in f
    @T.prim_func
     ^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/entry.py", line 72, in prim_func
    return decorator_wrapper(func)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/entry.py", line 65, in decorator_wrapper
    f = parse(func, utils.inspect_function_capture(func), check_well_formed=check_well_formed)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/entry.py", line 103, in parse
    parser.parse(extra_vars=extra_vars)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 379, in parse
    self.visit(node)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 638, in visit
    self.report_error(node, err)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 587, in report_error
    raise err
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 636, in visit
    func(node)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/doc.py", line 256, in generic_visit
    self.visit(value)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 624, in visit
    self.visit(item)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 638, in visit
    self.report_error(node, err)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 587, in report_error
    raise err
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 636, in visit
    func(node)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 684, in visit_FunctionDef
    _dispatch_wrapper(func)(self, node)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 311, in _wrapper
    self.report_error(node, err)
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 607, in report_error
    raise diag_err
  File "/home/yyc/repo/tvm/python/tvm/script/parser/core/parser.py", line 309, in _wrapper
    return func(self, node)
           ^^^^^^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/parser/tir/parser.py", line 409, in visit_function_def
    param = T.arg(arg.arg, ann)
            ^^^^^^^^^^^^^^^^^^^
  File "/home/yyc/repo/tvm/python/tvm/script/ir_builder/tir/ir.py", line 200, in arg
    return _ffi_api.Arg(name, obj)  # type: ignore[attr-defined] # pylint: disable=no-member
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 758, in core.Function.__call__
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.DiagnosticError: one or more error diagnostics were emitted, please check diagnostic render for output.

Environment

Any environment details, such as: Operating System, TVM version, etc

Steps to reproduce

from __future__ import annotations

from tvm.script import tir as T


def f(M=1):

    @T.prim_func
    def f(A: T.Buffer((M,), "float32")):
        pass

    return f


f()

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions