Skip to content

Commit

Permalink
core: make VarOperand, VarOpResult, and VarRegion tuple instead of li…
Browse files Browse the repository at this point in the history
…st (#2767)

Comes with some minimal impacts on uses, but seems like a win in terms
of communicating immutability.
  • Loading branch information
superlopuh authored Jun 22, 2024
1 parent 67822b9 commit 0f6558a
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 26 deletions.
4 changes: 2 additions & 2 deletions tests/dialects/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def body(_: tuple[BlockArgument, ...]) -> None:
step.result,
carried.result,
)
assert for_op.regions == [body]
assert for_op.regions == (body,)
assert for_op.attributes == {}

for_op.verify()
Expand Down Expand Up @@ -82,7 +82,7 @@ def body(_: tuple[BlockArgument, ...]) -> None:
upper.result,
step.result,
)
assert for_op.regions == [body]
assert for_op.regions == (body,)
assert for_op.attributes == {}

for_op.verify()
Expand Down
4 changes: 2 additions & 2 deletions tests/irdl/test_operation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_two_var_region_builder():
region4 = Region()
op2 = TwoVarRegionOp.build(regions=[[region1, region2], [region3, region4]])
op2.verify()
assert op2.regions == [region1, region2, region3, region4]
assert op2.regions == (region1, region2, region3, region4)
assert op2.attributes[
AttrSizedRegionSegments.attribute_name
] == DenseArrayBase.from_list(i32, [2, 2])
Expand All @@ -536,7 +536,7 @@ def test_two_var_operand_builder3():
region4 = Region()
op2 = TwoVarRegionOp.build(regions=[[region1], [region2, region3, region4]])
op2.verify()
assert op2.regions == [region1, region2, region3, region4]
assert op2.regions == (region1, region2, region3, region4)
assert op2.attributes[
AttrSizedRegionSegments.attribute_name
] == DenseArrayBase.from_list(i32, [1, 3])
Expand Down
12 changes: 6 additions & 6 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ class Operation(IRNode):
_operands: tuple[SSAValue, ...] = field(default=())
"""The operation operands."""

results: list[OpResult] = field(default_factory=list)
results: tuple[OpResult, ...] = field(default=())
"""The results created by the operation."""

successors: list[Block] = field(default_factory=list)
Expand All @@ -582,7 +582,7 @@ class Operation(IRNode):
attributes: dict[str, Attribute] = field(default_factory=dict)
"""The attributes attached to the operation."""

regions: list[Region] = field(default_factory=list)
regions: tuple[Region, ...] = field(default=())
"""Regions arguments of the operation."""

parent: Block | None = field(default=None, repr=False)
Expand Down Expand Up @@ -696,14 +696,14 @@ def __init__(
# This is assumed to exist by Operation.operand setter.
self.operands = operands

self.results = [
self.results = tuple(
OpResult(result_type, self, idx)
for (idx, result_type) in enumerate(result_types)
]
)
self.properties = dict(properties)
self.attributes = dict(attributes)
self.successors = list(successors)
self.regions = []
self.regions = ()
for region in regions:
self.add_region(region)

Expand Down Expand Up @@ -738,7 +738,7 @@ def add_region(self, region: Region) -> None:
raise Exception(
"Cannot add region that is already attached on an operation."
)
self.regions.append(region)
self.regions += (region,)
region.parent = self

def get_region_index(self, region: Region) -> int:
Expand Down
14 changes: 7 additions & 7 deletions xdsl/irdl/irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def __init__(
self.constr = range_constr_coercion(attr)


VarOperand: TypeAlias = list[SSAValue]
VarOperand: TypeAlias = tuple[SSAValue, ...]


@dataclass(init=False)
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def __init__(
self.constr = range_constr_coercion(attr)


VarOpResult: TypeAlias = list[OpResult]
VarOpResult: TypeAlias = tuple[OpResult, ...]


@dataclass(init=False)
Expand Down Expand Up @@ -1068,7 +1068,7 @@ class OptRegionDef(RegionDef, OptionalDef):
"""An IRDL optional region definition."""


VarRegion: TypeAlias = list[Region]
VarRegion: TypeAlias = tuple[Region, ...]
OptRegion: TypeAlias = Region | None


Expand Down Expand Up @@ -1839,7 +1839,7 @@ def get_construct_defs(

def get_op_constructs(
op: Operation, construct: VarIRConstruct
) -> Sequence[SSAValue] | list[OpResult] | list[Region] | list[Successor]:
) -> Sequence[SSAValue] | Sequence[OpResult] | Sequence[Region] | Sequence[Successor]:
"""
Get the list of arguments of the type in an operation.
For example, if the argument type is an operand, get the list of
Expand Down Expand Up @@ -2001,11 +2001,11 @@ def get_operand_result_or_region(
None
| SSAValue
| Sequence[SSAValue]
| list[OpResult]
| Sequence[OpResult]
| Region
| list[Region]
| Sequence[Region]
| Successor
| list[Successor]
| Sequence[Successor]
):
"""
Get an operand, result, or region.
Expand Down
2 changes: 1 addition & 1 deletion xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def print_region(
self._print_new_line()
self.print("}")

def print_regions(self, regions: list[Region]) -> None:
def print_regions(self, regions: Sequence[Region]) -> None:
if len(regions) == 0:
return

Expand Down
2 changes: 1 addition & 1 deletion xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class IsolatedFromAbove(OpTrait):

def verify(self, op: Operation) -> None:
# Start by checking all the passed operation's regions
regions: list[Region] = op.regions.copy()
regions: list[Region] = list(op.regions)

# While regions are left to check
while regions:
Expand Down
4 changes: 2 additions & 2 deletions xdsl/transforms/experimental/lower_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def match_and_rewrite(self, op: ParallelOp, rewriter: PatternRewriter, /):
# wrapping in for loops until we have exhausted the induction variables
parallel_block = op.body.detach_block(0)

if res != []:
if res:
parallel_block.insert_arg(res[0].type, 1)
cast(Operation, parallel_block.last_op).detach()
yieldop = Yield(res[0].op)
Expand All @@ -350,7 +350,7 @@ def match_and_rewrite(self, op: ParallelOp, rewriter: PatternRewriter, /):
for i in range(len(lb) - 1):
for_region.block.erase_arg(for_region.block.args[i])

if res != []:
if res:
for_op = For(lb[-1], ub[-1], step[-1], [res[0].op], for_region)
else:
for_op = For(lb[-1], ub[-1], step[-1], [], for_region)
Expand Down
10 changes: 6 additions & 4 deletions xdsl/transforms/lower_affine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

from xdsl.context import MLContext
from xdsl.dialects import affine, arith, builtin, memref, scf
from xdsl.ir import Operation, SSAValue
Expand All @@ -20,8 +22,8 @@

def affine_expr_ops(
expr: affine.AffineExpr,
dims: list[SSAValue],
symbols: list[SSAValue],
dims: Sequence[SSAValue],
symbols: Sequence[SSAValue],
) -> tuple[list[Operation], SSAValue]:
"""
Returns the operations that evaluate the affine expression when given input SSA
Expand Down Expand Up @@ -59,7 +61,7 @@ def affine_expr_ops(

def insert_affine_map_ops(
map: affine.AffineMapAttr | None,
dims: list[SSAValue],
dims: Sequence[SSAValue],
symbols: list[SSAValue],
) -> tuple[list[Operation], list[SSAValue]]:
"""
Expand All @@ -68,7 +70,7 @@ def insert_affine_map_ops(
"""
ops: list[Operation] = []
if map is None:
indices = dims
indices = list(dims)
else:
indices: list[SSAValue] = []
for expr in map.data.results:
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/lower_riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def match_and_rewrite(self, op: riscv_func.SyscallOp, rewriter: PatternRewriter)

if op.result is None:
ops.append(riscv.EcallOp())
new_results = []
new_results = ()
else:
# The result will be stored to a0, move to register that will be used
ecall = riscv.EcallOp()
Expand Down

0 comments on commit 0f6558a

Please sign in to comment.