Skip to content

Commit dc34a42

Browse files
authored
Use target model and dynamically create wrapper classes in iron device.py (#2005)
1 parent 2951ec4 commit dc34a42

File tree

1 file changed

+42
-98
lines changed

1 file changed

+42
-98
lines changed

python/iron/device/device.py

+42-98
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88

99
from abc import abstractmethod
1010
from ... import ir # type: ignore
11-
from ...dialects.aie import AIEDevice, tile, TileOp # type: ignore
11+
from ...dialects.aie import AIEDevice, tile, TileOp, get_target_model # type: ignore
1212
from ..resolvable import Resolvable
1313
from .tile import Tile
1414

15+
import re
16+
1517

1618
class Device(Resolvable):
1719
"""
@@ -53,32 +55,31 @@ def op(self, op: TileOp):
5355
raise ValueError("Cannot set operation more than once.")
5456
self._op = op
5557

56-
def __init__(self, cols: int, rows: int) -> None:
58+
def __init__(self, device: AIEDevice) -> None:
5759
"""Initialize a representation of a device.
5860
5961
Args:
60-
cols (int): Number of columns on the device
61-
rows (int): Number of rows on the device.
62+
device (AIEDevice): aie device
6263
"""
63-
self._cols = cols
64-
self._rows = rows
64+
self._device = device
6565
self._tiles: list[list[Device.__DeviceTile]] = []
6666

6767
# Create all "physical" tiles belonging to the device at initialization to
68-
# ensure only one "physical" tile object is every created corresponding to the same
68+
# ensure only one "physical" tile object is ever created corresponding to the same
6969
# coordinates.
70-
for c in range(self._cols):
70+
tm = get_target_model(device)
71+
for c in range(tm.columns()):
7172
self._tiles.append([])
72-
for r in range(self._rows):
73+
for r in range(tm.rows()):
7374
self._tiles[c].append(Device.__DeviceTile(c, r))
7475

7576
@property
7677
def rows(self) -> int:
77-
return self._rows
78+
return get_target_model(self._device).rows()
7879

7980
@property
8081
def cols(self) -> int:
81-
return self._cols
82+
return get_target_model(self._device).columns()
8283

8384
@abstractmethod
8485
def get_shim_tiles(self) -> list[Tile]:
@@ -128,116 +129,59 @@ class NPUBase(Device):
128129
* The 2nd+ tiles in each column are compute tiles
129130
"""
130131

131-
def __init__(self, cols: int, rows: int) -> None:
132-
"""Initialize a device based on numbers of rows and columns.
132+
def __init__(self, device: AIEDevice) -> None:
133+
"""Initialize a device based on the AIEDevice.
133134
134135
Args:
135-
cols (int): Number of columns
136-
rows (int): Number of rows
136+
device (AIEDevice): aie device
137137
"""
138-
super().__init__(cols=cols, rows=rows)
138+
super().__init__(device=device)
139139

140140
def get_shim_tiles(self) -> list[Tile]:
141141
shim_tiles = []
142-
for col in range(self._cols):
142+
for col in range(self.cols):
143143
shim_tiles.append(Tile(col, 0))
144144
return shim_tiles
145145

146146
def get_mem_tiles(self) -> list[Tile]:
147147
mem_tiles = []
148-
for col in range(self._cols):
148+
for col in range(self.cols):
149149
mem_tiles.append(Tile(col, 1))
150150
return mem_tiles
151151

152152
def get_compute_tiles(self) -> list[Tile]:
153153
compute_tiles = []
154-
for col in range(self._cols):
155-
for row in range(2, self._rows):
154+
mem_tile_rows = get_target_model(self._device).get_num_mem_tile_rows()
155+
for col in range(self.cols):
156+
for row in range(1 + mem_tile_rows, self.rows):
156157
compute_tiles.append(Tile(col, row))
157158
return compute_tiles
158159

159160

160-
class NPU1Col1(NPUBase):
161-
"""A representation of a device that resolves to AIEDevice.npu1_1col"""
162-
163-
def __init__(self) -> None:
164-
super().__init__(cols=1, rows=6)
165-
166-
def resolve(
167-
self,
168-
loc: ir.Location | None = None,
169-
ip: ir.InsertionPoint | None = None,
170-
) -> None:
171-
return AIEDevice.npu1_1col
172-
173-
174-
class NPU1Col2(NPUBase):
175-
"""A representation of a device that resolves to AIEDevice.npu1_2col"""
176-
177-
def __init__(self) -> None:
178-
super().__init__(cols=2, rows=6)
179-
180-
def resolve(
181-
self,
182-
loc: ir.Location | None = None,
183-
ip: ir.InsertionPoint | None = None,
184-
) -> None:
185-
return AIEDevice.npu1_2col
186-
187-
188-
class NPU1Col3(NPUBase):
189-
"""A representation of a device that resolves to AIEDevice.npu1_3col"""
190-
191-
def __init__(self) -> None:
192-
super().__init__(cols=3, rows=6)
193-
194-
def resolve(
195-
self,
196-
loc: ir.Location | None = None,
197-
ip: ir.InsertionPoint | None = None,
198-
) -> None:
199-
return AIEDevice.npu1_3col
200-
201-
202-
class NPU1Col4(NPUBase):
203-
"""A representation of a device that resolves to AIEDevice.npu1_4col"""
204-
205-
def __init__(self) -> None:
206-
super().__init__(cols=4, rows=6)
207-
208-
def resolve(
209-
self,
210-
loc: ir.Location | None = None,
211-
ip: ir.InsertionPoint | None = None,
212-
) -> None:
213-
return AIEDevice.npu1_4col
214-
215-
216-
class NPU2(NPUBase):
217-
"""A representation of a device that resolves to AIEDevice.npu2"""
218-
219-
def __init__(self) -> None:
220-
super().__init__(cols=8, rows=6)
221-
222-
def resolve(
223-
self,
224-
loc: ir.Location | None = None,
225-
ip: ir.InsertionPoint | None = None,
226-
) -> None:
227-
return AIEDevice.npu2
228-
229-
230-
class XCVC1902(Device):
231-
"""A placeholder representation of a device that resolves to IEDevice.xcvc1902
232-
TODO: this needs to be implemented.
233-
"""
161+
def create_class(class_name, device):
234162

235-
def __init__(self) -> None:
236-
raise NotImplementedError("This device type is not yet implementated")
163+
def _device__init__(self) -> None:
164+
super(globals()[class_name], self).__init__(device=device)
237165

238-
def resolve(
166+
def _device_resolve(
239167
self,
240168
loc: ir.Location | None = None,
241169
ip: ir.InsertionPoint | None = None,
242170
) -> None:
243-
return AIEDevice.xcvc1902
171+
return device
172+
173+
base = NPUBase if "NPU" in class_name else Device
174+
globals()[class_name] = type(
175+
class_name,
176+
(base,),
177+
{
178+
"__init__": _device__init__,
179+
"resolve": _device_resolve,
180+
"__doc__": f"A representation of a device that resolves to {device}",
181+
},
182+
)
183+
184+
185+
for device in AIEDevice:
186+
class_name = re.sub(r"NPU(\d+)_(\d+)COL", r"NPU\1Col\2", device.name.upper())
187+
create_class(class_name, device)

0 commit comments

Comments
 (0)