From 1cd37d680e606f837d5676585662e52c67002a35 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 26 Jul 2023 13:00:22 -0500 Subject: [PATCH] improve location tracking (and test it) --- mlir_utils/dialects/ext/arith.py | 39 +++++++---- mlir_utils/dialects/ext/func.py | 6 +- mlir_utils/dialects/ext/scf.py | 14 ++-- mlir_utils/util.py | 21 ++---- pyproject.toml | 2 +- tests/test_location_tracking.py | 111 +++++++++++++++++++++++++++++++ 6 files changed, 160 insertions(+), 33 deletions(-) create mode 100644 tests/test_location_tracking.py diff --git a/mlir_utils/dialects/ext/arith.py b/mlir_utils/dialects/ext/arith.py index f3c9241..f082b32 100644 --- a/mlir_utils/dialects/ext/arith.py +++ b/mlir_utils/dialects/ext/arith.py @@ -13,21 +13,22 @@ _is_index_type, ) from mlir.ir import ( + Attribute, + Context, + DenseElementsAttr, + IndexType, + IntegerAttr, + IntegerType, + Location, OpView, Operation, + RankedTensorType, Type, Value, - IndexType, - RankedTensorType, - IntegerAttr, - IntegerType, - DenseElementsAttr, register_attribute_builder, - Context, - Attribute, ) -from mlir_utils.util import get_result_or_results, maybe_cast +from mlir_utils.util import get_result_or_results, maybe_cast, get_user_code_loc try: from mlir_utils.dialects.arith import * @@ -41,6 +42,8 @@ def constant( value: Union[int, float, bool, np.ndarray], type: Optional[Type] = None, index: Optional[bool] = None, + *, + loc: Location = None, ) -> arith_dialect.ConstantOp: """Instantiate arith.constant with value `value`. @@ -56,6 +59,8 @@ def constant( Returns: ir.OpView instance that corresponds to instantiated arith.constant op. """ + if loc is None: + loc = get_user_code_loc() if index is not None and index: type = IndexType.get() if type is None: @@ -73,8 +78,9 @@ def constant( value, type=type, ) - - return maybe_cast(get_result_or_results(arith_dialect.ConstantOp(type, value))) + return maybe_cast( + get_result_or_results(arith_dialect.ConstantOp(type, value, loc=loc)) + ) class ArithValueMeta(type(Value)): @@ -217,7 +223,12 @@ def _arith_CmpFPredicateAttr(predicate: str | Attribute, context: Context): def _binary_op( - lhs: "ArithValue", rhs: "ArithValue", op: str, predicate: str = None + lhs: "ArithValue", + rhs: "ArithValue", + op: str, + predicate: str = None, + *, + loc: Location = None, ) -> "ArithValue": """Generic for handling infix binary operator dispatch. @@ -230,6 +241,8 @@ def _binary_op( Returns: Result of binary operation. This will be a handle to an arith(add|sub|mul) op. """ + if loc is None: + loc = get_user_code_loc() if not isinstance(rhs, lhs.__class__): rhs = lhs.__class__(rhs, dtype=lhs.type) @@ -258,9 +271,9 @@ def _binary_op( predicate = "s" + predicate else: predicate = "u" + predicate - return lhs.__class__(op(predicate, lhs, rhs), dtype=lhs.dtype) + return lhs.__class__(op(predicate, lhs, rhs, loc=loc), dtype=lhs.dtype) else: - return lhs.__class__(op(lhs, rhs), dtype=lhs.dtype) + return lhs.__class__(op(lhs, rhs, loc=loc), dtype=lhs.dtype) class ArithValue(Value, metaclass=ArithValueMeta): diff --git a/mlir_utils/dialects/ext/func.py b/mlir_utils/dialects/ext/func.py index 68b384d..33385cb 100644 --- a/mlir_utils/dialects/ext/func.py +++ b/mlir_utils/dialects/ext/func.py @@ -8,6 +8,7 @@ TypeAttr, FlatSymbolRefAttr, Type, + Location, ) from mlir_utils.util import ( @@ -103,13 +104,16 @@ def emit(self): # this is the func op itself (funcs never have a resulting ssa value) return maybe_cast(get_result_or_results(func_op)) - def __call__(self, *call_args): + def __call__(self, *call_args, loc: Location = None): + if loc is None: + loc = get_user_code_loc() if not self.emitted: self.emit() call_op = self.call_op_ctor( [r.type for r in self.results], FlatSymbolRefAttr.get(self.func_name), call_args, + loc=loc, ) return maybe_cast(get_result_or_results(call_op)) diff --git a/mlir_utils/dialects/ext/scf.py b/mlir_utils/dialects/ext/scf.py index 995956a..f36dba2 100644 --- a/mlir_utils/dialects/ext/scf.py +++ b/mlir_utils/dialects/ext/scf.py @@ -6,7 +6,7 @@ import libcst as cst import libcst.matchers as m from bytecode import ConcreteBytecode, ConcreteInstr -from mlir.dialects import scf +from mlir.dialects.scf import IfOp, ForOp from mlir.ir import InsertionPoint, Value, OpResultList, OpResult from mlir_utils.ast.canonicalize import ( @@ -24,6 +24,7 @@ maybe_cast, _update_caller_vars, get_result_or_results, + get_user_code_loc, ) logger = logging.getLogger(__name__) @@ -49,7 +50,9 @@ def _for( stop = constant(stop, index=True) if isinstance(step, int): step = constant(step, index=True) - return scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip) + if loc is None: + loc = get_user_code_loc() + return ForOp(start, stop, step, iter_args, loc=loc, ip=ip) for_ = region_op(_for, terminator=yield__) @@ -91,7 +94,9 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None): results_ = [] if results_: has_else = True - return scf.IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip) + if loc is None: + loc = get_user_code_loc() + return IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip) if_ = region_op(_if, terminator=yield__) @@ -100,7 +105,7 @@ def _if(cond, results_=None, *, has_else=False, loc=None, ip=None): class IfStack: - __current_if_op: list[scf.IfOp] = [] + __current_if_op: list[IfOp] = [] __if_ip: list[InsertionPoint] = [] @staticmethod @@ -423,6 +428,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f): f.__globals__[end_if.__name__] = end_if f.__globals__[stack_if.__name__] = stack_if f.__globals__[stack_yield.__name__] = stack_yield + f.__globals__[yield_.__name__] = yield_ f.__globals__["_placeholder_opaque_t"] = _placeholder_opaque_t return code diff --git a/mlir_utils/util.py b/mlir_utils/util.py index f34f05c..365b6c1 100644 --- a/mlir_utils/util.py +++ b/mlir_utils/util.py @@ -134,7 +134,8 @@ def builder_wrapper(body_builder): f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}" ) - op.regions[0].blocks.append(*types) + arg_locs = [get_user_code_loc()] * len(sig.parameters) + op.regions[0].blocks.append(*types, arg_locs=arg_locs) with InsertionPoint(op.regions[0].blocks[0]): results = body_builder( *[maybe_cast(a) for a in op.regions[0].blocks[0].arguments] @@ -209,17 +210,9 @@ def get_user_code_loc(): mlir_utis_root_path = Path(mlir_utils.__path__[0]) prev_frame = inspect.currentframe().f_back - stack = traceback.StackSummary.extract(traceback.walk_stack(prev_frame)) - - user_frame = next( - ( - fr - for fr in stack - if not Path(fr.filename).is_relative_to(mlir_utis_root_path) - ), - None, + while Path(prev_frame.f_code.co_filename).is_relative_to(mlir_utis_root_path): + prev_frame = prev_frame.f_back + frame_info = inspect.getframeinfo(prev_frame) + return Location.file( + frame_info.filename, frame_info.lineno, frame_info.positions.col_offset ) - if user_frame is None: - warnings.warn("couldn't find user code frame") - return - return Location.file(user_frame.filename, user_frame.lineno, user_frame.colno or 0) diff --git a/pyproject.toml b/pyproject.toml index 79ab788..4dd8b59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "mlir-python-utils" -version = "0.0.2" +version = "0.0.3" description = "The missing pieces (as far as boilerplate reduction goes) of the upstream MLIR python bindings." requires-python = ">=3.11" license = { file = "LICENSE" } diff --git a/tests/test_location_tracking.py b/tests/test_location_tracking.py new file mode 100644 index 0000000..c026a23 --- /dev/null +++ b/tests/test_location_tracking.py @@ -0,0 +1,111 @@ +from pathlib import Path +from textwrap import dedent +from os import sep +import pytest + +from mlir_utils.ast.canonicalize import canonicalize +from mlir_utils.dialects.ext.arith import constant +from mlir_utils.dialects.ext.scf import ( + canonicalizer, + stack_if, +) +from mlir_utils.dialects.ext.tensor import S +from mlir_utils.dialects.tensor import generate, yield_ as tensor_yield, rank + +# noinspection PyUnresolvedReferences +from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext +from mlir_utils.types import f64_t, index_t, tensor_t + +# needed since the fix isn't defined here nor conftest.py +pytest.mark.usefixtures("ctx") + + +THIS_DIR = str(Path(__file__).parent.absolute()) + + +def get_asm(operation): + return operation.get_asm(enable_debug_info=True, pretty_debug_info=True).replace( + THIS_DIR, "THIS_DIR" + ) + + +def test_if_replace_yield_5(ctx: MLIRContext): + @canonicalize(using=canonicalizer) + def iffoo(): + one = constant(1.0) + two = constant(2.0) + if res := stack_if(one < two, (f64_t, f64_t, f64_t)): + three = constant(3.0) + yield three, three, three + else: + four = constant(4.0) + yield four, four, four + return + + iffoo() + ctx.module.operation.verify() + correct = dedent( + f"""\ + module {{ + %cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:35:10 + %cst_0 = arith.constant 2.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:36:10 + %0 = arith.cmpf olt, %cst, %cst_0 : f64 THIS_DIR{sep}test_location_tracking.py:37:23 + %1:3 = scf.if %0 -> (f64, f64, f64) {{ + %cst_1 = arith.constant 3.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:38:16 + scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:39:8 + }} else {{ + %cst_1 = arith.constant 4.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:41:24 + scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:42:8 + }} THIS_DIR{sep}test_location_tracking.py:37:14 + }} [unknown] + #loc = [unknown] + #loc1 = THIS_DIR{sep}test_location_tracking.py:35:10 + #loc2 = THIS_DIR{sep}test_location_tracking.py:36:10 + #loc3 = THIS_DIR{sep}test_location_tracking.py:37:23 + #loc4 = THIS_DIR{sep}test_location_tracking.py:37:14 + #loc5 = THIS_DIR{sep}test_location_tracking.py:38:16 + #loc6 = THIS_DIR{sep}test_location_tracking.py:39:8 + #loc7 = THIS_DIR{sep}test_location_tracking.py:41:24 + #loc8 = THIS_DIR{sep}test_location_tracking.py:42:8 + """ + ) + asm = get_asm(ctx.module.operation) + filecheck(correct, asm) + + +def test_block_args(ctx: MLIRContext): + one = constant(1, index_t) + two = constant(2, index_t) + + @generate(tensor_t(S, 3, S, f64_t), dynamic_extents=[one, two]) + def demo_fun1(i: index_t, j: index_t, k: index_t): + one = constant(1.0) + tensor_yield(one) + + r = rank(demo_fun1) + + ctx.module.operation.verify() + + correct = dedent( + f"""\ + #loc3 = THIS_DIR{sep}test_location_tracking.py:80:5 + module {{ + %c1 = arith.constant 1 : index THIS_DIR{sep}test_location_tracking.py:77:10 + %c2 = arith.constant 2 : index THIS_DIR{sep}test_location_tracking.py:78:10 + %generated = tensor.generate %c1, %c2 {{ + ^bb0(%arg0: index THIS_DIR{sep}test_location_tracking.py:80:5, %arg1: index THIS_DIR{sep}test_location_tracking.py:80:5, %arg2: index THIS_DIR{sep}test_location_tracking.py:80:5): + %cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:82:14 + tensor.yield %cst : f64 THIS_DIR{sep}test_location_tracking.py:83:8 + }} : tensor THIS_DIR{sep}test_location_tracking.py:80:5 + %rank = tensor.rank %generated : tensor THIS_DIR{sep}test_location_tracking.py:85:8 + }} [unknown] + #loc = [unknown] + #loc1 = THIS_DIR{sep}test_location_tracking.py:77:10 + #loc2 = THIS_DIR{sep}test_location_tracking.py:78:10 + #loc4 = THIS_DIR{sep}test_location_tracking.py:82:14 + #loc5 = THIS_DIR{sep}test_location_tracking.py:83:8 + #loc6 = THIS_DIR{sep}test_location_tracking.py:85:8 + """ + ) + asm = get_asm(ctx.module.operation) + filecheck(correct, asm)