diff --git a/programming_examples/basic/dma_transpose/Makefile b/programming_examples/basic/dma_transpose/Makefile
index 87771d1ab0..212cedabf3 100644
--- a/programming_examples/basic/dma_transpose/Makefile
+++ b/programming_examples/basic/dma_transpose/Makefile
@@ -44,5 +44,9 @@ endif
run: ${targetname}.exe build/final.xclbin
${powershell} ./$< -x build/final.xclbin -i build/insts.txt -k MLIR_AIE --M ${M} --K ${K}
+generate_access_map: ${srcdir}/aie2.py
+ mkdir -p ${@D}
+ python3 $< --generate-access-map ${M} ${K}
+
clean:
rm -rf build _build inst ${targetname}.exe
diff --git a/programming_examples/basic/dma_transpose/README.md b/programming_examples/basic/dma_transpose/README.md
index 5a73dde0e3..45125d377c 100644
--- a/programming_examples/basic/dma_transpose/README.md
+++ b/programming_examples/basic/dma_transpose/README.md
@@ -15,11 +15,24 @@ This reference design can be run on a Ryzen™ AI NPU.
In the [design](./aie2.py), a 2-D array in a row-major layout is read from external memory to `ComputeTile2` with a transposed layout,
by using an implicit copy via the compute tile's Data Movement Accelerator (DMA). The data is read from and written to external memory through the Shim tile (`col`, 0).
+This data movement transformation can be visualized as a map which shows the order the data the data is streamed (e.g., in transposed layout):
+
+
+
Visualization of the Transpose Data Transformation for M=32, K=16.
+
+
+
The implicit copy is performed using the `object_fifo_link` operation that specifies how input data arriving via `of_in` should be sent further via `of_out` by specifically leveraging the compute tile's DMA. This operation and its functionality are described in more depth in [Section-2b](../../../programming_guide/section-2/section-2b/README.md/#object-fifo-link) of the programming guide.
To compile and run the design for NPU:
-```
+```bash
make
make run
+```
+
+To generate a data visualization of the transpose (like that above), run:
+```bash
+make generate_access_map
```
\ No newline at end of file
diff --git a/programming_examples/basic/dma_transpose/aie2.py b/programming_examples/basic/dma_transpose/aie2.py
index d5299f4a06..bbc3a056d9 100644
--- a/programming_examples/basic/dma_transpose/aie2.py
+++ b/programming_examples/basic/dma_transpose/aie2.py
@@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
+import argparse
import numpy as np
import sys
@@ -12,20 +13,24 @@
from aie.dialects.aiex import *
from aie.extras.context import mlir_mod_ctx
from aie.helpers.dialects.ext.scf import _for as range_
+from aie.helpers.tensortiler.tensortiler2d import TensorTile
-N = 4096
-M = 64
-K = 64
-if len(sys.argv) == 3:
- M = int(sys.argv[1])
- K = int(sys.argv[2])
- N = M * K
+def my_passthrough(M, K, N, generate_acccess_map=False):
+ tensor_ty = np.ndarray[(M, K), np.dtype[np.int32]]
+ data_transform = TensorTile(
+ tensor_height=M,
+ tensor_width=K,
+ sizes=[1, 1, K, M],
+ strides=[1, 1, 1, K],
+ offset=0,
+ )
+ if generate_acccess_map:
+ data_transform.visualize(
+ plot_access_count=False, file_path="transpose_data.png"
+ )
+ return
-tensor_ty = np.ndarray[(M, K), np.dtype[np.int32]]
-
-
-def my_passthrough():
with mlir_mod_ctx() as ctx:
@device(AIEDevice.npu1_1col)
@@ -56,8 +61,7 @@ def sequence(A, B, C):
metadata=of_in,
bd_id=1,
mem=A,
- sizes=[1, 1, K, M],
- strides=[1, 1, 1, K],
+ tensor_tile=data_transform,
issue_token=True,
)
npu_dma_memcpy_nd(metadata=of_out, bd_id=0, mem=C, sizes=[1, 1, 1, N])
@@ -66,4 +70,24 @@ def sequence(A, B, C):
print(ctx.module)
-my_passthrough()
+if __name__ == "__main__":
+ p = argparse.ArgumentParser()
+ p.add_argument("dims", help="M K", type=int, nargs="*", default=[64, 64])
+ p.add_argument(
+ "--generate-access-map",
+ action="store_true",
+ help="Produce a file showing data access order",
+ )
+ args = p.parse_args()
+
+ if len(args.dims) != 2:
+ print(
+ "ERROR: Must provide either no dimensions or both M and K", file=sys.stderr
+ )
+ exit(-1)
+ my_passthrough(
+ M=args.dims[0],
+ K=args.dims[1],
+ N=args.dims[0] * args.dims[1],
+ generate_acccess_map=args.generate_access_map,
+ )
diff --git a/programming_examples/basic/dma_transpose/transpose_data.png b/programming_examples/basic/dma_transpose/transpose_data.png
new file mode 100644
index 0000000000..33112790be
Binary files /dev/null and b/programming_examples/basic/dma_transpose/transpose_data.png differ
diff --git a/programming_examples/basic/matrix_multiplication/whole_array/tiling_scratch.py b/programming_examples/basic/matrix_multiplication/whole_array/tiling_scratch.py
new file mode 100644
index 0000000000..835efd87ea
--- /dev/null
+++ b/programming_examples/basic/matrix_multiplication/whole_array/tiling_scratch.py
@@ -0,0 +1,92 @@
+import numpy as np
+
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D, TensorTile
+from util import construct_test
+
+
+# RUN: %python %s | FileCheck %s
+def ceildiv(a, b):
+ return -(a // -b)
+
+
+def run_checks(n_aie_cols, n_aie_rows, M, N, K, m, n, k):
+ tb_max_n_rows = 4
+ tb_n_rows = tb_max_n_rows // 2
+
+ # Define tilers
+ c_tiler = TensorTiler2D(M, N, m * n_aie_rows, n)
+ c_iter = c_tiler.tile_iter(
+ tile_repeat_step_horizontal=n_aie_cols, iter_step=tb_n_rows
+ )
+
+ for tb in range(ceildiv(M // m // n_aie_rows, tb_max_n_rows)):
+ for pingpong in [0, 1]:
+ row_base = tb * tb_max_n_rows + pingpong * tb_max_n_rows // 2
+ tb_n_rows = min([tb_max_n_rows // 2, M // m // n_aie_rows - row_base])
+ print(tb_n_rows)
+ if tb_n_rows <= 0:
+ # for small input sizes, we may not even need a "pong" iteration
+ break
+
+ for col in range(n_aie_cols):
+ C_row_offset = row_base * m * n_aie_rows * N
+ C_col_offset = col * n
+ C_offset = C_col_offset + C_row_offset
+ C_sizes = [tb_n_rows, N // n // n_aie_cols, m * n_aie_rows, n]
+ C_strides = [m * n_aie_rows * N, n * n_aie_cols, N, 1]
+ expected_c_tile = TensorTile(
+ M, N, offset=C_offset, sizes=C_sizes, strides=C_strides
+ )
+
+ c_tile = next(c_iter)
+ if c_tile != expected_c_tile:
+ # equivalence for tensor tile checks offset, size, stride
+ # but there may be different but equivalent transformations
+
+ reference_access, reference_count = expected_c_tile.access_tensors()
+ c_access, c_count = c_tile.access_tensors()
+
+ """
+ assert (reference_access == c_access).all(), (
+ f"C access orders do not match. "
+ f"Expected ({expected_c_tile}), got ({c_tile})"
+ )
+ assert (reference_count == c_count).all()
+ """
+ print(f"Expected: {expected_c_tile}")
+ print(f"Actual: {c_tile}")
+
+
+def matrix_whole_array_tiling_sweep():
+ n_aie_cols_sweep = [1, 2, 4] # TODO: when partial, add 3
+ n_aie_rows_sweep = [1, 2, 4] # TODO: when partial, add 3
+ M_sweep = range(512, 4096, 512)
+ K_sweep = range(512, 4096, 512)
+ N_sweep = range(512, 4096, 512)
+ m_sweep = [16, 32, 64]
+ n_sweep = [16, 32, 64]
+ k_sweep = [16, 32, 64]
+
+ for n_aie_cols in n_aie_cols_sweep:
+ for n_aie_rows in n_aie_rows_sweep:
+ for M in M_sweep:
+ for N in N_sweep:
+ for K in K_sweep:
+ for m in m_sweep:
+ for n in n_sweep:
+ for k in k_sweep:
+ run_checks(
+ n_aie_cols=n_aie_cols,
+ n_aie_rows=n_aie_rows,
+ M=M,
+ N=N,
+ K=K,
+ m=m,
+ k=k,
+ n=n,
+ )
+ return
+
+
+if __name__ == "__main__":
+ matrix_whole_array_tiling_sweep()
diff --git a/programming_examples/basic/matrix_scalar_add/aie2.py b/programming_examples/basic/matrix_scalar_add/aie2.py
index 87d75fd88b..75a980d208 100644
--- a/programming_examples/basic/matrix_scalar_add/aie2.py
+++ b/programming_examples/basic/matrix_scalar_add/aie2.py
@@ -12,6 +12,7 @@
from aie.dialects.aiex import *
from aie.extras.context import mlir_mod_ctx
from aie.helpers.dialects.ext.scf import _for as range_
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D
# Size of the entire image
IMAGE_HEIGHT = 16
@@ -68,14 +69,16 @@ def core_body():
of_out1.release(ObjectFifoPort.Produce, 1)
# To/from AIE-array data movement
+ tiler = TensorTiler2D(IMAGE_HEIGHT, IMAGE_WIDTH, TILE_HEIGHT, TILE_WIDTH)
+ t = next(tiler.tile_iter()) # Only transfer one (first) tile of data
+
@runtime_sequence(tile_ty, tile_ty, tile_ty)
def sequence(inTensor, notUsed, outTensor):
npu_dma_memcpy_nd(
metadata=of_in1,
bd_id=1,
mem=inTensor,
- sizes=[1, 1, TILE_HEIGHT, TILE_WIDTH],
- strides=[1, 1, IMAGE_WIDTH, 1],
+ tensor_tile=t,
issue_token=True,
)
@@ -83,8 +86,7 @@ def sequence(inTensor, notUsed, outTensor):
metadata=of_out1,
bd_id=0,
mem=outTensor,
- sizes=[1, 1, TILE_HEIGHT, TILE_WIDTH],
- strides=[1, 1, IMAGE_WIDTH, 1],
+ tensor_tile=t,
)
dma_wait(of_in1, of_out1)
diff --git a/programming_examples/basic/row_wise_bias_add/aie2.py b/programming_examples/basic/row_wise_bias_add/aie2.py
index 1589473636..eac3b6d702 100644
--- a/programming_examples/basic/row_wise_bias_add/aie2.py
+++ b/programming_examples/basic/row_wise_bias_add/aie2.py
@@ -11,6 +11,7 @@
from aie.dialects.aiex import *
from aie.extras.context import mlir_mod_ctx
from aie.helpers.dialects.ext.scf import _for as range_
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D
def row_wise_bias_add(M, N, m, n):
@@ -48,28 +49,32 @@ def core_body():
in_fifo.release(ObjectFifoPort.Consume, 1)
bias_fifo.release(ObjectFifoPort.Consume, 1)
+ tiler = TensorTiler2D(M, N, m, n, tensor_col_major=True)
+ t = next(
+ tiler.tile_iter(tile_group_height=M // m, tile_group_width=N // n)
+ ) # Transfer all tiles at once
+ bias_tiler = TensorTiler2D(1, N, 1, n)
+ bias_t = next(bias_tiler.tile_iter(tile_group_width=N // n))
+
@runtime_sequence(tensor_ty, bias_ty, tensor_ty)
def sequence(inp, bias, out):
npu_dma_memcpy_nd(
metadata=in_fifo,
bd_id=0,
mem=inp,
- sizes=[1, N // n, M, n],
- strides=[0, n, N, 1],
+ tensor_tile=t,
)
npu_dma_memcpy_nd(
metadata=bias_fifo,
bd_id=1,
mem=bias,
- sizes=[1, 1, N // n, n],
- strides=[0, 0, n, 1],
+ tensor_tile=bias_t,
)
npu_dma_memcpy_nd(
metadata=out_fifo,
bd_id=2,
mem=out,
- sizes=[1, N // n, M, n],
- strides=[0, n, N, 1],
+ tensor_tile=t,
)
# of_out will only complete after of_in completes, so we just wait on of_out instead of both
dma_wait(out_fifo)
diff --git a/programming_examples/basic/tiling_exploration/per_tile/Makefile b/programming_examples/basic/tiling_exploration/per_tile/Makefile
new file mode 100644
index 0000000000..ccec0077b0
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/per_tile/Makefile
@@ -0,0 +1,39 @@
+##===- Makefile -----------------------------------------------------------===##
+#
+# This file licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+# Copyright (C) 2024, Advanced Micro Devices, Inc.
+#
+##===----------------------------------------------------------------------===##
+
+srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST))))
+
+include ${srcdir}/../../../makefile-common
+
+tensor_height = 32
+tensor_width = 32
+tile_height = 4
+tile_width = 4
+data_str=${tensor_height}_${tensor_width}_${tile_height}_${tile_width}
+
+.PHONY: all template clean
+
+all: build/final_${data_str}.xclbin
+
+build/aie_${data_str}.mlir: ${srcdir}/aie2.py
+ mkdir -p ${@D}
+ python3 $< --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width} > $@
+
+build/final_${data_str}.xclbin: build/aie_${data_str}.mlir
+ mkdir -p ${@D}
+ cd ${@D} && aiecc.py --aie-generate-cdo --aie-generate-npu --no-compile-host \
+ --no-xchesscc --no-xbridge \
+ --xclbin-name=${@F} --npu-insts-name=insts_${data_str}.txt $(<:%=../%)
+
+run: build/final_${data_str}.xclbin build/insts_${data_str}.txt
+ ${powershell} python3 ${srcdir}/test.py -x build/final_${data_str}.xclbin -i build/insts_${data_str}.txt -k MLIR_AIE --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width}
+
+clean:
+ rm -rf build
diff --git a/programming_examples/basic/tiling_exploration/per_tile/README.md b/programming_examples/basic/tiling_exploration/per_tile/README.md
new file mode 100644
index 0000000000..cfb0936362
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/per_tile/README.md
@@ -0,0 +1,31 @@
+
+
+# Tiling Exploration
+
+This IRON design flow example, called "Tiling Exploration", demonstrates how data may be `tiled` on input/output. This is a common data transformation pattern, and this example is meant to be interactive.
+
+## Source Files Overview
+
+TODO
+
+## Design Overview
+
+TODO
+
+## Design Component Details
+
+### AIE Array Structural Design
+
+TODO
+
+## Usage
+
+TODO
diff --git a/programming_examples/basic/tiling_exploration/per_tile/aie2.py b/programming_examples/basic/tiling_exploration/per_tile/aie2.py
new file mode 100644
index 0000000000..dfd6fc9495
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/per_tile/aie2.py
@@ -0,0 +1,88 @@
+# tiling_exploration/aie2.py -*- Python -*-
+#
+# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
+import argparse
+import numpy as np
+import sys
+
+from aie.dialects.aie import *
+from aie.dialects.aiex import *
+from aie.dialects import arith
+from aie.extras.context import mlir_mod_ctx
+from aie.helpers.dialects.ext.scf import _for as range_
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D
+
+
+def generate_module(tensor_height, tensor_width, tile_height, tile_width):
+ @device(AIEDevice.npu1_1col)
+ def device_body():
+ # define types
+ tensor_size = tensor_height * tensor_width
+ tile_size = tile_height * tile_width
+ flattened_tensor = np.ndarray[(tensor_size,), np.dtype[TensorTiler2D.DTYPE]]
+ flattened_tile = np.ndarray[(tile_size,), np.dtype[TensorTiler2D.DTYPE]]
+
+ # Tile declarations
+ ShimTile = tile(0, 0)
+ ComputeTile2 = tile(0, 2)
+
+ # AIE-array data movement with object fifos
+ of_out = object_fifo("out", ComputeTile2, ShimTile, 2, flattened_tile)
+
+ # Set up compute tiles
+
+ # Compute tile 2
+ @core(ComputeTile2)
+ def core_body():
+ # TODO: better way to get mutable constant than buffer??
+ access_counter = buffer(
+ ComputeTile2,
+ np.ndarray[(1,), np.dtype[TensorTiler2D.DTYPE]],
+ "access_counter",
+ initial_value=np.array([0], dtype=TensorTiler2D.DTYPE),
+ )
+ for _ in range_(sys.maxsize):
+ elemOut = of_out.acquire(ObjectFifoPort.Produce, 1)
+ for i in range_(tile_size):
+ elemOut[i] = access_counter[0]
+ access_counter[0] += 1
+ of_out.release(ObjectFifoPort.Produce, 1)
+
+ @runtime_sequence(flattened_tensor)
+ def sequence(access_count):
+ tiler = TensorTiler2D(tensor_height, tensor_width, tile_height, tile_width)
+ for t in tiler.tile_iter():
+ npu_dma_memcpy_nd(
+ metadata=of_out,
+ bd_id=1,
+ mem=access_count,
+ tensor_tile=t,
+ )
+ dma_wait(of_out)
+
+
+def main(opts):
+ with mlir_mod_ctx() as ctx:
+ generate_module(
+ opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width
+ )
+ print(ctx.module)
+
+
+def get_arg_parser():
+ p = argparse.ArgumentParser()
+ p.add_argument("--tensor-height", required=True, help="Tensor height", type=int)
+ p.add_argument("--tensor-width", required=True, help="Tensor width", type=int)
+ p.add_argument("--tile-height", required=True, help="Tile height", type=int)
+ p.add_argument("--tile-width", required=True, help="Tile width", type=int)
+ return p
+
+
+if __name__ == "__main__":
+ p = get_arg_parser()
+ opts = p.parse_args()
+ main(opts)
diff --git a/programming_examples/basic/tiling_exploration/per_tile/run_makefile.lit b/programming_examples/basic/tiling_exploration/per_tile/run_makefile.lit
new file mode 100644
index 0000000000..bdafa81a59
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/per_tile/run_makefile.lit
@@ -0,0 +1,10 @@
+// (c) Copyright 2024 Advanced Micro Devices, Inc.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// REQUIRES: ryzen_ai, peano
+//
+// RUN: make -f %S/Makefile clean
+// RUN: make -f %S/Makefile
+// RUN: %run_on_npu make -f %S/Makefile run | FileCheck %s
+// CHECK: Running...
+// CHECK: PASS!
diff --git a/programming_examples/basic/tiling_exploration/per_tile/test.py b/programming_examples/basic/tiling_exploration/per_tile/test.py
new file mode 100644
index 0000000000..d9adaaf2fc
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/per_tile/test.py
@@ -0,0 +1,87 @@
+# tiling_exploration/test.py -*- Python -*-
+#
+# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
+import argparse
+import numpy as np
+
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D
+from aie.utils.xrt import setup_aie, execute as execute_on_aie
+
+
+def main(opts):
+ print("Running...\n")
+
+ dtype = TensorTiler2D.DTYPE
+ data_size = opts.tensor_height * opts.tensor_width
+
+ reference_tiler = TensorTiler2D(
+ opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width
+ )
+ reference_access_order = reference_tiler.access_order()
+
+ app = setup_aie(
+ opts.xclbin,
+ opts.instr,
+ None,
+ None,
+ None,
+ None,
+ data_size,
+ dtype,
+ )
+ aie_output = execute_on_aie(app)
+ aie_output = aie_output.reshape((opts.tensor_height, opts.tensor_width))
+
+ # Copy output results and verify they are correct
+ errors = 0
+ if opts.verbosity >= 1:
+ print("Verifying results ...")
+ e = np.equal(reference_access_order, aie_output)
+ errors = np.size(e) - np.count_nonzero(e)
+
+ if not errors:
+ print("\nPASS!\n")
+ exit(0)
+ else:
+ print("\nError count: ", errors)
+ print("\nFailed.\n")
+ exit(-1)
+
+
+def get_arg_parser():
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "-x", "--xclbin", default="final.xclbin", dest="xclbin", help="the xclbin path"
+ )
+ p.add_argument(
+ "-k",
+ "--kernel",
+ dest="kernel",
+ default="MLIR_AIE",
+ help="the kernel name in the XCLBIN (for instance MLIR_AIE)",
+ )
+ p.add_argument(
+ "-v", "--verbosity", default=0, type=int, help="the verbosity of the output"
+ )
+ p.add_argument(
+ "-i",
+ "--instr",
+ dest="instr",
+ default="instr.txt",
+ help="path of file containing userspace instructions sent to the NPU",
+ )
+ p.add_argument("--tensor-height", required=True, help="Tensor height", type=int)
+ p.add_argument("--tensor-width", required=True, help="Tensor width", type=int)
+ p.add_argument("--tile-height", required=True, help="Tile height", type=int)
+ p.add_argument("--tile-width", required=True, help="Tile width", type=int)
+ return p
+
+
+if __name__ == "__main__":
+ p = get_arg_parser()
+ opts = p.parse_args()
+ main(opts)
diff --git a/programming_examples/basic/tiling_exploration/single_transform/Makefile b/programming_examples/basic/tiling_exploration/single_transform/Makefile
new file mode 100644
index 0000000000..ccec0077b0
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/single_transform/Makefile
@@ -0,0 +1,39 @@
+##===- Makefile -----------------------------------------------------------===##
+#
+# This file licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+# Copyright (C) 2024, Advanced Micro Devices, Inc.
+#
+##===----------------------------------------------------------------------===##
+
+srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST))))
+
+include ${srcdir}/../../../makefile-common
+
+tensor_height = 32
+tensor_width = 32
+tile_height = 4
+tile_width = 4
+data_str=${tensor_height}_${tensor_width}_${tile_height}_${tile_width}
+
+.PHONY: all template clean
+
+all: build/final_${data_str}.xclbin
+
+build/aie_${data_str}.mlir: ${srcdir}/aie2.py
+ mkdir -p ${@D}
+ python3 $< --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width} > $@
+
+build/final_${data_str}.xclbin: build/aie_${data_str}.mlir
+ mkdir -p ${@D}
+ cd ${@D} && aiecc.py --aie-generate-cdo --aie-generate-npu --no-compile-host \
+ --no-xchesscc --no-xbridge \
+ --xclbin-name=${@F} --npu-insts-name=insts_${data_str}.txt $(<:%=../%)
+
+run: build/final_${data_str}.xclbin build/insts_${data_str}.txt
+ ${powershell} python3 ${srcdir}/test.py -x build/final_${data_str}.xclbin -i build/insts_${data_str}.txt -k MLIR_AIE --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width}
+
+clean:
+ rm -rf build
diff --git a/programming_examples/basic/tiling_exploration/single_transform/README.md b/programming_examples/basic/tiling_exploration/single_transform/README.md
new file mode 100644
index 0000000000..cfb0936362
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/single_transform/README.md
@@ -0,0 +1,31 @@
+
+
+# Tiling Exploration
+
+This IRON design flow example, called "Tiling Exploration", demonstrates how data may be `tiled` on input/output. This is a common data transformation pattern, and this example is meant to be interactive.
+
+## Source Files Overview
+
+TODO
+
+## Design Overview
+
+TODO
+
+## Design Component Details
+
+### AIE Array Structural Design
+
+TODO
+
+## Usage
+
+TODO
diff --git a/programming_examples/basic/tiling_exploration/single_transform/aie2.py b/programming_examples/basic/tiling_exploration/single_transform/aie2.py
new file mode 100644
index 0000000000..d1b8f98951
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/single_transform/aie2.py
@@ -0,0 +1,80 @@
+# tiling_exploration/aie2.py -*- Python -*-
+#
+# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
+import argparse
+import numpy as np
+import sys
+
+from aie.dialects.aie import *
+from aie.dialects.aiex import *
+from aie.dialects import arith
+from aie.extras.context import mlir_mod_ctx
+from aie.helpers.dialects.ext.scf import _for as range_
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D
+
+
+def generate_module(tensor_height, tensor_width, tile_height, tile_width):
+ @device(AIEDevice.npu1_1col)
+ def device_body():
+ # define types
+ tensor_size = tensor_height * tensor_width
+ flattened_tensor = np.ndarray[(tensor_size,), np.dtype[TensorTiler2D.DTYPE]]
+
+ # Tile declarations
+ ShimTile = tile(0, 0)
+ ComputeTile2 = tile(0, 2)
+
+ # AIE-array data movement with object fifos
+ of_out = object_fifo("out", ComputeTile2, ShimTile, 2, flattened_tensor)
+
+ # Set up compute tiles
+
+ # Compute tile 2
+ @core(ComputeTile2)
+ def core_body():
+ for _ in range_(sys.maxsize):
+ elemOut = of_out.acquire(ObjectFifoPort.Produce, 1)
+ for i in range_(tensor_size):
+ # TODO: fix need for cast here.
+ elemOut[i] = arith.index_cast(T.i32(), i)
+ of_out.release(ObjectFifoPort.Produce, 1)
+
+ @runtime_sequence(flattened_tensor)
+ def sequence(access_count):
+ t = TensorTiler2D(
+ tensor_height, tensor_width, tile_height, tile_width
+ ).as_tile()
+ npu_dma_memcpy_nd(
+ metadata=of_out,
+ bd_id=1,
+ mem=access_count,
+ tensor_tile=t,
+ )
+ dma_wait(of_out)
+
+
+def main(opts):
+ with mlir_mod_ctx() as ctx:
+ generate_module(
+ opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width
+ )
+ print(ctx.module)
+
+
+def get_arg_parser():
+ p = argparse.ArgumentParser()
+ p.add_argument("--tensor-height", required=True, help="Tensor height", type=int)
+ p.add_argument("--tensor-width", required=True, help="Tensor width", type=int)
+ p.add_argument("--tile-height", required=True, help="Tile height", type=int)
+ p.add_argument("--tile-width", required=True, help="Tile width", type=int)
+ return p
+
+
+if __name__ == "__main__":
+ p = get_arg_parser()
+ opts = p.parse_args()
+ main(opts)
diff --git a/programming_examples/basic/tiling_exploration/single_transform/run_makefile.lit b/programming_examples/basic/tiling_exploration/single_transform/run_makefile.lit
new file mode 100644
index 0000000000..bdafa81a59
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/single_transform/run_makefile.lit
@@ -0,0 +1,10 @@
+// (c) Copyright 2024 Advanced Micro Devices, Inc.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// REQUIRES: ryzen_ai, peano
+//
+// RUN: make -f %S/Makefile clean
+// RUN: make -f %S/Makefile
+// RUN: %run_on_npu make -f %S/Makefile run | FileCheck %s
+// CHECK: Running...
+// CHECK: PASS!
diff --git a/programming_examples/basic/tiling_exploration/single_transform/test.py b/programming_examples/basic/tiling_exploration/single_transform/test.py
new file mode 100644
index 0000000000..d9adaaf2fc
--- /dev/null
+++ b/programming_examples/basic/tiling_exploration/single_transform/test.py
@@ -0,0 +1,87 @@
+# tiling_exploration/test.py -*- Python -*-
+#
+# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#
+# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
+import argparse
+import numpy as np
+
+from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D
+from aie.utils.xrt import setup_aie, execute as execute_on_aie
+
+
+def main(opts):
+ print("Running...\n")
+
+ dtype = TensorTiler2D.DTYPE
+ data_size = opts.tensor_height * opts.tensor_width
+
+ reference_tiler = TensorTiler2D(
+ opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width
+ )
+ reference_access_order = reference_tiler.access_order()
+
+ app = setup_aie(
+ opts.xclbin,
+ opts.instr,
+ None,
+ None,
+ None,
+ None,
+ data_size,
+ dtype,
+ )
+ aie_output = execute_on_aie(app)
+ aie_output = aie_output.reshape((opts.tensor_height, opts.tensor_width))
+
+ # Copy output results and verify they are correct
+ errors = 0
+ if opts.verbosity >= 1:
+ print("Verifying results ...")
+ e = np.equal(reference_access_order, aie_output)
+ errors = np.size(e) - np.count_nonzero(e)
+
+ if not errors:
+ print("\nPASS!\n")
+ exit(0)
+ else:
+ print("\nError count: ", errors)
+ print("\nFailed.\n")
+ exit(-1)
+
+
+def get_arg_parser():
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "-x", "--xclbin", default="final.xclbin", dest="xclbin", help="the xclbin path"
+ )
+ p.add_argument(
+ "-k",
+ "--kernel",
+ dest="kernel",
+ default="MLIR_AIE",
+ help="the kernel name in the XCLBIN (for instance MLIR_AIE)",
+ )
+ p.add_argument(
+ "-v", "--verbosity", default=0, type=int, help="the verbosity of the output"
+ )
+ p.add_argument(
+ "-i",
+ "--instr",
+ dest="instr",
+ default="instr.txt",
+ help="path of file containing userspace instructions sent to the NPU",
+ )
+ p.add_argument("--tensor-height", required=True, help="Tensor height", type=int)
+ p.add_argument("--tensor-width", required=True, help="Tensor width", type=int)
+ p.add_argument("--tile-height", required=True, help="Tile height", type=int)
+ p.add_argument("--tile-width", required=True, help="Tile width", type=int)
+ return p
+
+
+if __name__ == "__main__":
+ p = get_arg_parser()
+ opts = p.parse_args()
+ main(opts)
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index 97d2dbceb3..fa95393a48 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -42,6 +42,7 @@ declare_mlir_python_sources(AIEPythonSources.Helpers
helpers/*.py
helpers/dialects/ext/*.py
helpers/runtime/*.py
+ helpers/tensortiler/*.py
)
declare_mlir_dialect_python_bindings(
diff --git a/python/dialects/aiex.py b/python/dialects/aiex.py
index 0f5c329e1b..c3cd2d124b 100644
--- a/python/dialects/aiex.py
+++ b/python/dialects/aiex.py
@@ -28,6 +28,7 @@
# noinspection PyUnresolvedReferences
from ..extras import types as T
from ..helpers.util import try_convert_np_type_to_mlir_type, np_ndarray_type_get_shape
+from ..helpers.tensortiler.tensortiler2d import TensorTile
# Comes from _aie
register_dialect(get_dialect_registry())
@@ -55,19 +56,29 @@ def __init__(
metadata: str | ObjectFifoCreateOp,
bd_id,
mem,
- offsets: MixedValues = None,
- sizes: MixedValues = None,
- strides: MixedValues = None,
+ tensor_tile: TensorTile | None = None,
+ offsets: MixedValues | None = None,
+ sizes: MixedValues | None = None,
+ strides: MixedValues | None = None,
issue_token: bool | None = None,
):
x = 0
y = 0
- if offsets is None:
- offsets = [0] * 4
- if sizes is None:
- sizes = [0] * 4
- if strides is None:
- strides = [0] * 3 + [1]
+ if tensor_tile and not (offsets is None and sizes is None and strides is None):
+ raise ValueError(
+ "NpuDmaMemcpyNd can take either a tensor_tile OR (sizes and/or strides and/or offsets), but not both."
+ )
+ if tensor_tile:
+ sizes = tensor_tile.sizes.copy()
+ strides = tensor_tile.strides.copy()
+ offsets = [0, 0, 0, tensor_tile.offset]
+ else:
+ if offsets is None:
+ offsets = [0] * 4
+ if sizes is None:
+ sizes = [0] * 4
+ if strides is None:
+ strides = [0] * 3 + [1]
dynamic_offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(
offsets
)
diff --git a/python/helpers/tensortiler/__init__.py b/python/helpers/tensortiler/__init__.py
new file mode 100644
index 0000000000..5641295b7f
--- /dev/null
+++ b/python/helpers/tensortiler/__init__.py
@@ -0,0 +1,5 @@
+from .tensortile import TensorTile
+from .tensortilesequence import (
+ TensorTileSequence,
+)
+from .tensortiler2d import TensorTiler2D
diff --git a/python/helpers/tensortiler/tensortile.py b/python/helpers/tensortiler/tensortile.py
new file mode 100644
index 0000000000..b9df1d1a8a
--- /dev/null
+++ b/python/helpers/tensortiler/tensortile.py
@@ -0,0 +1,120 @@
+from copy import deepcopy
+import numpy as np
+import itertools
+from typing import Sequence
+
+from .utils import (
+ validate_and_clean_sizes_strides,
+ validate_offset,
+ validate_tensor_dims,
+)
+from .visualization2d import visualize_from_access_tensors
+
+
+class TensorTile:
+ _DTYPE = np.int32
+
+ def __init__(
+ self,
+ tensor_dims: Sequence[int],
+ offset: int,
+ sizes: Sequence[int],
+ strides: Sequence[int],
+ ):
+ self._tensor_dims = validate_tensor_dims(tensor_dims)
+ self._offset = validate_offset(offset)
+ if self._offset >= np.prod(tensor_dims):
+ raise ValueError(
+ f"Offset too large: {self._offset}. Max value allowed for tensor: {np.prod(tensor_dims)}"
+ )
+ self._sizes, self._strides = validate_and_clean_sizes_strides(sizes, strides)
+
+ @property
+ def tensor_dims(self) -> Sequence[int]:
+ # Copy to prevent callers from mutating self
+ return deepcopy(self._tensor_dims)
+
+ @property
+ def offset(self) -> int:
+ return self._offset
+
+ @property
+ def sizes(self) -> Sequence[int]:
+ # Copy to prevent callers from mutating self
+ return deepcopy(self._sizes)
+
+ @property
+ def strides(self) -> Sequence[int]:
+ # Copy to prevent callers from mutating self
+ return deepcopy(self._strides)
+
+ @property
+ def transformation_dims(self) -> Sequence[tuple[int, int]]:
+ return list(zip(self._sizes, self._strides))
+
+ def access_tensors(self) -> tuple[np.ndarray, np.ndarray]:
+ # TODO: should access order be a list of lists instead of generate two separate tensors?
+ # TODO: for performance, should cache and return copies? Or just cache?
+
+ # Initialize access order and count maps
+ total_elems = np.prod(self._tensor_dims)
+ access_order_tensor = np.full(total_elems, -1, dtype=self._DTYPE)
+ access_count_tensor = np.full(total_elems, 0, dtype=self._DTYPE)
+
+ # Use itertools.product to collapse len(sizes) nested forloop into one forloop
+ access_count = 0
+ for dims in itertools.product(*[range(0, n) for n in self._sizes]):
+ access_idx = (
+ self._offset + np.sum(np.multiply(dims, self._strides))
+ ) % total_elems
+ access_count_tensor[access_idx] += 1
+ access_order_tensor[access_idx] = access_count
+ access_count += 1
+
+ access_order_tensor = access_order_tensor.reshape(self._tensor_dims)
+ access_count_tensor = access_count_tensor.reshape(self._tensor_dims)
+ return access_order_tensor, access_count_tensor
+
+ def visualize(
+ self,
+ show_arrows: bool | None = None,
+ title: str = None,
+ file_path: str | None = None,
+ show_plot: bool = True,
+ plot_access_count: bool = False,
+ ) -> None:
+ access_order, access_count = self.access_tensors()
+ if title is None:
+ title = str(self)
+ if not plot_access_count:
+ access_count = None
+ if len(self._tensor_dims) == 2:
+ visualize_from_access_tensors(
+ access_order,
+ access_count,
+ title=title,
+ show_arrows=show_arrows,
+ file_path=file_path,
+ show_plot=show_plot,
+ )
+ else:
+ raise NotImplementedError(
+ "Visualization is only currently supported for 1- or 2-dimensional tensors"
+ )
+
+ def __str__(self) -> str:
+ return f"TensorTile(offset={self._offset}, sizes={self._sizes}, strides={self._strides})"
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return (
+ self._tensor_dims == other._tensor_dims
+ and self._offset == other._offset
+ and self._sizes == other._sizes
+ and self._strides == other._strides
+ )
+ else:
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
diff --git a/python/helpers/tensortiler/tensortiler2d.py b/python/helpers/tensortiler/tensortiler2d.py
new file mode 100644
index 0000000000..c70481b587
--- /dev/null
+++ b/python/helpers/tensortiler/tensortiler2d.py
@@ -0,0 +1,337 @@
+from copy import deepcopy
+from functools import partial
+import numpy as np
+from typing import Sequence
+
+from .tensortilesequence import TensorTileSequence
+from .utils import ceildiv, validate_and_clean_sizes_strides, validate_tensor_dims
+
+
+class TensorTiler2D:
+ """
+ This is a generator (similar to factory pattern) class which produces TensorTileSequence
+ objects for common 2-dimensional tiling patterns.
+ """
+
+ _DTYPE = np.int32
+ _NUM_DIMS = 2
+
+ def __init__(self):
+ raise Exception(
+ f"{self.__class__} cannot be instantiated. Use it as a factory/generator of TensorTileSequences."
+ )
+
+ @classmethod
+ def simple_tiler(
+ cls,
+ tensor_dims: Sequence[int],
+ tile_dims: Sequence[int] | None = None,
+ tile_col_major: bool = False,
+ iter_col_major: bool = False,
+ pattern_repeat: int = 1,
+ ) -> TensorTileSequence:
+ if tile_dims is None:
+ tile_dims = deepcopy(tensor_dims)
+ # Special case of group_tiler
+ return cls.group_tiler(
+ tensor_dims=tensor_dims,
+ tile_dims=tile_dims,
+ tile_col_major=tile_col_major,
+ iter_col_major=iter_col_major,
+ pattern_repeat=pattern_repeat,
+ )
+
+ @classmethod
+ def group_tiler(
+ cls,
+ tensor_dims: Sequence[int],
+ tile_dims: Sequence[int],
+ tile_group_dims: Sequence[int] | None = None,
+ tile_col_major: bool = False,
+ tile_group_col_major: bool = False,
+ iter_col_major: bool = False,
+ pattern_repeat: int = 1,
+ allow_partial: bool = False,
+ ) -> TensorTileSequence:
+ if tile_group_dims is None:
+ tile_group_dims = (1,) * cls._NUM_DIMS
+ # Special case of step_tiler
+ return cls.step_tiler(
+ tensor_dims=tensor_dims,
+ tile_dims=tile_dims,
+ tile_group_repeats=tile_group_dims,
+ tile_col_major=tile_col_major,
+ tile_group_col_major=tile_group_col_major,
+ iter_col_major=iter_col_major,
+ pattern_repeat=pattern_repeat,
+ allow_partial=allow_partial,
+ )
+
+ @classmethod
+ def step_tiler(
+ cls,
+ tensor_dims: Sequence[int],
+ tile_dims: Sequence[int],
+ tile_group_repeats: Sequence[int],
+ tile_group_steps: Sequence[int] | None = None,
+ tile_col_major: bool = False,
+ tile_group_col_major: bool = False,
+ iter_col_major: bool = False,
+ allow_partial: bool = False,
+ pattern_repeat: int = 1,
+ ) -> TensorTileSequence:
+ if tile_group_steps is None:
+ tile_group_steps = (1,) * cls._NUM_DIMS
+ tensor_dims = validate_tensor_dims(tensor_dims, expected_dims=cls._NUM_DIMS)
+ tile_dims = validate_tensor_dims(tile_dims, expected_dims=cls._NUM_DIMS)
+ tile_group_repeats = validate_tensor_dims(
+ tile_group_repeats, expected_dims=cls._NUM_DIMS
+ )
+ if len(tile_group_steps) != cls._NUM_DIMS:
+ raise ValueError(f"Expcted {cls._NUM_DIMS} dimensions of tile group steps")
+ for i, s in enumerate(tile_group_steps):
+ if s < 0:
+ raise ValueError(
+ f"Tile group step dimension {i} must be >= 0, but got {s}"
+ )
+
+ for i, (tensor_dim, tile_dim) in enumerate(zip(tensor_dims, tile_dims)):
+ if tensor_dim % tile_dim != 0:
+ raise ValueError(
+ f"Tensor dimension {i} ({tensor_dim}) is not divisible by tile dim ({tile_dim})"
+ )
+
+ if not isinstance(pattern_repeat, int) or pattern_repeat < 1:
+ raise ValueError(f"Pattern repeat must be >= 1 but is {pattern_repeat}")
+ if not allow_partial:
+ for i, (tensor_dim, tile_dim, repeat_dim, step_dim) in enumerate(
+ zip(tensor_dims, tile_dims, tile_group_repeats, tile_group_steps)
+ ):
+ if tensor_dim % (tile_dim * repeat_dim * step_dim) != 0:
+ raise ValueError(
+ f"allow_partial={allow_partial} but tensor does not divide evenly into tile groups in dimension {i}"
+ )
+ if tile_dim * repeat_dim * step_dim > tensor_dim:
+ raise ValueError(
+ f"Tile pattern exceeds tensor size in dimension {i} ({tile_dim}x{repeat_dim}x{step_dim} > {tensor_dim})"
+ )
+
+ steps_per_dim = cls.__get_num_steps(
+ tensor_dims=tensor_dims,
+ tile_dims=tile_dims,
+ step_dims=tile_group_steps,
+ repeat_dims=tile_group_repeats,
+ )
+ num_steps = np.prod(steps_per_dim)
+
+ def offset_fn(step_num: int, _prev_offset: int) -> int:
+ tile_offsets = cls.__tile_offset_by_step_num(
+ step_num,
+ tile_group_steps,
+ tile_group_repeats,
+ steps_per_dim,
+ iter_col_major,
+ )
+ total_offset = 0
+ num_dims = len(tile_offsets)
+ for dim, (offset, tile_dim) in enumerate(zip(tile_offsets, tile_dims)):
+ total_offset += (
+ offset
+ * tile_dim
+ * (np.prod(tensor_dims[dim + 1 :]) if dim < num_dims - 1 else 1)
+ )
+ return total_offset
+
+ def sizes_or_strides_fn(step_num, _prev_sizes, is_sizes):
+ tile_offsets = cls.__tile_offset_by_step_num(
+ step_num,
+ tile_group_steps,
+ tile_group_repeats,
+ steps_per_dim,
+ iter_col_major,
+ )
+
+ iter_sizes, iter_strides = cls.__sizes_strides_for_step_tile_group(
+ tensor_dims,
+ tile_dims,
+ tile_group_steps,
+ tile_group_repeats,
+ tile_offsets,
+ tile_col_major,
+ tile_group_col_major,
+ pattern_repeat=pattern_repeat,
+ )
+ if is_sizes:
+ return iter_sizes
+ else:
+ return iter_strides
+
+ sizes_fn = partial(sizes_or_strides_fn, is_sizes=True)
+ strides_fn = partial(sizes_or_strides_fn, is_sizes=False)
+
+ return TensorTileSequence(
+ tensor_dims,
+ num_steps,
+ sizes_fn=sizes_fn,
+ strides_fn=strides_fn,
+ offset_fn=offset_fn,
+ )
+
+ @classmethod
+ def __get_num_steps(
+ cls,
+ tensor_dims: Sequence[int],
+ tile_dims: Sequence[int],
+ step_dims: Sequence[int],
+ repeat_dims: Sequence[int],
+ ) -> Sequence[int]:
+ num_steps_dims = []
+ for tensor_dim, tile_dim, step_dim, repeat_dim in zip(
+ tensor_dims, tile_dims, step_dims, repeat_dims
+ ):
+ num_steps_per_dim = tensor_dim // (tile_dim * repeat_dim)
+ tiles_in_tensor = (tensor_dim // tile_dim) - num_steps_per_dim * repeat_dim
+ partial_height_steps = 0
+ while tiles_in_tensor > 0:
+ tiles_in_tensor -= min(repeat_dim, ceildiv(tiles_in_tensor, step_dim))
+ partial_height_steps += 1
+ num_steps_per_dim += partial_height_steps
+ num_steps_dims.append(num_steps_per_dim)
+ return num_steps_dims
+
+ @classmethod
+ def __tile_offset_by_step_num(
+ cls,
+ step_num: int,
+ tile_group_steps: Sequence[int],
+ tile_group_repeats: Sequence[int],
+ num_steps: Sequence[int],
+ iter_col_major: bool,
+ ) -> Sequence[int]:
+ # TODO: this code is still specific to two dimensions
+ steps_per_col, steps_per_row = num_steps
+ tile_step_height, tile_step_width = tile_group_steps
+ tile_repeat_height, tile_repeat_width = tile_group_repeats
+
+ if not iter_col_major:
+ row_idx = step_num % steps_per_row
+ col_idx = step_num // steps_per_row
+ else:
+ col_idx = step_num % steps_per_col
+ row_idx = step_num // steps_per_col
+
+ col_chunk_idx = col_idx // tile_step_height
+ row_chunk_idx = row_idx // tile_step_width
+ col_in_chunk_idx = col_idx % tile_step_height
+ row_in_chunk_idx = row_idx % tile_step_width
+
+ tile_offset_in_row = (
+ row_chunk_idx * tile_step_width * tile_repeat_width + row_in_chunk_idx
+ )
+ tile_offset_in_col = (
+ col_chunk_idx * tile_step_height * tile_repeat_height + col_in_chunk_idx
+ )
+ return (tile_offset_in_col, tile_offset_in_row)
+
+ @classmethod
+ def __sizes_strides_for_step_tile_group(
+ cls,
+ tensor_dims: Sequence[int],
+ tile_dims: Sequence[int],
+ tile_group_steps: Sequence[int],
+ tile_group_repeats: Sequence[int],
+ tile_offsets: Sequence[int],
+ tile_col_major: bool,
+ tile_group_col_major: bool,
+ pattern_repeat: int,
+ ) -> tuple[Sequence[int], Sequence[int]]:
+ # TODO: this code is still specific to two dimensions
+ # TODO: this code assumes sizes/strides of len 4
+
+ # Interior method, assumes all validation already done
+ tensor_height, tensor_width = tensor_dims
+ tile_height, tile_width = tile_dims
+ tile_step_height, tile_step_width = tile_group_steps
+ tile_repeat_height, tile_repeat_width = tile_group_repeats
+ tile_offset_height, tile_offset_width = tile_offsets
+
+ tiles_remaining_height = tensor_height // tile_height - tile_offset_height
+ tiles_remaining_width = tensor_width // tile_width - tile_offset_width
+
+ # use tile offsets to prune step
+ if tile_step_height > tiles_remaining_height:
+ tile_step_height = 1
+ if tile_step_width > tiles_remaining_width:
+ tile_step_width = 1
+
+ # use tile offsets to prune repeat count
+ tile_repeat_width = min(
+ tile_repeat_width,
+ ceildiv(tiles_remaining_width, tile_step_width),
+ )
+ tile_repeat_height = min(
+ tile_repeat_height,
+ ceildiv(tiles_remaining_height, tile_step_height),
+ )
+ tile_group_repeats = (tile_repeat_height, tile_repeat_width)
+
+ if (
+ tile_group_col_major
+ and tile_step_height == 1
+ and tile_repeat_height > 1
+ and not tile_col_major
+ ):
+ # Can combine into one big tile vertically
+ tile_height *= tile_repeat_height
+ tile_repeat_height = 1
+ elif (
+ not tile_group_col_major
+ and tile_step_width == 1
+ and tile_repeat_width > 1
+ and tile_col_major
+ ):
+ # Can combine into one big tile horizontally
+ tile_width *= tile_repeat_width
+ tile_repeat_width = 1
+
+ if not tile_col_major:
+ iter_sizes = [1, 1, tile_height, tile_width]
+ iter_strides = [0, 0, tensor_width, 1]
+ else:
+ iter_sizes = [1, 1, tile_width, tile_height]
+ iter_strides = [0, 0, 1, tensor_width]
+
+ if tile_repeat_width > 1 and tile_repeat_height > 1:
+ idx_order = [0, 1]
+ if tile_group_col_major:
+ idx_order = [1, 0]
+ iter_sizes[idx_order[0]] = tile_repeat_height
+ iter_sizes[idx_order[1]] = tile_repeat_width
+ iter_strides[idx_order[0]] = tensor_width * tile_height * tile_step_height
+ iter_strides[idx_order[1]] = tile_width * tile_step_width
+ elif tile_repeat_height > 1:
+ iter_sizes[1], iter_strides[1] = (
+ tile_repeat_height,
+ tensor_width * tile_height * tile_step_height,
+ )
+ elif tile_repeat_width > 1:
+ iter_sizes[1], iter_strides[1] = (
+ tile_repeat_width,
+ tile_width * tile_step_width,
+ )
+
+ iter_sizes, iter_strides = validate_and_clean_sizes_strides(
+ iter_sizes, iter_strides
+ )
+
+ if pattern_repeat != 1:
+ if iter_sizes[1] == 1 and iter_strides[0] == 0:
+ iter_sizes[1] = pattern_repeat
+ else:
+ if iter_sizes[0] != 1 or iter_strides[0] != 0:
+ raise ValueError(
+ f"Ran out of dimensions for repeat (sizes={iter_sizes}, strides={iter_strides})"
+ )
+ iter_sizes[0] = pattern_repeat
+
+ return iter_sizes, iter_strides
diff --git a/python/helpers/tensortiler/tensortilesequence.py b/python/helpers/tensortiler/tensortilesequence.py
new file mode 100644
index 0000000000..f57b44fd44
--- /dev/null
+++ b/python/helpers/tensortiler/tensortilesequence.py
@@ -0,0 +1,247 @@
+from __future__ import annotations
+from collections import abc
+from copy import deepcopy
+import matplotlib.animation as animation
+import numpy as np
+from typing import Callable, Sequence
+
+from .tensortile import TensorTile
+from .utils import (
+ validate_and_clean_sizes_strides,
+ validate_offset,
+ validate_tensor_dims,
+)
+from .visualization2d import animate_from_access_tensors, visualize_from_access_tensors
+
+
+class TensorTileSequence(abc.MutableSequence, abc.Iterable):
+ """
+ TensorTileSequence is a MutableSequence and an Iterable which is a thin wrapper around a list[TensorTiles].
+ """
+
+ def __init__(
+ self,
+ tensor_dims: Sequence[int],
+ num_steps: int,
+ offset: int | None = None,
+ sizes: Sequence[int] | None = None,
+ strides: Sequence[int] | None = None,
+ offset_fn: Callable[[int, int], int] | None = None,
+ sizes_fn: Callable[[int, Sequence[int]], Sequence[int]] | None = None,
+ strides_fn: Callable[[int, Sequence[int]], Sequence[int]] | None = None,
+ ):
+ self._current_step = 0
+
+ # Check tensor dims, offset, sizes, strides
+ self._tensor_dims = validate_tensor_dims(tensor_dims)
+ if not (offset is None):
+ offset = validate_offset(offset)
+ sizes, strides = validate_and_clean_sizes_strides(
+ sizes, strides, allow_none=True
+ )
+
+ # Validate and set num steps
+ if num_steps < 0:
+ raise ValueError(f"Number of steps must be positive (but is {num_steps})")
+
+ if num_steps == 0:
+ if (
+ offset != None
+ or sizes != None
+ or strides != None
+ or offset_fn != None
+ or sizes_fn != None
+ or strides_fn != None
+ ):
+ raise ValueError(
+ f"If num_steps=0, no sizes/strides/offset information may be specified"
+ )
+ self._tiles = []
+ else:
+ # Make sure values or not None if iteration functions are None; also set default iter fn
+ if offset_fn is None:
+ if offset is None:
+ raise ValueError("Offset must be provided if offset_fn is None")
+ offset_fn = lambda _step, _prev_offset: offset
+ else:
+ offset_fn = offset_fn
+ if sizes_fn is None:
+ if sizes is None:
+ raise ValueError("Sizes must be provided if size_fn is None")
+ sizes_fn = lambda _step, _prev_sizes: sizes
+ else:
+ sizes_fn = sizes_fn
+ if strides_fn is None:
+ if strides is None:
+ raise ValueError("Strides must be provided if stride_fn is None")
+ strides_fn = lambda _step, _prev_strides: strides
+ else:
+ strides_fn = strides_fn
+
+ # Pre-calculate tiles, because better for error handling up-front (and for visualizing full iter)
+ # This is somewhat against the mentality behind iterations, but should be okay at the scale this
+ # class will be used for (e.g., no scalability concerns with keeping all tiles in mem)
+ self._tiles = []
+ for step in range(num_steps):
+ offset = offset_fn(step, offset)
+ sizes = sizes_fn(step, sizes)
+ strides = strides_fn(step, strides)
+
+ self._tiles.append(
+ TensorTile(
+ self._tensor_dims,
+ offset,
+ sizes,
+ strides,
+ )
+ )
+
+ @classmethod
+ def from_tiles(cls, tiles: Sequence[TensorTile]) -> TensorTileSequence:
+ if len(tiles) < 1:
+ raise ValueError(
+ "Received no tiles; must have at least one tile to create a tile sequence."
+ )
+ tensor_dims = tiles[0].tensor_dims
+ for t in tiles:
+ if t.tensor_dims != tensor_dims:
+ raise ValueError(
+ f"Tiles have multiple tensor dimensions (found {tensor_dims} and {t.tensor_dims})"
+ )
+ tileseq = cls(
+ tensor_dims,
+ num_steps=1,
+ offset=tiles[0].offset,
+ sizes=tiles[0].sizes,
+ strides=tiles[0].strides,
+ )
+ for t in tiles[1:]:
+ tileseq.append(t)
+ return tileseq
+
+ def access_tensors(self) -> tuple[np.ndarray, np.ndarray]:
+ total_elems = np.prod(self._tensor_dims)
+
+ combined_access_order_tensor = np.full(
+ total_elems, 0, TensorTile._DTYPE
+ ).reshape(self._tensor_dims)
+ combined_access_count_tensor = np.full(
+ total_elems, 0, TensorTile._DTYPE
+ ).reshape(self._tensor_dims)
+ highest_count = 0
+ for t in self._tiles:
+ t_access_order, t_access_count = t.access_tensors()
+ t_access_order[t_access_order != -1] += 1 + highest_count
+ t_access_order[t_access_order == -1] = 0
+ combined_access_order_tensor += t_access_order
+ highest_count = np.max(combined_access_order_tensor)
+
+ combined_access_count_tensor += t_access_count
+
+ combined_access_order_tensor -= 1
+ return (combined_access_order_tensor, combined_access_count_tensor)
+
+ def animate(
+ self, title: str = None, animate_access_count: bool = False
+ ) -> animation.FuncAnimation:
+ if title is None:
+ title = "TensorTileSequence Animation"
+ if len(self._tensor_dims) == 2:
+ total_elems = np.prod(self._tensor_dims)
+
+ animate_order_frames = [
+ np.full(total_elems, -1, TensorTile._DTYPE).reshape(self._tensor_dims)
+ ]
+ if animate_access_count:
+ animate_count_frames = [
+ np.full(total_elems, 0, TensorTile._DTYPE).reshape(
+ self._tensor_dims
+ )
+ ]
+ else:
+ animate_count_frames = None
+
+ for t in self._tiles:
+ t_access_order, t_access_count = t.access_tensors()
+ animate_order_frames.append(t_access_order)
+ if animate_access_count:
+ animate_count_frames.append(t_access_count)
+
+ return animate_from_access_tensors(
+ animate_order_frames,
+ animate_count_frames,
+ title=title,
+ )
+
+ else:
+ raise NotImplementedError(
+ "Visualization is only currently supported for 1- or 2-dimensional tensors"
+ )
+
+ def visualize(
+ self,
+ title: str = None,
+ file_path: str | None = None,
+ show_plot: bool = True,
+ plot_access_count: bool = False,
+ ) -> None:
+ if len(self._tensor_dims) != 2:
+ raise NotImplementedError(
+ "Visualization is only currently supported for 1- or 2-dimensional tensors"
+ )
+
+ if title is None:
+ title = "TensorTileSequence"
+ access_order_tensor, access_count_tensor = self.access_tensors()
+ if not plot_access_count:
+ access_count_tensor = None
+
+ visualize_from_access_tensors(
+ access_order_tensor,
+ access_count_tensor,
+ title=title,
+ show_arrows=False,
+ file_path=file_path,
+ show_plot=show_plot,
+ )
+
+ def __contains__(self, tile: TensorTile):
+ return tile in self._tiles
+
+ def __iter__(self):
+ return iter(deepcopy(self._tiles))
+
+ def __len__(self) -> int:
+ return len(self._tiles)
+
+ def __getitem__(self, idx: int) -> TensorTile:
+ return self._tiles[idx]
+
+ def __setitem__(self, idx: int, tile: TensorTile):
+ if self._tensor_dims != tile.tensor_dims:
+ raise ValueError(
+ f"Cannot add tile with tensor dims {tile.tensor_dims} to sequence of tiles with tensor dims {self._tensor_dims}"
+ )
+ self._tiles[idx] = deepcopy(tile)
+
+ def __delitem__(self, idx: int):
+ del self._tiles[idx]
+
+ def insert(self, idx: int, tile: TensorTile):
+ if self._tensor_dims != tile.tensor_dims:
+ raise ValueError(
+ f"Cannot add tile with tensor dims {tile.tensor_dims} to sequence of tiles with tensor dims {self._tensor_dims}"
+ )
+ self._tiles.insert(idx, tile)
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return (
+ self._tiles == other._tiles
+ and self._current_step == other._current_step
+ )
+ else:
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
diff --git a/python/helpers/tensortiler/utils.py b/python/helpers/tensortiler/utils.py
new file mode 100644
index 0000000000..c997e72d94
--- /dev/null
+++ b/python/helpers/tensortiler/utils.py
@@ -0,0 +1,114 @@
+from copy import deepcopy
+from typing import Sequence
+
+
+def ceildiv(a, b):
+ return -(a // -b)
+
+
+def validate_and_clean_sizes_strides(
+ sizes: Sequence[int] | None,
+ strides: Sequence[int] | None,
+ allow_none: bool = False,
+ expected_dims: int | None = None,
+) -> tuple[Sequence[int] | None, Sequence[int] | None]:
+ if not allow_none:
+ if sizes is None:
+ raise ValueError("Sizes is None, but expected Sequence[int]")
+ if strides is None:
+ raise ValueError("Strides is None, but expected Sequence[int]")
+ # After this point can assume None is ok for sizes/strides
+
+ if not (expected_dims is None):
+ if expected_dims < 1:
+ raise ValueError(f"Expected dimensions ({expected_dims}) should be >= 1")
+
+ if sizes is None and strides is None:
+ # nothing to do
+ return None, None
+
+ # Validate dimensions
+ if (not (sizes is None)) and len(sizes) == 0:
+ raise ValueError("len(sizes) must be >0")
+ if (not (strides is None)) and len(strides) == 0:
+ raise ValueError("len(strides) must be >0")
+
+ if sizes and strides:
+ if expected_dims:
+ if len(sizes) != expected_dims:
+ raise (
+ f"Num dimensions of sizes ({sizes}) is not expected number of dimensions ({expected_dims})"
+ )
+ if len(strides) != expected_dims:
+ raise (
+ f"Num dimensions of strides ({strides}) is not expected number of dimensions ({expected_dims})"
+ )
+ elif len(strides) != len(sizes):
+ raise ValueError(
+ f"len(sizes) ({len(sizes)}) != len(strides) ({len(strides)})"
+ )
+ if strides:
+ num_dims = len(strides)
+ else:
+ num_dims = len(sizes)
+
+ # Validate sizes/strides values
+ if sizes:
+ sizes = deepcopy(sizes)
+ for s in sizes:
+ if s < 1:
+ raise ValueError(f"All sizes must be >= 1, but got {sizes}")
+ if strides:
+ strides = deepcopy(strides)
+ for s in strides:
+ if s < 0:
+ raise ValueError(f"All strides must be >= 0, but got {strides}")
+
+ # Clean (set size=1, stride=0 for as many dims as possible)
+ if sizes and strides:
+ for i in range(num_dims):
+ if sizes[i] == 1:
+ if isinstance(strides, tuple):
+ # Tuple is immutable, so convert if necessary
+ strides = list(strides)
+ strides[i] = 0
+ else:
+ break
+ return sizes, strides
+
+
+def validate_tensor_dims(
+ tensor_dims: Sequence[int], expected_dims: int | None = None
+) -> Sequence[int]:
+ if not (expected_dims is None):
+ if expected_dims < 1:
+ raise ValueError(f"Expected dimensions ({expected_dims}) should be >= 1")
+ tensor_dims = deepcopy(tensor_dims)
+
+ # Validate tensor dims and offset, then set
+ if len(tensor_dims) == 0:
+ raise ValueError(
+ f"Number of tensor dimensions must be >= 1 (dimensions={tensor_dims})"
+ )
+ for d in tensor_dims:
+ if d <= 0:
+ raise ValueError(
+ f"Each tensor dimension must be >= 1 (dimensions={tensor_dims})"
+ )
+
+ # We can treat a 1-dimensional tensor as a 2-dimensional tensor,
+ if len(tensor_dims) == 1:
+ tensor_dims = [1, tensor_dims[0]]
+
+ if not (expected_dims is None) and len(tensor_dims) != expected_dims:
+ raise ValueError(
+ f"Tensor dimension ({tensor_dims}) does not match expected dimension ({expected_dims})"
+ )
+
+ return tensor_dims
+
+
+def validate_offset(offset: int):
+ if offset < 0:
+ raise ValueError(f"Offset must be >= 0 (offset={offset})")
+ return offset
diff --git a/python/helpers/tensortiler/visualization2d.py b/python/helpers/tensortiler/visualization2d.py
new file mode 100644
index 0000000000..7244475f8e
--- /dev/null
+++ b/python/helpers/tensortiler/visualization2d.py
@@ -0,0 +1,172 @@
+import matplotlib.animation as animation
+import matplotlib.patheffects as pe
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import sys
+
+from .utils import ceildiv
+
+
+def animate_from_access_tensors(
+ access_order_tensors: list[np.ndarray],
+ access_count_tensors: list[np.ndarray] | None,
+ title: str = "Animated Access Visualization",
+) -> animation.FuncAnimation:
+ if len(access_order_tensors) < 1:
+ raise ValueError("At least one access order tensor is required.")
+ if not (access_count_tensors is None):
+ if len(access_count_tensors) < 1:
+ raise ValueError(
+ "access_count_tensor should be None or requires at least one tensor"
+ )
+ if len(access_count_tensors) != len(access_order_tensors):
+ raise ValueError(
+ "Number of access count tensors and number of access order tensors should be equal"
+ )
+
+ tensor_height, tensor_width = access_order_tensors[0].shape
+ fig_width = 7
+ if tensor_width < 32:
+ fig_width = 5
+ height_width_ratio = ceildiv(tensor_height, tensor_width)
+ fig_height = min(fig_width, fig_width * height_width_ratio)
+
+ if not (access_count_tensors is None):
+ fig_height *= 2
+ fig, (ax_order, ax_count) = plt.subplots(2, 1)
+ else:
+ fig, ax_order = plt.subplots()
+
+ fig.set_figheight(fig_height)
+ fig.set_figwidth(fig_width)
+ fig.suptitle(title)
+ xs = np.arange(access_order_tensors[0].shape[1])
+ ys = np.arange(access_order_tensors[0].shape[0])
+
+ ax_order.xaxis.tick_top()
+ ax_order.invert_yaxis()
+ ax_order.set_title("Access Order Animation")
+
+ if not (access_count_tensors is None):
+ ax_count.xaxis.tick_top()
+ ax_count.invert_yaxis()
+ ax_count.set_title(f"Access Counts")
+
+ def animate_order(i):
+ access_heatmap = ax_order.pcolormesh(access_order_tensors[i])
+
+ if not (access_count_tensors is None):
+ count_heatmap = ax_count.pcolormesh(
+ xs, ys, access_count_tensors[i], cmap="gnuplot2"
+ )
+ return (
+ access_heatmap,
+ count_heatmap,
+ )
+ return access_heatmap
+
+ _animation = animation.FuncAnimation(
+ fig,
+ animate_order,
+ frames=len(access_order_tensors),
+ interval=max(400, 100 + 5 * len(access_order_tensors)),
+ )
+
+ plt.tight_layout()
+ plt.close()
+ return _animation
+
+
+def visualize_from_access_tensors(
+ access_order_tensor: np.ndarray,
+ access_count_tensor: np.ndarray | None,
+ title: str = "Access Visualization",
+ show_arrows: bool | None = None,
+ file_path: str | None = None,
+ show_plot: bool = True,
+):
+ tensor_height, tensor_width = access_order_tensor.shape
+ if tensor_height * tensor_width >= 1024:
+ if show_arrows:
+ print(
+ f"show_arrows not recommended for tensor sizes > 1024 elements",
+ file=sys.stderr,
+ )
+ if show_arrows is None:
+ show_arrows = False
+ elif show_arrows is None:
+ # Set to true by default only for 'small' tensor sizes
+ show_arrows = True
+
+ fig_width = 7
+ if tensor_width < 32:
+ fig_width = 5
+ height_width_ratio = ceildiv(tensor_height, tensor_width)
+ fig_height = min(fig_width, fig_width * height_width_ratio)
+
+ if not (access_count_tensor is None):
+ fig_height *= 2
+ fig, (ax_order, ax_count) = plt.subplots(2, 1)
+ else:
+ fig, ax_order = plt.subplots()
+
+ fig.set_figheight(fig_height)
+ fig.set_figwidth(fig_width)
+ fig.suptitle(title)
+ xs = np.arange(access_order_tensor.shape[1])
+ ys = np.arange(access_order_tensor.shape[0])
+
+ _access_heatmap = ax_order.pcolormesh(xs, ys, access_order_tensor, cmap="gnuplot2")
+ ax_order.xaxis.tick_top()
+ ax_order.invert_yaxis()
+ ax_order.set_title("Access Order")
+
+ if not (access_count_tensor is None):
+ max_count = np.max(access_count_tensor)
+ _count_heatmap = ax_count.pcolormesh(
+ xs, ys, access_count_tensor, cmap="gnuplot2"
+ )
+ ax_count.xaxis.tick_top()
+ ax_count.invert_yaxis()
+ ax_count.set_title(f"Access Counts (max={max_count})")
+
+ # Add arrows to show access order
+ if show_arrows:
+ # Thanks to https://stackoverflow.com/questions/37719304/python-imshow-set-certain-value-to-defined-color
+ # Thanks to tmdavison answer here https://stackoverflow.com/a/40890587/7871710
+
+ order_dict = {}
+ for i in range(access_order_tensor.shape[0]):
+ for j in range(access_order_tensor.shape[1]):
+ if access_order_tensor[i, j] != -1:
+ order_dict[access_order_tensor[i, j]] = (i, j)
+
+ order_keys = list(order_dict.keys())
+ order_keys.sort()
+ for i in range(order_keys[0], order_keys[-1]):
+ y1, x1 = order_dict[i]
+ y2, x2 = order_dict[i + 1]
+ ax_order.arrow(
+ x1,
+ y1,
+ x2 - x1,
+ y2 - y1,
+ length_includes_head=True,
+ head_width=0.1,
+ head_length=0.15,
+ overhang=0.2,
+ path_effects=[pe.withStroke(linewidth=3, foreground="white")],
+ )
+
+ plt.tight_layout()
+ if show_plot:
+ plt.show()
+ if file_path:
+ if os.path.exists(file_path):
+ print(
+ f"Cannot save plot to {file_path}; file already exists",
+ file=sys.stderr,
+ )
+ plt.savefig(file_path)
+ plt.close()
diff --git a/python/requirements.txt b/python/requirements.txt
index 780c4f4481..6efee7f4bd 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -12,3 +12,4 @@ rich
setuptools
wheel
ml_dtypes
+matplotlib
diff --git a/python/utils/xrt.py b/python/utils/xrt.py
index 276e9e2cde..bd4fc1e3a8 100644
--- a/python/utils/xrt.py
+++ b/python/utils/xrt.py
@@ -131,7 +131,8 @@ def setup_aie(
):
app = AIE_Application(xclbin_path, insts_path, kernel_name)
- app.register_buffer(3, shape=in_0_shape, dtype=in_0_dtype)
+ if in_0_shape or in_0_dtype:
+ app.register_buffer(3, shape=in_0_shape, dtype=in_0_dtype)
if in_1_shape or in_1_dtype:
app.register_buffer(4, shape=in_1_shape, dtype=in_1_dtype)
@@ -159,8 +160,9 @@ def write_out_trace(trace, file_name):
f.write(out_str)
-def execute(app, input_one, input_two=None):
- app.buffers[3].write(input_one)
+def execute(app, input_one=None, input_two=None):
+ if not (input_one is None):
+ app.buffers[3].write(input_one)
if not (input_two is None):
app.buffers[4].write(input_two)
app.run()
diff --git a/test/python/tensortiler/group_tiler.py b/test/python/tensortiler/group_tiler.py
new file mode 100644
index 0000000000..527ce8397e
--- /dev/null
+++ b/test/python/tensortiler/group_tiler.py
@@ -0,0 +1,576 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile, TensorTiler2D
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: group_tiler
+@construct_test
+def group_tiler():
+ # Default tile group dims
+ tiles = TensorTiler2D.group_tiler((3 * 5, 2 * 7), tile_dims=(3, 2))
+ assert len(tiles) == 5 * 7
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125],
+ [126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163],
+ [128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165],
+ [130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167],
+ [168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205],
+ [170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207],
+ [172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ tiles = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2), tile_dims=(3, 2), tile_group_dims=(5, 7)
+ )
+ assert len(tiles) == 3 * 2
+ tile0_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ )
+ assert tiles[0] == tile0_0
+ tile0_1 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ )
+ assert tiles[1] == tile0_1
+ tile1_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ )
+ assert tiles[2] == tile1_0
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335],
+ [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373],
+ [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375],
+ [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377],
+ [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415],
+ [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417],
+ [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419],
+ [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667],
+ [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669],
+ [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671],
+ [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709],
+ [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711],
+ [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713],
+ [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751],
+ [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753],
+ [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755],
+ [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793],
+ [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795],
+ [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797],
+ [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835],
+ [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837],
+ [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839],
+ [ 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075, 1080, 1081, 1086, 1087],
+ [ 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077, 1082, 1083, 1088, 1089],
+ [ 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079, 1084, 1085, 1090, 1091],
+ [ 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919, 1092, 1093, 1098, 1099, 1104, 1105, 1110, 1111, 1116, 1117, 1122, 1123, 1128, 1129],
+ [ 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921, 1094, 1095, 1100, 1101, 1106, 1107, 1112, 1113, 1118, 1119, 1124, 1125, 1130, 1131],
+ [ 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923, 1096, 1097, 1102, 1103, 1108, 1109, 1114, 1115, 1120, 1121, 1126, 1127, 1132, 1133],
+ [ 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961, 1134, 1135, 1140, 1141, 1146, 1147, 1152, 1153, 1158, 1159, 1164, 1165, 1170, 1171],
+ [ 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963, 1136, 1137, 1142, 1143, 1148, 1149, 1154, 1155, 1160, 1161, 1166, 1167, 1172, 1173],
+ [ 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965, 1138, 1139, 1144, 1145, 1150, 1151, 1156, 1157, 1162, 1163, 1168, 1169, 1174, 1175],
+ [ 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003, 1176, 1177, 1182, 1183, 1188, 1189, 1194, 1195, 1200, 1201, 1206, 1207, 1212, 1213],
+ [ 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005, 1178, 1179, 1184, 1185, 1190, 1191, 1196, 1197, 1202, 1203, 1208, 1209, 1214, 1215],
+ [ 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007, 1180, 1181, 1186, 1187, 1192, 1193, 1198, 1199, 1204, 1205, 1210, 1211, 1216, 1217],
+ [1008, 1009, 1014, 1015, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045, 1218, 1219, 1224, 1225, 1230, 1231, 1236, 1237, 1242, 1243, 1248, 1249, 1254, 1255],
+ [1010, 1011, 1016, 1017, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047, 1220, 1221, 1226, 1227, 1232, 1233, 1238, 1239, 1244, 1245, 1250, 1251, 1256, 1257],
+ [1012, 1013, 1018, 1019, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049, 1222, 1223, 1228, 1229, 1234, 1235, 1240, 1241, 1246, 1247, 1252, 1253, 1258, 1259]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # iter_col_major
+ tiles_col_iter = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2),
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ iter_col_major=True,
+ )
+ assert tiles_col_iter[0] == tile0_0
+ assert tiles_col_iter[1] == tile1_0
+ assert tiles_col_iter[3] == tile0_1
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755],
+ [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793],
+ [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795],
+ [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797],
+ [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835],
+ [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837],
+ [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839],
+ [ 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877],
+ [ 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879],
+ [ 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881],
+ [ 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919],
+ [ 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921],
+ [ 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923],
+ [ 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961],
+ [ 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963],
+ [ 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965],
+ [ 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003],
+ [ 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005],
+ [ 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007],
+ [ 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 1008, 1009, 1014, 1015, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045],
+ [ 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 1010, 1011, 1016, 1017, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047],
+ [ 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 1012, 1013, 1018, 1019, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049],
+ [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075, 1080, 1081, 1086, 1087],
+ [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077, 1082, 1083, 1088, 1089],
+ [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079, 1084, 1085, 1090, 1091],
+ [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 1092, 1093, 1098, 1099, 1104, 1105, 1110, 1111, 1116, 1117, 1122, 1123, 1128, 1129],
+ [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 1094, 1095, 1100, 1101, 1106, 1107, 1112, 1113, 1118, 1119, 1124, 1125, 1130, 1131],
+ [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 1096, 1097, 1102, 1103, 1108, 1109, 1114, 1115, 1120, 1121, 1126, 1127, 1132, 1133],
+ [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 1134, 1135, 1140, 1141, 1146, 1147, 1152, 1153, 1158, 1159, 1164, 1165, 1170, 1171],
+ [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 1136, 1137, 1142, 1143, 1148, 1149, 1154, 1155, 1160, 1161, 1166, 1167, 1172, 1173],
+ [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 1138, 1139, 1144, 1145, 1150, 1151, 1156, 1157, 1162, 1163, 1168, 1169, 1174, 1175],
+ [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 1176, 1177, 1182, 1183, 1188, 1189, 1194, 1195, 1200, 1201, 1206, 1207, 1212, 1213],
+ [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 1178, 1179, 1184, 1185, 1190, 1191, 1196, 1197, 1202, 1203, 1208, 1209, 1214, 1215],
+ [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 1180, 1181, 1186, 1187, 1192, 1193, 1198, 1199, 1204, 1205, 1210, 1211, 1216, 1217],
+ [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 1218, 1219, 1224, 1225, 1230, 1231, 1236, 1237, 1242, 1243, 1248, 1249, 1254, 1255],
+ [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 1220, 1221, 1226, 1227, 1232, 1233, 1238, 1239, 1244, 1245, 1250, 1251, 1256, 1257],
+ [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 1222, 1223, 1228, 1229, 1234, 1235, 1240, 1241, 1246, 1247, 1252, 1253, 1258, 1259]])
+ # fmt: on
+ access_order, access_count = tiles_col_iter.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # tile_col_major
+ tiles_tile_col_major = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2),
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ )
+ tile0_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ )
+ assert tiles_tile_col_major[0] == tile0_0
+ tile0_1 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ )
+ assert tiles_tile_col_major[1] == tile0_1
+ tile1_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ )
+ assert tiles_tile_col_major[2] == tile1_0
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249],
+ [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250],
+ [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251],
+ [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291],
+ [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292],
+ [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293],
+ [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333],
+ [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334],
+ [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335],
+ [ 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375],
+ [ 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376],
+ [ 128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377],
+ [ 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417],
+ [ 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418],
+ [ 170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419],
+ [ 420, 423, 426, 429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669],
+ [ 421, 424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670],
+ [ 422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458, 461, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671],
+ [ 462, 465, 468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711],
+ [ 463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499, 502, 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712],
+ [ 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497, 500, 503, 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713],
+ [ 504, 507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543, 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753],
+ [ 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538, 541, 544, 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754],
+ [ 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536, 539, 542, 545, 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755],
+ [ 546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582, 585, 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795],
+ [ 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577, 580, 583, 586, 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796],
+ [ 548, 551, 554, 557, 560, 563, 566, 569, 572, 575, 578, 581, 584, 587, 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797],
+ [ 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627, 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837],
+ [ 589, 592, 595, 598, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628, 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838],
+ [ 590, 593, 596, 599, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629, 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839],
+ [ 840, 843, 846, 849, 852, 855, 858, 861, 864, 867, 870, 873, 876, 879, 1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089],
+ [ 841, 844, 847, 850, 853, 856, 859, 862, 865, 868, 871, 874, 877, 880, 1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090],
+ [ 842, 845, 848, 851, 854, 857, 860, 863, 866, 869, 872, 875, 878, 881, 1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091],
+ [ 882, 885, 888, 891, 894, 897, 900, 903, 906, 909, 912, 915, 918, 921, 1092, 1095, 1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131],
+ [ 883, 886, 889, 892, 895, 898, 901, 904, 907, 910, 913, 916, 919, 922, 1093, 1096, 1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132],
+ [ 884, 887, 890, 893, 896, 899, 902, 905, 908, 911, 914, 917, 920, 923, 1094, 1097, 1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133],
+ [ 924, 927, 930, 933, 936, 939, 942, 945, 948, 951, 954, 957, 960, 963, 1134, 1137, 1140, 1143, 1146, 1149, 1152, 1155, 1158, 1161, 1164, 1167, 1170, 1173],
+ [ 925, 928, 931, 934, 937, 940, 943, 946, 949, 952, 955, 958, 961, 964, 1135, 1138, 1141, 1144, 1147, 1150, 1153, 1156, 1159, 1162, 1165, 1168, 1171, 1174],
+ [ 926, 929, 932, 935, 938, 941, 944, 947, 950, 953, 956, 959, 962, 965, 1136, 1139, 1142, 1145, 1148, 1151, 1154, 1157, 1160, 1163, 1166, 1169, 1172, 1175],
+ [ 966, 969, 972, 975, 978, 981, 984, 987, 990, 993, 996, 999, 1002, 1005, 1176, 1179, 1182, 1185, 1188, 1191, 1194, 1197, 1200, 1203, 1206, 1209, 1212, 1215],
+ [ 967, 970, 973, 976, 979, 982, 985, 988, 991, 994, 997, 1000, 1003, 1006, 1177, 1180, 1183, 1186, 1189, 1192, 1195, 1198, 1201, 1204, 1207, 1210, 1213, 1216],
+ [ 968, 971, 974, 977, 980, 983, 986, 989, 992, 995, 998, 1001, 1004, 1007, 1178, 1181, 1184, 1187, 1190, 1193, 1196, 1199, 1202, 1205, 1208, 1211, 1214, 1217],
+ [1008, 1011, 1014, 1017, 1020, 1023, 1026, 1029, 1032, 1035, 1038, 1041, 1044, 1047, 1218, 1221, 1224, 1227, 1230, 1233, 1236, 1239, 1242, 1245, 1248, 1251, 1254, 1257],
+ [1009, 1012, 1015, 1018, 1021, 1024, 1027, 1030, 1033, 1036, 1039, 1042, 1045, 1048, 1219, 1222, 1225, 1228, 1231, 1234, 1237, 1240, 1243, 1246, 1249, 1252, 1255, 1258],
+ [1010, 1013, 1016, 1019, 1022, 1025, 1028, 1031, 1034, 1037, 1040, 1043, 1046, 1049, 1220, 1223, 1226, 1229, 1232, 1235, 1238, 1241, 1244, 1247, 1250, 1253, 1256, 1259]])
+ # fmt: on
+ access_order, access_count = tiles_tile_col_major.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # iter_col_major and tile_col_major
+ tiles_tile_col_major_col_iter = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2),
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ iter_col_major=True,
+ tile_col_major=True,
+ )
+ assert tiles_tile_col_major_col_iter[0] == tile0_0
+ assert tiles_tile_col_major_col_iter[1] == tile1_0
+ assert tiles_tile_col_major_col_iter[3] == tile0_1
+
+ # tile_col_major and pattern_repeat
+ tiles_tile_col_major_pattern_repeat = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2),
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ pattern_repeat=2,
+ )
+ assert tiles_tile_col_major_pattern_repeat[0] == TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[2, 5, 14, 3], strides=[0, 84, 1, 28]
+ )
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669],
+ [ 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670],
+ [ 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671],
+ [ 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711],
+ [ 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292, 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712],
+ [ 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293, 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713],
+ [ 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333, 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753],
+ [ 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334, 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754],
+ [ 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335, 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755],
+ [ 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375, 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795],
+ [ 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376, 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796],
+ [ 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377, 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797],
+ [ 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417, 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837],
+ [ 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418, 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838],
+ [ 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419, 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839],
+ [1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089, 1470, 1473, 1476, 1479, 1482, 1485, 1488, 1491, 1494, 1497, 1500, 1503, 1506, 1509],
+ [1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090, 1471, 1474, 1477, 1480, 1483, 1486, 1489, 1492, 1495, 1498, 1501, 1504, 1507, 1510],
+ [1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091, 1472, 1475, 1478, 1481, 1484, 1487, 1490, 1493, 1496, 1499, 1502, 1505, 1508, 1511],
+ [1092, 1095, 1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131, 1512, 1515, 1518, 1521, 1524, 1527, 1530, 1533, 1536, 1539, 1542, 1545, 1548, 1551],
+ [1093, 1096, 1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132, 1513, 1516, 1519, 1522, 1525, 1528, 1531, 1534, 1537, 1540, 1543, 1546, 1549, 1552],
+ [1094, 1097, 1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133, 1514, 1517, 1520, 1523, 1526, 1529, 1532, 1535, 1538, 1541, 1544, 1547, 1550, 1553],
+ [1134, 1137, 1140, 1143, 1146, 1149, 1152, 1155, 1158, 1161, 1164, 1167, 1170, 1173, 1554, 1557, 1560, 1563, 1566, 1569, 1572, 1575, 1578, 1581, 1584, 1587, 1590, 1593],
+ [1135, 1138, 1141, 1144, 1147, 1150, 1153, 1156, 1159, 1162, 1165, 1168, 1171, 1174, 1555, 1558, 1561, 1564, 1567, 1570, 1573, 1576, 1579, 1582, 1585, 1588, 1591, 1594],
+ [1136, 1139, 1142, 1145, 1148, 1151, 1154, 1157, 1160, 1163, 1166, 1169, 1172, 1175, 1556, 1559, 1562, 1565, 1568, 1571, 1574, 1577, 1580, 1583, 1586, 1589, 1592, 1595],
+ [1176, 1179, 1182, 1185, 1188, 1191, 1194, 1197, 1200, 1203, 1206, 1209, 1212, 1215, 1596, 1599, 1602, 1605, 1608, 1611, 1614, 1617, 1620, 1623, 1626, 1629, 1632, 1635],
+ [1177, 1180, 1183, 1186, 1189, 1192, 1195, 1198, 1201, 1204, 1207, 1210, 1213, 1216, 1597, 1600, 1603, 1606, 1609, 1612, 1615, 1618, 1621, 1624, 1627, 1630, 1633, 1636],
+ [1178, 1181, 1184, 1187, 1190, 1193, 1196, 1199, 1202, 1205, 1208, 1211, 1214, 1217, 1598, 1601, 1604, 1607, 1610, 1613, 1616, 1619, 1622, 1625, 1628, 1631, 1634, 1637],
+ [1218, 1221, 1224, 1227, 1230, 1233, 1236, 1239, 1242, 1245, 1248, 1251, 1254, 1257, 1638, 1641, 1644, 1647, 1650, 1653, 1656, 1659, 1662, 1665, 1668, 1671, 1674, 1677],
+ [1219, 1222, 1225, 1228, 1231, 1234, 1237, 1240, 1243, 1246, 1249, 1252, 1255, 1258, 1639, 1642, 1645, 1648, 1651, 1654, 1657, 1660, 1663, 1666, 1669, 1672, 1675, 1678],
+ [1220, 1223, 1226, 1229, 1232, 1235, 1238, 1241, 1244, 1247, 1250, 1253, 1256, 1259, 1640, 1643, 1646, 1649, 1652, 1655, 1658, 1661, 1664, 1667, 1670, 1673, 1676, 1679],
+ [1890, 1893, 1896, 1899, 1902, 1905, 1908, 1911, 1914, 1917, 1920, 1923, 1926, 1929, 2310, 2313, 2316, 2319, 2322, 2325, 2328, 2331, 2334, 2337, 2340, 2343, 2346, 2349],
+ [1891, 1894, 1897, 1900, 1903, 1906, 1909, 1912, 1915, 1918, 1921, 1924, 1927, 1930, 2311, 2314, 2317, 2320, 2323, 2326, 2329, 2332, 2335, 2338, 2341, 2344, 2347, 2350],
+ [1892, 1895, 1898, 1901, 1904, 1907, 1910, 1913, 1916, 1919, 1922, 1925, 1928, 1931, 2312, 2315, 2318, 2321, 2324, 2327, 2330, 2333, 2336, 2339, 2342, 2345, 2348, 2351],
+ [1932, 1935, 1938, 1941, 1944, 1947, 1950, 1953, 1956, 1959, 1962, 1965, 1968, 1971, 2352, 2355, 2358, 2361, 2364, 2367, 2370, 2373, 2376, 2379, 2382, 2385, 2388, 2391],
+ [1933, 1936, 1939, 1942, 1945, 1948, 1951, 1954, 1957, 1960, 1963, 1966, 1969, 1972, 2353, 2356, 2359, 2362, 2365, 2368, 2371, 2374, 2377, 2380, 2383, 2386, 2389, 2392],
+ [1934, 1937, 1940, 1943, 1946, 1949, 1952, 1955, 1958, 1961, 1964, 1967, 1970, 1973, 2354, 2357, 2360, 2363, 2366, 2369, 2372, 2375, 2378, 2381, 2384, 2387, 2390, 2393],
+ [1974, 1977, 1980, 1983, 1986, 1989, 1992, 1995, 1998, 2001, 2004, 2007, 2010, 2013, 2394, 2397, 2400, 2403, 2406, 2409, 2412, 2415, 2418, 2421, 2424, 2427, 2430, 2433],
+ [1975, 1978, 1981, 1984, 1987, 1990, 1993, 1996, 1999, 2002, 2005, 2008, 2011, 2014, 2395, 2398, 2401, 2404, 2407, 2410, 2413, 2416, 2419, 2422, 2425, 2428, 2431, 2434],
+ [1976, 1979, 1982, 1985, 1988, 1991, 1994, 1997, 2000, 2003, 2006, 2009, 2012, 2015, 2396, 2399, 2402, 2405, 2408, 2411, 2414, 2417, 2420, 2423, 2426, 2429, 2432, 2435],
+ [2016, 2019, 2022, 2025, 2028, 2031, 2034, 2037, 2040, 2043, 2046, 2049, 2052, 2055, 2436, 2439, 2442, 2445, 2448, 2451, 2454, 2457, 2460, 2463, 2466, 2469, 2472, 2475],
+ [2017, 2020, 2023, 2026, 2029, 2032, 2035, 2038, 2041, 2044, 2047, 2050, 2053, 2056, 2437, 2440, 2443, 2446, 2449, 2452, 2455, 2458, 2461, 2464, 2467, 2470, 2473, 2476],
+ [2018, 2021, 2024, 2027, 2030, 2033, 2036, 2039, 2042, 2045, 2048, 2051, 2054, 2057, 2438, 2441, 2444, 2447, 2450, 2453, 2456, 2459, 2462, 2465, 2468, 2471, 2474, 2477],
+ [2058, 2061, 2064, 2067, 2070, 2073, 2076, 2079, 2082, 2085, 2088, 2091, 2094, 2097, 2478, 2481, 2484, 2487, 2490, 2493, 2496, 2499, 2502, 2505, 2508, 2511, 2514, 2517],
+ [2059, 2062, 2065, 2068, 2071, 2074, 2077, 2080, 2083, 2086, 2089, 2092, 2095, 2098, 2479, 2482, 2485, 2488, 2491, 2494, 2497, 2500, 2503, 2506, 2509, 2512, 2515, 2518],
+ [2060, 2063, 2066, 2069, 2072, 2075, 2078, 2081, 2084, 2087, 2090, 2093, 2096, 2099, 2480, 2483, 2486, 2489, 2492, 2495, 2498, 2501, 2504, 2507, 2510, 2513, 2516, 2519]])
+ # fmt: on
+ access_order, access_count = tiles_tile_col_major_pattern_repeat.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 2).all()
+
+ # tile_group_col_major
+ tiles_group_col_major = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2),
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_group_col_major=True,
+ )
+ tile0_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ )
+ assert tiles_group_col_major[0] == tile0_0
+ tile0_1 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ )
+ assert tiles_group_col_major[1] == tile0_1
+ tile1_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ )
+ assert tiles_group_col_major[2] == tile1_0
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331, 360, 361, 390, 391],
+ [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333, 362, 363, 392, 393],
+ [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335, 364, 365, 394, 395],
+ [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337, 366, 367, 396, 397],
+ [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339, 368, 369, 398, 399],
+ [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341, 370, 371, 400, 401],
+ [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343, 372, 373, 402, 403],
+ [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345, 374, 375, 404, 405],
+ [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347, 376, 377, 406, 407],
+ [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349, 378, 379, 408, 409],
+ [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351, 380, 381, 410, 411],
+ [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353, 382, 383, 412, 413],
+ [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355, 384, 385, 414, 415],
+ [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357, 386, 387, 416, 417],
+ [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359, 388, 389, 418, 419],
+ [ 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691, 720, 721, 750, 751, 780, 781, 810, 811],
+ [ 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693, 722, 723, 752, 753, 782, 783, 812, 813],
+ [ 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695, 724, 725, 754, 755, 784, 785, 814, 815],
+ [ 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697, 726, 727, 756, 757, 786, 787, 816, 817],
+ [ 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699, 728, 729, 758, 759, 788, 789, 818, 819],
+ [ 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701, 730, 731, 760, 761, 790, 791, 820, 821],
+ [ 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703, 732, 733, 762, 763, 792, 793, 822, 823],
+ [ 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705, 734, 735, 764, 765, 794, 795, 824, 825],
+ [ 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707, 736, 737, 766, 767, 796, 797, 826, 827],
+ [ 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709, 738, 739, 768, 769, 798, 799, 828, 829],
+ [ 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711, 740, 741, 770, 771, 800, 801, 830, 831],
+ [ 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713, 742, 743, 772, 773, 802, 803, 832, 833],
+ [ 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715, 744, 745, 774, 775, 804, 805, 834, 835],
+ [ 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717, 746, 747, 776, 777, 806, 807, 836, 837],
+ [ 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719, 748, 749, 778, 779, 808, 809, 838, 839],
+ [ 840, 841, 870, 871, 900, 901, 930, 931, 960, 961, 990, 991, 1020, 1021, 1050, 1051, 1080, 1081, 1110, 1111, 1140, 1141, 1170, 1171, 1200, 1201, 1230, 1231],
+ [ 842, 843, 872, 873, 902, 903, 932, 933, 962, 963, 992, 993, 1022, 1023, 1052, 1053, 1082, 1083, 1112, 1113, 1142, 1143, 1172, 1173, 1202, 1203, 1232, 1233],
+ [ 844, 845, 874, 875, 904, 905, 934, 935, 964, 965, 994, 995, 1024, 1025, 1054, 1055, 1084, 1085, 1114, 1115, 1144, 1145, 1174, 1175, 1204, 1205, 1234, 1235],
+ [ 846, 847, 876, 877, 906, 907, 936, 937, 966, 967, 996, 997, 1026, 1027, 1056, 1057, 1086, 1087, 1116, 1117, 1146, 1147, 1176, 1177, 1206, 1207, 1236, 1237],
+ [ 848, 849, 878, 879, 908, 909, 938, 939, 968, 969, 998, 999, 1028, 1029, 1058, 1059, 1088, 1089, 1118, 1119, 1148, 1149, 1178, 1179, 1208, 1209, 1238, 1239],
+ [ 850, 851, 880, 881, 910, 911, 940, 941, 970, 971, 1000, 1001, 1030, 1031, 1060, 1061, 1090, 1091, 1120, 1121, 1150, 1151, 1180, 1181, 1210, 1211, 1240, 1241],
+ [ 852, 853, 882, 883, 912, 913, 942, 943, 972, 973, 1002, 1003, 1032, 1033, 1062, 1063, 1092, 1093, 1122, 1123, 1152, 1153, 1182, 1183, 1212, 1213, 1242, 1243],
+ [ 854, 855, 884, 885, 914, 915, 944, 945, 974, 975, 1004, 1005, 1034, 1035, 1064, 1065, 1094, 1095, 1124, 1125, 1154, 1155, 1184, 1185, 1214, 1215, 1244, 1245],
+ [ 856, 857, 886, 887, 916, 917, 946, 947, 976, 977, 1006, 1007, 1036, 1037, 1066, 1067, 1096, 1097, 1126, 1127, 1156, 1157, 1186, 1187, 1216, 1217, 1246, 1247],
+ [ 858, 859, 888, 889, 918, 919, 948, 949, 978, 979, 1008, 1009, 1038, 1039, 1068, 1069, 1098, 1099, 1128, 1129, 1158, 1159, 1188, 1189, 1218, 1219, 1248, 1249],
+ [ 860, 861, 890, 891, 920, 921, 950, 951, 980, 981, 1010, 1011, 1040, 1041, 1070, 1071, 1100, 1101, 1130, 1131, 1160, 1161, 1190, 1191, 1220, 1221, 1250, 1251],
+ [ 862, 863, 892, 893, 922, 923, 952, 953, 982, 983, 1012, 1013, 1042, 1043, 1072, 1073, 1102, 1103, 1132, 1133, 1162, 1163, 1192, 1193, 1222, 1223, 1252, 1253],
+ [ 864, 865, 894, 895, 924, 925, 954, 955, 984, 985, 1014, 1015, 1044, 1045, 1074, 1075, 1104, 1105, 1134, 1135, 1164, 1165, 1194, 1195, 1224, 1225, 1254, 1255],
+ [ 866, 867, 896, 897, 926, 927, 956, 957, 986, 987, 1016, 1017, 1046, 1047, 1076, 1077, 1106, 1107, 1136, 1137, 1166, 1167, 1196, 1197, 1226, 1227, 1256, 1257],
+ [ 868, 869, 898, 899, 928, 929, 958, 959, 988, 989, 1018, 1019, 1048, 1049, 1078, 1079, 1108, 1109, 1138, 1139, 1168, 1169, 1198, 1199, 1228, 1229, 1258, 1259]])
+ # fmt: on
+ access_order, access_count = tiles_group_col_major.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # tile_group_col_major and tile_col_major
+ tiles_group_col_major = TensorTiler2D.group_tiler(
+ (3 * 5 * 3, 2 * 7 * 2),
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ tile_group_col_major=True,
+ )
+ tile0_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ )
+ assert tiles_group_col_major[0] == tile0_0
+ tile0_1 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ )
+ assert tiles_group_col_major[1] == tile0_1
+ tile1_0 = TensorTile(
+ (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ )
+ assert tiles_group_col_major[2] == tile1_0
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393],
+ [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394],
+ [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395],
+ [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399],
+ [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400],
+ [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401],
+ [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405],
+ [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406],
+ [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407],
+ [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411],
+ [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412],
+ [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413],
+ [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417],
+ [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418],
+ [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419],
+ [ 420, 423, 450, 453, 480, 483, 510, 513, 540, 543, 570, 573, 600, 603, 630, 633, 660, 663, 690, 693, 720, 723, 750, 753, 780, 783, 810, 813],
+ [ 421, 424, 451, 454, 481, 484, 511, 514, 541, 544, 571, 574, 601, 604, 631, 634, 661, 664, 691, 694, 721, 724, 751, 754, 781, 784, 811, 814],
+ [ 422, 425, 452, 455, 482, 485, 512, 515, 542, 545, 572, 575, 602, 605, 632, 635, 662, 665, 692, 695, 722, 725, 752, 755, 782, 785, 812, 815],
+ [ 426, 429, 456, 459, 486, 489, 516, 519, 546, 549, 576, 579, 606, 609, 636, 639, 666, 669, 696, 699, 726, 729, 756, 759, 786, 789, 816, 819],
+ [ 427, 430, 457, 460, 487, 490, 517, 520, 547, 550, 577, 580, 607, 610, 637, 640, 667, 670, 697, 700, 727, 730, 757, 760, 787, 790, 817, 820],
+ [ 428, 431, 458, 461, 488, 491, 518, 521, 548, 551, 578, 581, 608, 611, 638, 641, 668, 671, 698, 701, 728, 731, 758, 761, 788, 791, 818, 821],
+ [ 432, 435, 462, 465, 492, 495, 522, 525, 552, 555, 582, 585, 612, 615, 642, 645, 672, 675, 702, 705, 732, 735, 762, 765, 792, 795, 822, 825],
+ [ 433, 436, 463, 466, 493, 496, 523, 526, 553, 556, 583, 586, 613, 616, 643, 646, 673, 676, 703, 706, 733, 736, 763, 766, 793, 796, 823, 826],
+ [ 434, 437, 464, 467, 494, 497, 524, 527, 554, 557, 584, 587, 614, 617, 644, 647, 674, 677, 704, 707, 734, 737, 764, 767, 794, 797, 824, 827],
+ [ 438, 441, 468, 471, 498, 501, 528, 531, 558, 561, 588, 591, 618, 621, 648, 651, 678, 681, 708, 711, 738, 741, 768, 771, 798, 801, 828, 831],
+ [ 439, 442, 469, 472, 499, 502, 529, 532, 559, 562, 589, 592, 619, 622, 649, 652, 679, 682, 709, 712, 739, 742, 769, 772, 799, 802, 829, 832],
+ [ 440, 443, 470, 473, 500, 503, 530, 533, 560, 563, 590, 593, 620, 623, 650, 653, 680, 683, 710, 713, 740, 743, 770, 773, 800, 803, 830, 833],
+ [ 444, 447, 474, 477, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627, 654, 657, 684, 687, 714, 717, 744, 747, 774, 777, 804, 807, 834, 837],
+ [ 445, 448, 475, 478, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628, 655, 658, 685, 688, 715, 718, 745, 748, 775, 778, 805, 808, 835, 838],
+ [ 446, 449, 476, 479, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629, 656, 659, 686, 689, 716, 719, 746, 749, 776, 779, 806, 809, 836, 839],
+ [ 840, 843, 870, 873, 900, 903, 930, 933, 960, 963, 990, 993, 1020, 1023, 1050, 1053, 1080, 1083, 1110, 1113, 1140, 1143, 1170, 1173, 1200, 1203, 1230, 1233],
+ [ 841, 844, 871, 874, 901, 904, 931, 934, 961, 964, 991, 994, 1021, 1024, 1051, 1054, 1081, 1084, 1111, 1114, 1141, 1144, 1171, 1174, 1201, 1204, 1231, 1234],
+ [ 842, 845, 872, 875, 902, 905, 932, 935, 962, 965, 992, 995, 1022, 1025, 1052, 1055, 1082, 1085, 1112, 1115, 1142, 1145, 1172, 1175, 1202, 1205, 1232, 1235],
+ [ 846, 849, 876, 879, 906, 909, 936, 939, 966, 969, 996, 999, 1026, 1029, 1056, 1059, 1086, 1089, 1116, 1119, 1146, 1149, 1176, 1179, 1206, 1209, 1236, 1239],
+ [ 847, 850, 877, 880, 907, 910, 937, 940, 967, 970, 997, 1000, 1027, 1030, 1057, 1060, 1087, 1090, 1117, 1120, 1147, 1150, 1177, 1180, 1207, 1210, 1237, 1240],
+ [ 848, 851, 878, 881, 908, 911, 938, 941, 968, 971, 998, 1001, 1028, 1031, 1058, 1061, 1088, 1091, 1118, 1121, 1148, 1151, 1178, 1181, 1208, 1211, 1238, 1241],
+ [ 852, 855, 882, 885, 912, 915, 942, 945, 972, 975, 1002, 1005, 1032, 1035, 1062, 1065, 1092, 1095, 1122, 1125, 1152, 1155, 1182, 1185, 1212, 1215, 1242, 1245],
+ [ 853, 856, 883, 886, 913, 916, 943, 946, 973, 976, 1003, 1006, 1033, 1036, 1063, 1066, 1093, 1096, 1123, 1126, 1153, 1156, 1183, 1186, 1213, 1216, 1243, 1246],
+ [ 854, 857, 884, 887, 914, 917, 944, 947, 974, 977, 1004, 1007, 1034, 1037, 1064, 1067, 1094, 1097, 1124, 1127, 1154, 1157, 1184, 1187, 1214, 1217, 1244, 1247],
+ [ 858, 861, 888, 891, 918, 921, 948, 951, 978, 981, 1008, 1011, 1038, 1041, 1068, 1071, 1098, 1101, 1128, 1131, 1158, 1161, 1188, 1191, 1218, 1221, 1248, 1251],
+ [ 859, 862, 889, 892, 919, 922, 949, 952, 979, 982, 1009, 1012, 1039, 1042, 1069, 1072, 1099, 1102, 1129, 1132, 1159, 1162, 1189, 1192, 1219, 1222, 1249, 1252],
+ [ 860, 863, 890, 893, 920, 923, 950, 953, 980, 983, 1010, 1013, 1040, 1043, 1070, 1073, 1100, 1103, 1130, 1133, 1160, 1163, 1190, 1193, 1220, 1223, 1250, 1253],
+ [ 864, 867, 894, 897, 924, 927, 954, 957, 984, 987, 1014, 1017, 1044, 1047, 1074, 1077, 1104, 1107, 1134, 1137, 1164, 1167, 1194, 1197, 1224, 1227, 1254, 1257],
+ [ 865, 868, 895, 898, 925, 928, 955, 958, 985, 988, 1015, 1018, 1045, 1048, 1075, 1078, 1105, 1108, 1135, 1138, 1165, 1168, 1195, 1198, 1225, 1228, 1255, 1258],
+ [ 866, 869, 896, 899, 926, 929, 956, 959, 986, 989, 1016, 1019, 1046, 1049, 1076, 1079, 1106, 1109, 1136, 1139, 1166, 1169, 1196, 1199, 1226, 1229, 1256, 1259]])
+ # fmt: on
+ access_order, access_count = tiles_group_col_major.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: group_tiler_invalid
+@construct_test
+def group_tiler_invalid():
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (), (3, 2), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tensor dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (10, 9, 4), (3, 2), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too many tensor dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3, -1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3,), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too few tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (1, 1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too many tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3, 2), (1, 1), tile_col_major=True, pattern_repeat=0
+ )
+ raise ValueError("Invalid repeat.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (4, 2), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Indivisible tile (height)")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3, 3), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Indivisible tile (width)")
+ except ValueError:
+ # good
+ pass
+
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3, 2), (1,), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too few tile group dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3, 2), (1, -1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tile group dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler(
+ (9, 4), (3, 2), (1, 1, 1), tile_col_major=True
+ )
+ raise ValueError("Too many tile group dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler((18, 8), (3, 2), (2, 3), tile_col_major=True)
+ raise ValueError(
+ "Indivisible by tile repeat width (but without allow_partial)."
+ )
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.group_tiler((18, 8), (3, 2), (4, 2), tile_col_major=True)
+ raise ValueError(
+ "Indivisible by tile repeat height (but without allow_partial)."
+ )
+ except ValueError:
+ # good
+ pass
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/group_tiler_partial.py b/test/python/tensortiler/group_tiler_partial.py
new file mode 100644
index 0000000000..8f2f54370a
--- /dev/null
+++ b/test/python/tensortiler/group_tiler_partial.py
@@ -0,0 +1,1427 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile, TensorTileSequence, TensorTiler2D
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: group_tiler_partial_row
+@construct_test
+def group_tiler_partial_row():
+
+ tensor_dims = (3 * 5 * 3, 2 * 6 * 2)
+
+ # All row major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims, tile_dims=(3, 2), tile_group_dims=(5, 7), allow_partial=True
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 240, 241, 246, 247, 252, 253, 258, 259, 264, 265],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 242, 243, 248, 249, 254, 255, 260, 261, 266, 267],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 244, 245, 250, 251, 256, 257, 262, 263, 268, 269],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 270, 271, 276, 277, 282, 283, 288, 289, 294, 295],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 272, 273, 278, 279, 284, 285, 290, 291, 296, 297],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 274, 275, 280, 281, 286, 287, 292, 293, 298, 299],
+ [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325],
+ [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327],
+ [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329],
+ [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 330, 331, 336, 337, 342, 343, 348, 349, 354, 355],
+ [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 332, 333, 338, 339, 344, 345, 350, 351, 356, 357],
+ [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 334, 335, 340, 341, 346, 347, 352, 353, 358, 359],
+ [ 360, 361, 366, 367, 372, 373, 378, 379, 384, 385, 390, 391, 396, 397, 570, 571, 576, 577, 582, 583, 588, 589, 594, 595],
+ [ 362, 363, 368, 369, 374, 375, 380, 381, 386, 387, 392, 393, 398, 399, 572, 573, 578, 579, 584, 585, 590, 591, 596, 597],
+ [ 364, 365, 370, 371, 376, 377, 382, 383, 388, 389, 394, 395, 400, 401, 574, 575, 580, 581, 586, 587, 592, 593, 598, 599],
+ [ 402, 403, 408, 409, 414, 415, 420, 421, 426, 427, 432, 433, 438, 439, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625],
+ [ 404, 405, 410, 411, 416, 417, 422, 423, 428, 429, 434, 435, 440, 441, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627],
+ [ 406, 407, 412, 413, 418, 419, 424, 425, 430, 431, 436, 437, 442, 443, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629],
+ [ 444, 445, 450, 451, 456, 457, 462, 463, 468, 469, 474, 475, 480, 481, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655],
+ [ 446, 447, 452, 453, 458, 459, 464, 465, 470, 471, 476, 477, 482, 483, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657],
+ [ 448, 449, 454, 455, 460, 461, 466, 467, 472, 473, 478, 479, 484, 485, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659],
+ [ 486, 487, 492, 493, 498, 499, 504, 505, 510, 511, 516, 517, 522, 523, 660, 661, 666, 667, 672, 673, 678, 679, 684, 685],
+ [ 488, 489, 494, 495, 500, 501, 506, 507, 512, 513, 518, 519, 524, 525, 662, 663, 668, 669, 674, 675, 680, 681, 686, 687],
+ [ 490, 491, 496, 497, 502, 503, 508, 509, 514, 515, 520, 521, 526, 527, 664, 665, 670, 671, 676, 677, 682, 683, 688, 689],
+ [ 528, 529, 534, 535, 540, 541, 546, 547, 552, 553, 558, 559, 564, 565, 690, 691, 696, 697, 702, 703, 708, 709, 714, 715],
+ [ 530, 531, 536, 537, 542, 543, 548, 549, 554, 555, 560, 561, 566, 567, 692, 693, 698, 699, 704, 705, 710, 711, 716, 717],
+ [ 532, 533, 538, 539, 544, 545, 550, 551, 556, 557, 562, 563, 568, 569, 694, 695, 700, 701, 706, 707, 712, 713, 718, 719],
+ [ 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751, 756, 757, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955],
+ [ 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753, 758, 759, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957],
+ [ 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755, 760, 761, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959],
+ [ 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793, 798, 799, 960, 961, 966, 967, 972, 973, 978, 979, 984, 985],
+ [ 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795, 800, 801, 962, 963, 968, 969, 974, 975, 980, 981, 986, 987],
+ [ 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797, 802, 803, 964, 965, 970, 971, 976, 977, 982, 983, 988, 989],
+ [ 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835, 840, 841, 990, 991, 996, 997, 1002, 1003, 1008, 1009, 1014, 1015],
+ [ 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837, 842, 843, 992, 993, 998, 999, 1004, 1005, 1010, 1011, 1016, 1017],
+ [ 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839, 844, 845, 994, 995, 1000, 1001, 1006, 1007, 1012, 1013, 1018, 1019],
+ [ 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877, 882, 883, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045],
+ [ 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879, 884, 885, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047],
+ [ 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881, 886, 887, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049],
+ [ 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919, 924, 925, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075],
+ [ 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921, 926, 927, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077],
+ [ 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923, 928, 929, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237],
+ [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238],
+ [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239],
+ [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 240, 243, 246, 249, 252, 255, 258, 261, 264, 267],
+ [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 241, 244, 247, 250, 253, 256, 259, 262, 265, 268],
+ [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 242, 245, 248, 251, 254, 257, 260, 263, 266, 269],
+ [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 270, 273, 276, 279, 282, 285, 288, 291, 294, 297],
+ [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 271, 274, 277, 280, 283, 286, 289, 292, 295, 298],
+ [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 272, 275, 278, 281, 284, 287, 290, 293, 296, 299],
+ [ 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327],
+ [ 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328],
+ [ 128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329],
+ [ 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 330, 333, 336, 339, 342, 345, 348, 351, 354, 357],
+ [ 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 331, 334, 337, 340, 343, 346, 349, 352, 355, 358],
+ [ 170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 332, 335, 338, 341, 344, 347, 350, 353, 356, 359],
+ [ 360, 363, 366, 369, 372, 375, 378, 381, 384, 387, 390, 393, 396, 399, 570, 573, 576, 579, 582, 585, 588, 591, 594, 597],
+ [ 361, 364, 367, 370, 373, 376, 379, 382, 385, 388, 391, 394, 397, 400, 571, 574, 577, 580, 583, 586, 589, 592, 595, 598],
+ [ 362, 365, 368, 371, 374, 377, 380, 383, 386, 389, 392, 395, 398, 401, 572, 575, 578, 581, 584, 587, 590, 593, 596, 599],
+ [ 402, 405, 408, 411, 414, 417, 420, 423, 426, 429, 432, 435, 438, 441, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627],
+ [ 403, 406, 409, 412, 415, 418, 421, 424, 427, 430, 433, 436, 439, 442, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628],
+ [ 404, 407, 410, 413, 416, 419, 422, 425, 428, 431, 434, 437, 440, 443, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629],
+ [ 444, 447, 450, 453, 456, 459, 462, 465, 468, 471, 474, 477, 480, 483, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657],
+ [ 445, 448, 451, 454, 457, 460, 463, 466, 469, 472, 475, 478, 481, 484, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658],
+ [ 446, 449, 452, 455, 458, 461, 464, 467, 470, 473, 476, 479, 482, 485, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659],
+ [ 486, 489, 492, 495, 498, 501, 504, 507, 510, 513, 516, 519, 522, 525, 660, 663, 666, 669, 672, 675, 678, 681, 684, 687],
+ [ 487, 490, 493, 496, 499, 502, 505, 508, 511, 514, 517, 520, 523, 526, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688],
+ [ 488, 491, 494, 497, 500, 503, 506, 509, 512, 515, 518, 521, 524, 527, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689],
+ [ 528, 531, 534, 537, 540, 543, 546, 549, 552, 555, 558, 561, 564, 567, 690, 693, 696, 699, 702, 705, 708, 711, 714, 717],
+ [ 529, 532, 535, 538, 541, 544, 547, 550, 553, 556, 559, 562, 565, 568, 691, 694, 697, 700, 703, 706, 709, 712, 715, 718],
+ [ 530, 533, 536, 539, 542, 545, 548, 551, 554, 557, 560, 563, 566, 569, 692, 695, 698, 701, 704, 707, 710, 713, 716, 719],
+ [ 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753, 756, 759, 930, 933, 936, 939, 942, 945, 948, 951, 954, 957],
+ [ 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754, 757, 760, 931, 934, 937, 940, 943, 946, 949, 952, 955, 958],
+ [ 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755, 758, 761, 932, 935, 938, 941, 944, 947, 950, 953, 956, 959],
+ [ 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795, 798, 801, 960, 963, 966, 969, 972, 975, 978, 981, 984, 987],
+ [ 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796, 799, 802, 961, 964, 967, 970, 973, 976, 979, 982, 985, 988],
+ [ 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797, 800, 803, 962, 965, 968, 971, 974, 977, 980, 983, 986, 989],
+ [ 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837, 840, 843, 990, 993, 996, 999, 1002, 1005, 1008, 1011, 1014, 1017],
+ [ 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838, 841, 844, 991, 994, 997, 1000, 1003, 1006, 1009, 1012, 1015, 1018],
+ [ 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839, 842, 845, 992, 995, 998, 1001, 1004, 1007, 1010, 1013, 1016, 1019],
+ [ 846, 849, 852, 855, 858, 861, 864, 867, 870, 873, 876, 879, 882, 885, 1020, 1023, 1026, 1029, 1032, 1035, 1038, 1041, 1044, 1047],
+ [ 847, 850, 853, 856, 859, 862, 865, 868, 871, 874, 877, 880, 883, 886, 1021, 1024, 1027, 1030, 1033, 1036, 1039, 1042, 1045, 1048],
+ [ 848, 851, 854, 857, 860, 863, 866, 869, 872, 875, 878, 881, 884, 887, 1022, 1025, 1028, 1031, 1034, 1037, 1040, 1043, 1046, 1049],
+ [ 888, 891, 894, 897, 900, 903, 906, 909, 912, 915, 918, 921, 924, 927, 1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077],
+ [ 889, 892, 895, 898, 901, 904, 907, 910, 913, 916, 919, 922, 925, 928, 1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078],
+ [ 890, 893, 896, 899, 902, 905, 908, 911, 914, 917, 920, 923, 926, 929, 1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile group col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_group_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331],
+ [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333],
+ [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335],
+ [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337],
+ [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339],
+ [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341],
+ [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343],
+ [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345],
+ [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347],
+ [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349],
+ [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351],
+ [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353],
+ [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355],
+ [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357],
+ [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359],
+ [ 360, 361, 390, 391, 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691],
+ [ 362, 363, 392, 393, 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693],
+ [ 364, 365, 394, 395, 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695],
+ [ 366, 367, 396, 397, 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697],
+ [ 368, 369, 398, 399, 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699],
+ [ 370, 371, 400, 401, 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701],
+ [ 372, 373, 402, 403, 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703],
+ [ 374, 375, 404, 405, 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705],
+ [ 376, 377, 406, 407, 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707],
+ [ 378, 379, 408, 409, 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709],
+ [ 380, 381, 410, 411, 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711],
+ [ 382, 383, 412, 413, 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713],
+ [ 384, 385, 414, 415, 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715],
+ [ 386, 387, 416, 417, 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717],
+ [ 388, 389, 418, 419, 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719],
+ [ 720, 721, 750, 751, 780, 781, 810, 811, 840, 841, 870, 871, 900, 901, 930, 931, 960, 961, 990, 991, 1020, 1021, 1050, 1051],
+ [ 722, 723, 752, 753, 782, 783, 812, 813, 842, 843, 872, 873, 902, 903, 932, 933, 962, 963, 992, 993, 1022, 1023, 1052, 1053],
+ [ 724, 725, 754, 755, 784, 785, 814, 815, 844, 845, 874, 875, 904, 905, 934, 935, 964, 965, 994, 995, 1024, 1025, 1054, 1055],
+ [ 726, 727, 756, 757, 786, 787, 816, 817, 846, 847, 876, 877, 906, 907, 936, 937, 966, 967, 996, 997, 1026, 1027, 1056, 1057],
+ [ 728, 729, 758, 759, 788, 789, 818, 819, 848, 849, 878, 879, 908, 909, 938, 939, 968, 969, 998, 999, 1028, 1029, 1058, 1059],
+ [ 730, 731, 760, 761, 790, 791, 820, 821, 850, 851, 880, 881, 910, 911, 940, 941, 970, 971, 1000, 1001, 1030, 1031, 1060, 1061],
+ [ 732, 733, 762, 763, 792, 793, 822, 823, 852, 853, 882, 883, 912, 913, 942, 943, 972, 973, 1002, 1003, 1032, 1033, 1062, 1063],
+ [ 734, 735, 764, 765, 794, 795, 824, 825, 854, 855, 884, 885, 914, 915, 944, 945, 974, 975, 1004, 1005, 1034, 1035, 1064, 1065],
+ [ 736, 737, 766, 767, 796, 797, 826, 827, 856, 857, 886, 887, 916, 917, 946, 947, 976, 977, 1006, 1007, 1036, 1037, 1066, 1067],
+ [ 738, 739, 768, 769, 798, 799, 828, 829, 858, 859, 888, 889, 918, 919, 948, 949, 978, 979, 1008, 1009, 1038, 1039, 1068, 1069],
+ [ 740, 741, 770, 771, 800, 801, 830, 831, 860, 861, 890, 891, 920, 921, 950, 951, 980, 981, 1010, 1011, 1040, 1041, 1070, 1071],
+ [ 742, 743, 772, 773, 802, 803, 832, 833, 862, 863, 892, 893, 922, 923, 952, 953, 982, 983, 1012, 1013, 1042, 1043, 1072, 1073],
+ [ 744, 745, 774, 775, 804, 805, 834, 835, 864, 865, 894, 895, 924, 925, 954, 955, 984, 985, 1014, 1015, 1044, 1045, 1074, 1075],
+ [ 746, 747, 776, 777, 806, 807, 836, 837, 866, 867, 896, 897, 926, 927, 956, 957, 986, 987, 1016, 1017, 1046, 1047, 1076, 1077],
+ [ 748, 749, 778, 779, 808, 809, 838, 839, 868, 869, 898, 899, 928, 929, 958, 959, 988, 989, 1018, 1019, 1048, 1049, 1078, 1079]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # iter col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ iter_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 660, 661, 666, 667, 672, 673, 678, 679, 684, 685],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 662, 663, 668, 669, 674, 675, 680, 681, 686, 687],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 664, 665, 670, 671, 676, 677, 682, 683, 688, 689],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 690, 691, 696, 697, 702, 703, 708, 709, 714, 715],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 692, 693, 698, 699, 704, 705, 710, 711, 716, 717],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 694, 695, 700, 701, 706, 707, 712, 713, 718, 719],
+ [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745],
+ [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747],
+ [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749],
+ [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 750, 751, 756, 757, 762, 763, 768, 769, 774, 775],
+ [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 752, 753, 758, 759, 764, 765, 770, 771, 776, 777],
+ [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 754, 755, 760, 761, 766, 767, 772, 773, 778, 779],
+ [ 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 780, 781, 786, 787, 792, 793, 798, 799, 804, 805],
+ [ 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 782, 783, 788, 789, 794, 795, 800, 801, 806, 807],
+ [ 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 784, 785, 790, 791, 796, 797, 802, 803, 808, 809],
+ [ 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835],
+ [ 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837],
+ [ 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839],
+ [ 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 840, 841, 846, 847, 852, 853, 858, 859, 864, 865],
+ [ 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 842, 843, 848, 849, 854, 855, 860, 861, 866, 867],
+ [ 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 844, 845, 850, 851, 856, 857, 862, 863, 868, 869],
+ [ 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 870, 871, 876, 877, 882, 883, 888, 889, 894, 895],
+ [ 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 872, 873, 878, 879, 884, 885, 890, 891, 896, 897],
+ [ 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 874, 875, 880, 881, 886, 887, 892, 893, 898, 899],
+ [ 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 900, 901, 906, 907, 912, 913, 918, 919, 924, 925],
+ [ 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 902, 903, 908, 909, 914, 915, 920, 921, 926, 927],
+ [ 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 904, 905, 910, 911, 916, 917, 922, 923, 928, 929],
+ [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955],
+ [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957],
+ [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959],
+ [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 960, 961, 966, 967, 972, 973, 978, 979, 984, 985],
+ [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 962, 963, 968, 969, 974, 975, 980, 981, 986, 987],
+ [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 964, 965, 970, 971, 976, 977, 982, 983, 988, 989],
+ [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 990, 991, 996, 997, 1002, 1003, 1008, 1009, 1014, 1015],
+ [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 992, 993, 998, 999, 1004, 1005, 1010, 1011, 1016, 1017],
+ [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 994, 995, 1000, 1001, 1006, 1007, 1012, 1013, 1018, 1019],
+ [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045],
+ [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047],
+ [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049],
+ [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075],
+ [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077],
+ [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # all col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ tile_group_col_major=True,
+ iter_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 630, 633, 660, 663, 690, 693, 720, 723, 750, 753],
+ [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 631, 634, 661, 664, 691, 694, 721, 724, 751, 754],
+ [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 632, 635, 662, 665, 692, 695, 722, 725, 752, 755],
+ [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 636, 639, 666, 669, 696, 699, 726, 729, 756, 759],
+ [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 637, 640, 667, 670, 697, 700, 727, 730, 757, 760],
+ [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 638, 641, 668, 671, 698, 701, 728, 731, 758, 761],
+ [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 642, 645, 672, 675, 702, 705, 732, 735, 762, 765],
+ [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 643, 646, 673, 676, 703, 706, 733, 736, 763, 766],
+ [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 644, 647, 674, 677, 704, 707, 734, 737, 764, 767],
+ [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 648, 651, 678, 681, 708, 711, 738, 741, 768, 771],
+ [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 649, 652, 679, 682, 709, 712, 739, 742, 769, 772],
+ [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 650, 653, 680, 683, 710, 713, 740, 743, 770, 773],
+ [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 654, 657, 684, 687, 714, 717, 744, 747, 774, 777],
+ [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 655, 658, 685, 688, 715, 718, 745, 748, 775, 778],
+ [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 656, 659, 686, 689, 716, 719, 746, 749, 776, 779],
+ [ 210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393, 780, 783, 810, 813, 840, 843, 870, 873, 900, 903],
+ [ 211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394, 781, 784, 811, 814, 841, 844, 871, 874, 901, 904],
+ [ 212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395, 782, 785, 812, 815, 842, 845, 872, 875, 902, 905],
+ [ 216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399, 786, 789, 816, 819, 846, 849, 876, 879, 906, 909],
+ [ 217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400, 787, 790, 817, 820, 847, 850, 877, 880, 907, 910],
+ [ 218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401, 788, 791, 818, 821, 848, 851, 878, 881, 908, 911],
+ [ 222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405, 792, 795, 822, 825, 852, 855, 882, 885, 912, 915],
+ [ 223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406, 793, 796, 823, 826, 853, 856, 883, 886, 913, 916],
+ [ 224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407, 794, 797, 824, 827, 854, 857, 884, 887, 914, 917],
+ [ 228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411, 798, 801, 828, 831, 858, 861, 888, 891, 918, 921],
+ [ 229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412, 799, 802, 829, 832, 859, 862, 889, 892, 919, 922],
+ [ 230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413, 800, 803, 830, 833, 860, 863, 890, 893, 920, 923],
+ [ 234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417, 804, 807, 834, 837, 864, 867, 894, 897, 924, 927],
+ [ 235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418, 805, 808, 835, 838, 865, 868, 895, 898, 925, 928],
+ [ 236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419, 806, 809, 836, 839, 866, 869, 896, 899, 926, 929],
+ [ 420, 423, 450, 453, 480, 483, 510, 513, 540, 543, 570, 573, 600, 603, 930, 933, 960, 963, 990, 993, 1020, 1023, 1050, 1053],
+ [ 421, 424, 451, 454, 481, 484, 511, 514, 541, 544, 571, 574, 601, 604, 931, 934, 961, 964, 991, 994, 1021, 1024, 1051, 1054],
+ [ 422, 425, 452, 455, 482, 485, 512, 515, 542, 545, 572, 575, 602, 605, 932, 935, 962, 965, 992, 995, 1022, 1025, 1052, 1055],
+ [ 426, 429, 456, 459, 486, 489, 516, 519, 546, 549, 576, 579, 606, 609, 936, 939, 966, 969, 996, 999, 1026, 1029, 1056, 1059],
+ [ 427, 430, 457, 460, 487, 490, 517, 520, 547, 550, 577, 580, 607, 610, 937, 940, 967, 970, 997, 1000, 1027, 1030, 1057, 1060],
+ [ 428, 431, 458, 461, 488, 491, 518, 521, 548, 551, 578, 581, 608, 611, 938, 941, 968, 971, 998, 1001, 1028, 1031, 1058, 1061],
+ [ 432, 435, 462, 465, 492, 495, 522, 525, 552, 555, 582, 585, 612, 615, 942, 945, 972, 975, 1002, 1005, 1032, 1035, 1062, 1065],
+ [ 433, 436, 463, 466, 493, 496, 523, 526, 553, 556, 583, 586, 613, 616, 943, 946, 973, 976, 1003, 1006, 1033, 1036, 1063, 1066],
+ [ 434, 437, 464, 467, 494, 497, 524, 527, 554, 557, 584, 587, 614, 617, 944, 947, 974, 977, 1004, 1007, 1034, 1037, 1064, 1067],
+ [ 438, 441, 468, 471, 498, 501, 528, 531, 558, 561, 588, 591, 618, 621, 948, 951, 978, 981, 1008, 1011, 1038, 1041, 1068, 1071],
+ [ 439, 442, 469, 472, 499, 502, 529, 532, 559, 562, 589, 592, 619, 622, 949, 952, 979, 982, 1009, 1012, 1039, 1042, 1069, 1072],
+ [ 440, 443, 470, 473, 500, 503, 530, 533, 560, 563, 590, 593, 620, 623, 950, 953, 980, 983, 1010, 1013, 1040, 1043, 1070, 1073],
+ [ 444, 447, 474, 477, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627, 954, 957, 984, 987, 1014, 1017, 1044, 1047, 1074, 1077],
+ [ 445, 448, 475, 478, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628, 955, 958, 985, 988, 1015, 1018, 1045, 1048, 1075, 1078],
+ [ 446, 449, 476, 479, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629, 956, 959, 986, 989, 1016, 1019, 1046, 1049, 1076, 1079]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # pattern repeat
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ allow_partial=True,
+ pattern_repeat=4,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[4, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[4, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[4, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[4, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[4, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[4, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669, 1290, 1293, 1296, 1299, 1302, 1305, 1308, 1311, 1314, 1317],
+ [ 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670, 1291, 1294, 1297, 1300, 1303, 1306, 1309, 1312, 1315, 1318],
+ [ 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671, 1292, 1295, 1298, 1301, 1304, 1307, 1310, 1313, 1316, 1319],
+ [ 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711, 1320, 1323, 1326, 1329, 1332, 1335, 1338, 1341, 1344, 1347],
+ [ 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712, 1321, 1324, 1327, 1330, 1333, 1336, 1339, 1342, 1345, 1348],
+ [ 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713, 1322, 1325, 1328, 1331, 1334, 1337, 1340, 1343, 1346, 1349],
+ [ 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753, 1350, 1353, 1356, 1359, 1362, 1365, 1368, 1371, 1374, 1377],
+ [ 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754, 1351, 1354, 1357, 1360, 1363, 1366, 1369, 1372, 1375, 1378],
+ [ 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755, 1352, 1355, 1358, 1361, 1364, 1367, 1370, 1373, 1376, 1379],
+ [ 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795, 1380, 1383, 1386, 1389, 1392, 1395, 1398, 1401, 1404, 1407],
+ [ 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796, 1381, 1384, 1387, 1390, 1393, 1396, 1399, 1402, 1405, 1408],
+ [ 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797, 1382, 1385, 1388, 1391, 1394, 1397, 1400, 1403, 1406, 1409],
+ [ 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837, 1410, 1413, 1416, 1419, 1422, 1425, 1428, 1431, 1434, 1437],
+ [ 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838, 1411, 1414, 1417, 1420, 1423, 1426, 1429, 1432, 1435, 1438],
+ [ 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839, 1412, 1415, 1418, 1421, 1424, 1427, 1430, 1433, 1436, 1439],
+ [2070, 2073, 2076, 2079, 2082, 2085, 2088, 2091, 2094, 2097, 2100, 2103, 2106, 2109, 2730, 2733, 2736, 2739, 2742, 2745, 2748, 2751, 2754, 2757],
+ [2071, 2074, 2077, 2080, 2083, 2086, 2089, 2092, 2095, 2098, 2101, 2104, 2107, 2110, 2731, 2734, 2737, 2740, 2743, 2746, 2749, 2752, 2755, 2758],
+ [2072, 2075, 2078, 2081, 2084, 2087, 2090, 2093, 2096, 2099, 2102, 2105, 2108, 2111, 2732, 2735, 2738, 2741, 2744, 2747, 2750, 2753, 2756, 2759],
+ [2112, 2115, 2118, 2121, 2124, 2127, 2130, 2133, 2136, 2139, 2142, 2145, 2148, 2151, 2760, 2763, 2766, 2769, 2772, 2775, 2778, 2781, 2784, 2787],
+ [2113, 2116, 2119, 2122, 2125, 2128, 2131, 2134, 2137, 2140, 2143, 2146, 2149, 2152, 2761, 2764, 2767, 2770, 2773, 2776, 2779, 2782, 2785, 2788],
+ [2114, 2117, 2120, 2123, 2126, 2129, 2132, 2135, 2138, 2141, 2144, 2147, 2150, 2153, 2762, 2765, 2768, 2771, 2774, 2777, 2780, 2783, 2786, 2789],
+ [2154, 2157, 2160, 2163, 2166, 2169, 2172, 2175, 2178, 2181, 2184, 2187, 2190, 2193, 2790, 2793, 2796, 2799, 2802, 2805, 2808, 2811, 2814, 2817],
+ [2155, 2158, 2161, 2164, 2167, 2170, 2173, 2176, 2179, 2182, 2185, 2188, 2191, 2194, 2791, 2794, 2797, 2800, 2803, 2806, 2809, 2812, 2815, 2818],
+ [2156, 2159, 2162, 2165, 2168, 2171, 2174, 2177, 2180, 2183, 2186, 2189, 2192, 2195, 2792, 2795, 2798, 2801, 2804, 2807, 2810, 2813, 2816, 2819],
+ [2196, 2199, 2202, 2205, 2208, 2211, 2214, 2217, 2220, 2223, 2226, 2229, 2232, 2235, 2820, 2823, 2826, 2829, 2832, 2835, 2838, 2841, 2844, 2847],
+ [2197, 2200, 2203, 2206, 2209, 2212, 2215, 2218, 2221, 2224, 2227, 2230, 2233, 2236, 2821, 2824, 2827, 2830, 2833, 2836, 2839, 2842, 2845, 2848],
+ [2198, 2201, 2204, 2207, 2210, 2213, 2216, 2219, 2222, 2225, 2228, 2231, 2234, 2237, 2822, 2825, 2828, 2831, 2834, 2837, 2840, 2843, 2846, 2849],
+ [2238, 2241, 2244, 2247, 2250, 2253, 2256, 2259, 2262, 2265, 2268, 2271, 2274, 2277, 2850, 2853, 2856, 2859, 2862, 2865, 2868, 2871, 2874, 2877],
+ [2239, 2242, 2245, 2248, 2251, 2254, 2257, 2260, 2263, 2266, 2269, 2272, 2275, 2278, 2851, 2854, 2857, 2860, 2863, 2866, 2869, 2872, 2875, 2878],
+ [2240, 2243, 2246, 2249, 2252, 2255, 2258, 2261, 2264, 2267, 2270, 2273, 2276, 2279, 2852, 2855, 2858, 2861, 2864, 2867, 2870, 2873, 2876, 2879],
+ [3510, 3513, 3516, 3519, 3522, 3525, 3528, 3531, 3534, 3537, 3540, 3543, 3546, 3549, 4170, 4173, 4176, 4179, 4182, 4185, 4188, 4191, 4194, 4197],
+ [3511, 3514, 3517, 3520, 3523, 3526, 3529, 3532, 3535, 3538, 3541, 3544, 3547, 3550, 4171, 4174, 4177, 4180, 4183, 4186, 4189, 4192, 4195, 4198],
+ [3512, 3515, 3518, 3521, 3524, 3527, 3530, 3533, 3536, 3539, 3542, 3545, 3548, 3551, 4172, 4175, 4178, 4181, 4184, 4187, 4190, 4193, 4196, 4199],
+ [3552, 3555, 3558, 3561, 3564, 3567, 3570, 3573, 3576, 3579, 3582, 3585, 3588, 3591, 4200, 4203, 4206, 4209, 4212, 4215, 4218, 4221, 4224, 4227],
+ [3553, 3556, 3559, 3562, 3565, 3568, 3571, 3574, 3577, 3580, 3583, 3586, 3589, 3592, 4201, 4204, 4207, 4210, 4213, 4216, 4219, 4222, 4225, 4228],
+ [3554, 3557, 3560, 3563, 3566, 3569, 3572, 3575, 3578, 3581, 3584, 3587, 3590, 3593, 4202, 4205, 4208, 4211, 4214, 4217, 4220, 4223, 4226, 4229],
+ [3594, 3597, 3600, 3603, 3606, 3609, 3612, 3615, 3618, 3621, 3624, 3627, 3630, 3633, 4230, 4233, 4236, 4239, 4242, 4245, 4248, 4251, 4254, 4257],
+ [3595, 3598, 3601, 3604, 3607, 3610, 3613, 3616, 3619, 3622, 3625, 3628, 3631, 3634, 4231, 4234, 4237, 4240, 4243, 4246, 4249, 4252, 4255, 4258],
+ [3596, 3599, 3602, 3605, 3608, 3611, 3614, 3617, 3620, 3623, 3626, 3629, 3632, 3635, 4232, 4235, 4238, 4241, 4244, 4247, 4250, 4253, 4256, 4259],
+ [3636, 3639, 3642, 3645, 3648, 3651, 3654, 3657, 3660, 3663, 3666, 3669, 3672, 3675, 4260, 4263, 4266, 4269, 4272, 4275, 4278, 4281, 4284, 4287],
+ [3637, 3640, 3643, 3646, 3649, 3652, 3655, 3658, 3661, 3664, 3667, 3670, 3673, 3676, 4261, 4264, 4267, 4270, 4273, 4276, 4279, 4282, 4285, 4288],
+ [3638, 3641, 3644, 3647, 3650, 3653, 3656, 3659, 3662, 3665, 3668, 3671, 3674, 3677, 4262, 4265, 4268, 4271, 4274, 4277, 4280, 4283, 4286, 4289],
+ [3678, 3681, 3684, 3687, 3690, 3693, 3696, 3699, 3702, 3705, 3708, 3711, 3714, 3717, 4290, 4293, 4296, 4299, 4302, 4305, 4308, 4311, 4314, 4317],
+ [3679, 3682, 3685, 3688, 3691, 3694, 3697, 3700, 3703, 3706, 3709, 3712, 3715, 3718, 4291, 4294, 4297, 4300, 4303, 4306, 4309, 4312, 4315, 4318],
+ [3680, 3683, 3686, 3689, 3692, 3695, 3698, 3701, 3704, 3707, 3710, 3713, 3716, 3719, 4292, 4295, 4298, 4301, 4304, 4307, 4310, 4313, 4316, 4319]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 4).all()
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: group_tiler_partial_col
+@construct_test
+def group_tiler_partial_col():
+
+ # All row major
+ tensor_dims = (3 * 4 * 3, 2 * 7 * 2)
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims, tile_dims=(3, 2), tile_group_dims=(5, 7), allow_partial=True
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=420, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=434, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=840, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=854, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335],
+ [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373],
+ [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375],
+ [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377],
+ [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415],
+ [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417],
+ [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419],
+ [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667],
+ [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669],
+ [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671],
+ [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709],
+ [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711],
+ [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713],
+ [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751],
+ [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753],
+ [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755],
+ [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793],
+ [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795],
+ [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797],
+ [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835],
+ [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837],
+ [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839],
+ [ 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877, 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961],
+ [ 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879, 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963],
+ [ 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881, 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965],
+ [ 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919, 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003],
+ [ 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921, 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005],
+ [ 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923, 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=420, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=434, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=840, sizes=[1, 2, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=854, sizes=[1, 2, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249],
+ [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250],
+ [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251],
+ [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291],
+ [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292],
+ [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293],
+ [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333],
+ [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334],
+ [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335],
+ [ 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375],
+ [ 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376],
+ [ 128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377],
+ [ 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417],
+ [ 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418],
+ [ 170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419],
+ [ 420, 423, 426, 429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669],
+ [ 421, 424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670],
+ [ 422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458, 461, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671],
+ [ 462, 465, 468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711],
+ [ 463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499, 502, 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712],
+ [ 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497, 500, 503, 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713],
+ [ 504, 507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543, 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753],
+ [ 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538, 541, 544, 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754],
+ [ 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536, 539, 542, 545, 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755],
+ [ 546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582, 585, 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795],
+ [ 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577, 580, 583, 586, 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796],
+ [ 548, 551, 554, 557, 560, 563, 566, 569, 572, 575, 578, 581, 584, 587, 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797],
+ [ 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627, 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837],
+ [ 589, 592, 595, 598, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628, 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838],
+ [ 590, 593, 596, 599, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629, 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839],
+ [ 840, 843, 846, 849, 852, 855, 858, 861, 864, 867, 870, 873, 876, 879, 924, 927, 930, 933, 936, 939, 942, 945, 948, 951, 954, 957, 960, 963],
+ [ 841, 844, 847, 850, 853, 856, 859, 862, 865, 868, 871, 874, 877, 880, 925, 928, 931, 934, 937, 940, 943, 946, 949, 952, 955, 958, 961, 964],
+ [ 842, 845, 848, 851, 854, 857, 860, 863, 866, 869, 872, 875, 878, 881, 926, 929, 932, 935, 938, 941, 944, 947, 950, 953, 956, 959, 962, 965],
+ [ 882, 885, 888, 891, 894, 897, 900, 903, 906, 909, 912, 915, 918, 921, 966, 969, 972, 975, 978, 981, 984, 987, 990, 993, 996, 999, 1002, 1005],
+ [ 883, 886, 889, 892, 895, 898, 901, 904, 907, 910, 913, 916, 919, 922, 967, 970, 973, 976, 979, 982, 985, 988, 991, 994, 997, 1000, 1003, 1006],
+ [ 884, 887, 890, 893, 896, 899, 902, 905, 908, 911, 914, 917, 920, 923, 968, 971, 974, 977, 980, 983, 986, 989, 992, 995, 998, 1001, 1004, 1007]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile group col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_group_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=420, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=434, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=840, sizes=[1, 7, 6, 2], strides=[0, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=854, sizes=[1, 7, 6, 2], strides=[0, 2, 28, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331, 360, 361, 390, 391],
+ [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333, 362, 363, 392, 393],
+ [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335, 364, 365, 394, 395],
+ [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337, 366, 367, 396, 397],
+ [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339, 368, 369, 398, 399],
+ [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341, 370, 371, 400, 401],
+ [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343, 372, 373, 402, 403],
+ [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345, 374, 375, 404, 405],
+ [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347, 376, 377, 406, 407],
+ [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349, 378, 379, 408, 409],
+ [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351, 380, 381, 410, 411],
+ [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353, 382, 383, 412, 413],
+ [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355, 384, 385, 414, 415],
+ [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357, 386, 387, 416, 417],
+ [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359, 388, 389, 418, 419],
+ [ 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691, 720, 721, 750, 751, 780, 781, 810, 811],
+ [ 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693, 722, 723, 752, 753, 782, 783, 812, 813],
+ [ 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695, 724, 725, 754, 755, 784, 785, 814, 815],
+ [ 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697, 726, 727, 756, 757, 786, 787, 816, 817],
+ [ 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699, 728, 729, 758, 759, 788, 789, 818, 819],
+ [ 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701, 730, 731, 760, 761, 790, 791, 820, 821],
+ [ 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703, 732, 733, 762, 763, 792, 793, 822, 823],
+ [ 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705, 734, 735, 764, 765, 794, 795, 824, 825],
+ [ 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707, 736, 737, 766, 767, 796, 797, 826, 827],
+ [ 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709, 738, 739, 768, 769, 798, 799, 828, 829],
+ [ 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711, 740, 741, 770, 771, 800, 801, 830, 831],
+ [ 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713, 742, 743, 772, 773, 802, 803, 832, 833],
+ [ 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715, 744, 745, 774, 775, 804, 805, 834, 835],
+ [ 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717, 746, 747, 776, 777, 806, 807, 836, 837],
+ [ 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719, 748, 749, 778, 779, 808, 809, 838, 839],
+ [ 840, 841, 852, 853, 864, 865, 876, 877, 888, 889, 900, 901, 912, 913, 924, 925, 936, 937, 948, 949, 960, 961, 972, 973, 984, 985, 996, 997],
+ [ 842, 843, 854, 855, 866, 867, 878, 879, 890, 891, 902, 903, 914, 915, 926, 927, 938, 939, 950, 951, 962, 963, 974, 975, 986, 987, 998, 999],
+ [ 844, 845, 856, 857, 868, 869, 880, 881, 892, 893, 904, 905, 916, 917, 928, 929, 940, 941, 952, 953, 964, 965, 976, 977, 988, 989, 1000, 1001],
+ [ 846, 847, 858, 859, 870, 871, 882, 883, 894, 895, 906, 907, 918, 919, 930, 931, 942, 943, 954, 955, 966, 967, 978, 979, 990, 991, 1002, 1003],
+ [ 848, 849, 860, 861, 872, 873, 884, 885, 896, 897, 908, 909, 920, 921, 932, 933, 944, 945, 956, 957, 968, 969, 980, 981, 992, 993, 1004, 1005],
+ [ 850, 851, 862, 863, 874, 875, 886, 887, 898, 899, 910, 911, 922, 923, 934, 935, 946, 947, 958, 959, 970, 971, 982, 983, 994, 995, 1006, 1007]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # iter col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ iter_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=420, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=840, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=434, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=854, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629],
+ [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667],
+ [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669],
+ [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671],
+ [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709],
+ [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711],
+ [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713],
+ [ 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751],
+ [ 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753],
+ [ 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755],
+ [ 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793],
+ [ 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795],
+ [ 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797],
+ [ 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835],
+ [ 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837],
+ [ 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839],
+ [ 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877],
+ [ 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879],
+ [ 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881],
+ [ 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919],
+ [ 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921],
+ [ 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923],
+ [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961],
+ [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963],
+ [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965],
+ [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003],
+ [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005],
+ [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # all col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ tile_group_col_major=True,
+ iter_col_major=True,
+ allow_partial=True,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=420, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=840, sizes=[7, 2, 2, 3], strides=[2, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=434, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=854, sizes=[7, 2, 2, 3], strides=[2, 84, 1, 28]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627, 654, 657, 684, 687],
+ [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628, 655, 658, 685, 688],
+ [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629, 656, 659, 686, 689],
+ [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 510, 513, 540, 543, 570, 573, 600, 603, 630, 633, 660, 663, 690, 693],
+ [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 511, 514, 541, 544, 571, 574, 601, 604, 631, 634, 661, 664, 691, 694],
+ [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 512, 515, 542, 545, 572, 575, 602, 605, 632, 635, 662, 665, 692, 695],
+ [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 516, 519, 546, 549, 576, 579, 606, 609, 636, 639, 666, 669, 696, 699],
+ [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 517, 520, 547, 550, 577, 580, 607, 610, 637, 640, 667, 670, 697, 700],
+ [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 518, 521, 548, 551, 578, 581, 608, 611, 638, 641, 668, 671, 698, 701],
+ [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 522, 525, 552, 555, 582, 585, 612, 615, 642, 645, 672, 675, 702, 705],
+ [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 523, 526, 553, 556, 583, 586, 613, 616, 643, 646, 673, 676, 703, 706],
+ [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 524, 527, 554, 557, 584, 587, 614, 617, 644, 647, 674, 677, 704, 707],
+ [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 528, 531, 558, 561, 588, 591, 618, 621, 648, 651, 678, 681, 708, 711],
+ [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 529, 532, 559, 562, 589, 592, 619, 622, 649, 652, 679, 682, 709, 712],
+ [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 530, 533, 560, 563, 590, 593, 620, 623, 650, 653, 680, 683, 710, 713],
+ [ 210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393, 714, 717, 744, 747, 774, 777, 804, 807, 834, 837, 864, 867, 894, 897],
+ [ 211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394, 715, 718, 745, 748, 775, 778, 805, 808, 835, 838, 865, 868, 895, 898],
+ [ 212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395, 716, 719, 746, 749, 776, 779, 806, 809, 836, 839, 866, 869, 896, 899],
+ [ 216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399, 720, 723, 750, 753, 780, 783, 810, 813, 840, 843, 870, 873, 900, 903],
+ [ 217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400, 721, 724, 751, 754, 781, 784, 811, 814, 841, 844, 871, 874, 901, 904],
+ [ 218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401, 722, 725, 752, 755, 782, 785, 812, 815, 842, 845, 872, 875, 902, 905],
+ [ 222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405, 726, 729, 756, 759, 786, 789, 816, 819, 846, 849, 876, 879, 906, 909],
+ [ 223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406, 727, 730, 757, 760, 787, 790, 817, 820, 847, 850, 877, 880, 907, 910],
+ [ 224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407, 728, 731, 758, 761, 788, 791, 818, 821, 848, 851, 878, 881, 908, 911],
+ [ 228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411, 732, 735, 762, 765, 792, 795, 822, 825, 852, 855, 882, 885, 912, 915],
+ [ 229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412, 733, 736, 763, 766, 793, 796, 823, 826, 853, 856, 883, 886, 913, 916],
+ [ 230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413, 734, 737, 764, 767, 794, 797, 824, 827, 854, 857, 884, 887, 914, 917],
+ [ 234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417, 738, 741, 768, 771, 798, 801, 828, 831, 858, 861, 888, 891, 918, 921],
+ [ 235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418, 739, 742, 769, 772, 799, 802, 829, 832, 859, 862, 889, 892, 919, 922],
+ [ 236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419, 740, 743, 770, 773, 800, 803, 830, 833, 860, 863, 890, 893, 920, 923],
+ [ 420, 423, 432, 435, 444, 447, 456, 459, 468, 471, 480, 483, 492, 495, 924, 927, 936, 939, 948, 951, 960, 963, 972, 975, 984, 987, 996, 999],
+ [ 421, 424, 433, 436, 445, 448, 457, 460, 469, 472, 481, 484, 493, 496, 925, 928, 937, 940, 949, 952, 961, 964, 973, 976, 985, 988, 997, 1000],
+ [ 422, 425, 434, 437, 446, 449, 458, 461, 470, 473, 482, 485, 494, 497, 926, 929, 938, 941, 950, 953, 962, 965, 974, 977, 986, 989, 998, 1001],
+ [ 426, 429, 438, 441, 450, 453, 462, 465, 474, 477, 486, 489, 498, 501, 930, 933, 942, 945, 954, 957, 966, 969, 978, 981, 990, 993, 1002, 1005],
+ [ 427, 430, 439, 442, 451, 454, 463, 466, 475, 478, 487, 490, 499, 502, 931, 934, 943, 946, 955, 958, 967, 970, 979, 982, 991, 994, 1003, 1006],
+ [ 428, 431, 440, 443, 452, 455, 464, 467, 476, 479, 488, 491, 500, 503, 932, 935, 944, 947, 956, 959, 968, 971, 980, 983, 992, 995, 1004, 1007]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # pattern repeat
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ allow_partial=True,
+ pattern_repeat=3,
+ )
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=420, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=434, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=840, sizes=[3, 2, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ TensorTile(
+ tensor_dims, offset=854, sizes=[3, 2, 14, 3], strides=[0, 84, 1, 28]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 420, 423, 426, 429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089],
+ [ 421, 424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460, 1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090],
+ [ 422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458, 461, 1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091],
+ [ 462, 465, 468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 1092, 1095, 1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131],
+ [ 463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499, 502, 1093, 1096, 1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132],
+ [ 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497, 500, 503, 1094, 1097, 1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133],
+ [ 504, 507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543, 1134, 1137, 1140, 1143, 1146, 1149, 1152, 1155, 1158, 1161, 1164, 1167, 1170, 1173],
+ [ 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538, 541, 544, 1135, 1138, 1141, 1144, 1147, 1150, 1153, 1156, 1159, 1162, 1165, 1168, 1171, 1174],
+ [ 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536, 539, 542, 545, 1136, 1139, 1142, 1145, 1148, 1151, 1154, 1157, 1160, 1163, 1166, 1169, 1172, 1175],
+ [ 546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582, 585, 1176, 1179, 1182, 1185, 1188, 1191, 1194, 1197, 1200, 1203, 1206, 1209, 1212, 1215],
+ [ 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577, 580, 583, 586, 1177, 1180, 1183, 1186, 1189, 1192, 1195, 1198, 1201, 1204, 1207, 1210, 1213, 1216],
+ [ 548, 551, 554, 557, 560, 563, 566, 569, 572, 575, 578, 581, 584, 587, 1178, 1181, 1184, 1187, 1190, 1193, 1196, 1199, 1202, 1205, 1208, 1211, 1214, 1217],
+ [ 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627, 1218, 1221, 1224, 1227, 1230, 1233, 1236, 1239, 1242, 1245, 1248, 1251, 1254, 1257],
+ [ 589, 592, 595, 598, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628, 1219, 1222, 1225, 1228, 1231, 1234, 1237, 1240, 1243, 1246, 1249, 1252, 1255, 1258],
+ [ 590, 593, 596, 599, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629, 1220, 1223, 1226, 1229, 1232, 1235, 1238, 1241, 1244, 1247, 1250, 1253, 1256, 1259],
+ [1680, 1683, 1686, 1689, 1692, 1695, 1698, 1701, 1704, 1707, 1710, 1713, 1716, 1719, 2310, 2313, 2316, 2319, 2322, 2325, 2328, 2331, 2334, 2337, 2340, 2343, 2346, 2349],
+ [1681, 1684, 1687, 1690, 1693, 1696, 1699, 1702, 1705, 1708, 1711, 1714, 1717, 1720, 2311, 2314, 2317, 2320, 2323, 2326, 2329, 2332, 2335, 2338, 2341, 2344, 2347, 2350],
+ [1682, 1685, 1688, 1691, 1694, 1697, 1700, 1703, 1706, 1709, 1712, 1715, 1718, 1721, 2312, 2315, 2318, 2321, 2324, 2327, 2330, 2333, 2336, 2339, 2342, 2345, 2348, 2351],
+ [1722, 1725, 1728, 1731, 1734, 1737, 1740, 1743, 1746, 1749, 1752, 1755, 1758, 1761, 2352, 2355, 2358, 2361, 2364, 2367, 2370, 2373, 2376, 2379, 2382, 2385, 2388, 2391],
+ [1723, 1726, 1729, 1732, 1735, 1738, 1741, 1744, 1747, 1750, 1753, 1756, 1759, 1762, 2353, 2356, 2359, 2362, 2365, 2368, 2371, 2374, 2377, 2380, 2383, 2386, 2389, 2392],
+ [1724, 1727, 1730, 1733, 1736, 1739, 1742, 1745, 1748, 1751, 1754, 1757, 1760, 1763, 2354, 2357, 2360, 2363, 2366, 2369, 2372, 2375, 2378, 2381, 2384, 2387, 2390, 2393],
+ [1764, 1767, 1770, 1773, 1776, 1779, 1782, 1785, 1788, 1791, 1794, 1797, 1800, 1803, 2394, 2397, 2400, 2403, 2406, 2409, 2412, 2415, 2418, 2421, 2424, 2427, 2430, 2433],
+ [1765, 1768, 1771, 1774, 1777, 1780, 1783, 1786, 1789, 1792, 1795, 1798, 1801, 1804, 2395, 2398, 2401, 2404, 2407, 2410, 2413, 2416, 2419, 2422, 2425, 2428, 2431, 2434],
+ [1766, 1769, 1772, 1775, 1778, 1781, 1784, 1787, 1790, 1793, 1796, 1799, 1802, 1805, 2396, 2399, 2402, 2405, 2408, 2411, 2414, 2417, 2420, 2423, 2426, 2429, 2432, 2435],
+ [1806, 1809, 1812, 1815, 1818, 1821, 1824, 1827, 1830, 1833, 1836, 1839, 1842, 1845, 2436, 2439, 2442, 2445, 2448, 2451, 2454, 2457, 2460, 2463, 2466, 2469, 2472, 2475],
+ [1807, 1810, 1813, 1816, 1819, 1822, 1825, 1828, 1831, 1834, 1837, 1840, 1843, 1846, 2437, 2440, 2443, 2446, 2449, 2452, 2455, 2458, 2461, 2464, 2467, 2470, 2473, 2476],
+ [1808, 1811, 1814, 1817, 1820, 1823, 1826, 1829, 1832, 1835, 1838, 1841, 1844, 1847, 2438, 2441, 2444, 2447, 2450, 2453, 2456, 2459, 2462, 2465, 2468, 2471, 2474, 2477],
+ [1848, 1851, 1854, 1857, 1860, 1863, 1866, 1869, 1872, 1875, 1878, 1881, 1884, 1887, 2478, 2481, 2484, 2487, 2490, 2493, 2496, 2499, 2502, 2505, 2508, 2511, 2514, 2517],
+ [1849, 1852, 1855, 1858, 1861, 1864, 1867, 1870, 1873, 1876, 1879, 1882, 1885, 1888, 2479, 2482, 2485, 2488, 2491, 2494, 2497, 2500, 2503, 2506, 2509, 2512, 2515, 2518],
+ [1850, 1853, 1856, 1859, 1862, 1865, 1868, 1871, 1874, 1877, 1880, 1883, 1886, 1889, 2480, 2483, 2486, 2489, 2492, 2495, 2498, 2501, 2504, 2507, 2510, 2513, 2516, 2519],
+ [2688, 2691, 2694, 2697, 2700, 2703, 2706, 2709, 2712, 2715, 2718, 2721, 2724, 2727, 2940, 2943, 2946, 2949, 2952, 2955, 2958, 2961, 2964, 2967, 2970, 2973, 2976, 2979],
+ [2689, 2692, 2695, 2698, 2701, 2704, 2707, 2710, 2713, 2716, 2719, 2722, 2725, 2728, 2941, 2944, 2947, 2950, 2953, 2956, 2959, 2962, 2965, 2968, 2971, 2974, 2977, 2980],
+ [2690, 2693, 2696, 2699, 2702, 2705, 2708, 2711, 2714, 2717, 2720, 2723, 2726, 2729, 2942, 2945, 2948, 2951, 2954, 2957, 2960, 2963, 2966, 2969, 2972, 2975, 2978, 2981],
+ [2730, 2733, 2736, 2739, 2742, 2745, 2748, 2751, 2754, 2757, 2760, 2763, 2766, 2769, 2982, 2985, 2988, 2991, 2994, 2997, 3000, 3003, 3006, 3009, 3012, 3015, 3018, 3021],
+ [2731, 2734, 2737, 2740, 2743, 2746, 2749, 2752, 2755, 2758, 2761, 2764, 2767, 2770, 2983, 2986, 2989, 2992, 2995, 2998, 3001, 3004, 3007, 3010, 3013, 3016, 3019, 3022],
+ [2732, 2735, 2738, 2741, 2744, 2747, 2750, 2753, 2756, 2759, 2762, 2765, 2768, 2771, 2984, 2987, 2990, 2993, 2996, 2999, 3002, 3005, 3008, 3011, 3014, 3017, 3020, 3023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 3).all()
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: group_tiler_partial_both
+@construct_test
+def group_tiler_partial_both():
+
+ # All row major
+ tensor_dims = (3 * 4 * 3, 2 * 6 * 2)
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ allow_partial=True,
+ )
+
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[2, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[2, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 240, 241, 246, 247, 252, 253, 258, 259, 264, 265],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 242, 243, 248, 249, 254, 255, 260, 261, 266, 267],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 244, 245, 250, 251, 256, 257, 262, 263, 268, 269],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 270, 271, 276, 277, 282, 283, 288, 289, 294, 295],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 272, 273, 278, 279, 284, 285, 290, 291, 296, 297],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 274, 275, 280, 281, 286, 287, 292, 293, 298, 299],
+ [126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325],
+ [128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327],
+ [130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329],
+ [168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 330, 331, 336, 337, 342, 343, 348, 349, 354, 355],
+ [170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 332, 333, 338, 339, 344, 345, 350, 351, 356, 357],
+ [172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 334, 335, 340, 341, 346, 347, 352, 353, 358, 359],
+ [360, 361, 366, 367, 372, 373, 378, 379, 384, 385, 390, 391, 396, 397, 570, 571, 576, 577, 582, 583, 588, 589, 594, 595],
+ [362, 363, 368, 369, 374, 375, 380, 381, 386, 387, 392, 393, 398, 399, 572, 573, 578, 579, 584, 585, 590, 591, 596, 597],
+ [364, 365, 370, 371, 376, 377, 382, 383, 388, 389, 394, 395, 400, 401, 574, 575, 580, 581, 586, 587, 592, 593, 598, 599],
+ [402, 403, 408, 409, 414, 415, 420, 421, 426, 427, 432, 433, 438, 439, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625],
+ [404, 405, 410, 411, 416, 417, 422, 423, 428, 429, 434, 435, 440, 441, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627],
+ [406, 407, 412, 413, 418, 419, 424, 425, 430, 431, 436, 437, 442, 443, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629],
+ [444, 445, 450, 451, 456, 457, 462, 463, 468, 469, 474, 475, 480, 481, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655],
+ [446, 447, 452, 453, 458, 459, 464, 465, 470, 471, 476, 477, 482, 483, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657],
+ [448, 449, 454, 455, 460, 461, 466, 467, 472, 473, 478, 479, 484, 485, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659],
+ [486, 487, 492, 493, 498, 499, 504, 505, 510, 511, 516, 517, 522, 523, 660, 661, 666, 667, 672, 673, 678, 679, 684, 685],
+ [488, 489, 494, 495, 500, 501, 506, 507, 512, 513, 518, 519, 524, 525, 662, 663, 668, 669, 674, 675, 680, 681, 686, 687],
+ [490, 491, 496, 497, 502, 503, 508, 509, 514, 515, 520, 521, 526, 527, 664, 665, 670, 671, 676, 677, 682, 683, 688, 689],
+ [528, 529, 534, 535, 540, 541, 546, 547, 552, 553, 558, 559, 564, 565, 690, 691, 696, 697, 702, 703, 708, 709, 714, 715],
+ [530, 531, 536, 537, 542, 543, 548, 549, 554, 555, 560, 561, 566, 567, 692, 693, 698, 699, 704, 705, 710, 711, 716, 717],
+ [532, 533, 538, 539, 544, 545, 550, 551, 556, 557, 562, 563, 568, 569, 694, 695, 700, 701, 706, 707, 712, 713, 718, 719],
+ [720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751, 756, 757, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829],
+ [722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753, 758, 759, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831],
+ [724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755, 760, 761, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833],
+ [762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793, 798, 799, 834, 835, 840, 841, 846, 847, 852, 853, 858, 859],
+ [764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795, 800, 801, 836, 837, 842, 843, 848, 849, 854, 855, 860, 861],
+ [766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797, 802, 803, 838, 839, 844, 845, 850, 851, 856, 857, 862, 863]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ allow_partial=True,
+ )
+
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[1, 2, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[1, 2, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237],
+ [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238],
+ [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239],
+ [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 240, 243, 246, 249, 252, 255, 258, 261, 264, 267],
+ [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 241, 244, 247, 250, 253, 256, 259, 262, 265, 268],
+ [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 242, 245, 248, 251, 254, 257, 260, 263, 266, 269],
+ [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 270, 273, 276, 279, 282, 285, 288, 291, 294, 297],
+ [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 271, 274, 277, 280, 283, 286, 289, 292, 295, 298],
+ [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 272, 275, 278, 281, 284, 287, 290, 293, 296, 299],
+ [126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327],
+ [127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328],
+ [128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329],
+ [168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 330, 333, 336, 339, 342, 345, 348, 351, 354, 357],
+ [169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 331, 334, 337, 340, 343, 346, 349, 352, 355, 358],
+ [170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 332, 335, 338, 341, 344, 347, 350, 353, 356, 359],
+ [360, 363, 366, 369, 372, 375, 378, 381, 384, 387, 390, 393, 396, 399, 570, 573, 576, 579, 582, 585, 588, 591, 594, 597],
+ [361, 364, 367, 370, 373, 376, 379, 382, 385, 388, 391, 394, 397, 400, 571, 574, 577, 580, 583, 586, 589, 592, 595, 598],
+ [362, 365, 368, 371, 374, 377, 380, 383, 386, 389, 392, 395, 398, 401, 572, 575, 578, 581, 584, 587, 590, 593, 596, 599],
+ [402, 405, 408, 411, 414, 417, 420, 423, 426, 429, 432, 435, 438, 441, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627],
+ [403, 406, 409, 412, 415, 418, 421, 424, 427, 430, 433, 436, 439, 442, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628],
+ [404, 407, 410, 413, 416, 419, 422, 425, 428, 431, 434, 437, 440, 443, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629],
+ [444, 447, 450, 453, 456, 459, 462, 465, 468, 471, 474, 477, 480, 483, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657],
+ [445, 448, 451, 454, 457, 460, 463, 466, 469, 472, 475, 478, 481, 484, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658],
+ [446, 449, 452, 455, 458, 461, 464, 467, 470, 473, 476, 479, 482, 485, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659],
+ [486, 489, 492, 495, 498, 501, 504, 507, 510, 513, 516, 519, 522, 525, 660, 663, 666, 669, 672, 675, 678, 681, 684, 687],
+ [487, 490, 493, 496, 499, 502, 505, 508, 511, 514, 517, 520, 523, 526, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688],
+ [488, 491, 494, 497, 500, 503, 506, 509, 512, 515, 518, 521, 524, 527, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689],
+ [528, 531, 534, 537, 540, 543, 546, 549, 552, 555, 558, 561, 564, 567, 690, 693, 696, 699, 702, 705, 708, 711, 714, 717],
+ [529, 532, 535, 538, 541, 544, 547, 550, 553, 556, 559, 562, 565, 568, 691, 694, 697, 700, 703, 706, 709, 712, 715, 718],
+ [530, 533, 536, 539, 542, 545, 548, 551, 554, 557, 560, 563, 566, 569, 692, 695, 698, 701, 704, 707, 710, 713, 716, 719],
+ [720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753, 756, 759, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831],
+ [721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754, 757, 760, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832],
+ [722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755, 758, 761, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833],
+ [762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795, 798, 801, 834, 837, 840, 843, 846, 849, 852, 855, 858, 861],
+ [763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796, 799, 802, 835, 838, 841, 844, 847, 850, 853, 856, 859, 862],
+ [764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797, 800, 803, 836, 839, 842, 845, 848, 851, 854, 857, 860, 863]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile group col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_group_col_major=True,
+ allow_partial=True,
+ )
+
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[1, 7, 6, 2], strides=[0, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[1, 5, 6, 2], strides=[0, 2, 24, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331],
+ [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333],
+ [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335],
+ [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337],
+ [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339],
+ [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341],
+ [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343],
+ [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345],
+ [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347],
+ [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349],
+ [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351],
+ [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353],
+ [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355],
+ [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357],
+ [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359],
+ [360, 361, 390, 391, 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691],
+ [362, 363, 392, 393, 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693],
+ [364, 365, 394, 395, 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695],
+ [366, 367, 396, 397, 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697],
+ [368, 369, 398, 399, 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699],
+ [370, 371, 400, 401, 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701],
+ [372, 373, 402, 403, 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703],
+ [374, 375, 404, 405, 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705],
+ [376, 377, 406, 407, 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707],
+ [378, 379, 408, 409, 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709],
+ [380, 381, 410, 411, 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711],
+ [382, 383, 412, 413, 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713],
+ [384, 385, 414, 415, 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715],
+ [386, 387, 416, 417, 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717],
+ [388, 389, 418, 419, 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719],
+ [720, 721, 732, 733, 744, 745, 756, 757, 768, 769, 780, 781, 792, 793, 804, 805, 816, 817, 828, 829, 840, 841, 852, 853],
+ [722, 723, 734, 735, 746, 747, 758, 759, 770, 771, 782, 783, 794, 795, 806, 807, 818, 819, 830, 831, 842, 843, 854, 855],
+ [724, 725, 736, 737, 748, 749, 760, 761, 772, 773, 784, 785, 796, 797, 808, 809, 820, 821, 832, 833, 844, 845, 856, 857],
+ [726, 727, 738, 739, 750, 751, 762, 763, 774, 775, 786, 787, 798, 799, 810, 811, 822, 823, 834, 835, 846, 847, 858, 859],
+ [728, 729, 740, 741, 752, 753, 764, 765, 776, 777, 788, 789, 800, 801, 812, 813, 824, 825, 836, 837, 848, 849, 860, 861],
+ [730, 731, 742, 743, 754, 755, 766, 767, 778, 779, 790, 791, 802, 803, 814, 815, 826, 827, 838, 839, 850, 851, 862, 863]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # iter col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ iter_col_major=True,
+ allow_partial=True,
+ )
+
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[2, 7, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[2, 5, 3, 2], strides=[72, 2, 24, 1]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 504, 505, 510, 511, 516, 517, 522, 523, 528, 529],
+ [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 506, 507, 512, 513, 518, 519, 524, 525, 530, 531],
+ [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 508, 509, 514, 515, 520, 521, 526, 527, 532, 533],
+ [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 534, 535, 540, 541, 546, 547, 552, 553, 558, 559],
+ [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 536, 537, 542, 543, 548, 549, 554, 555, 560, 561],
+ [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 538, 539, 544, 545, 550, 551, 556, 557, 562, 563],
+ [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 564, 565, 570, 571, 576, 577, 582, 583, 588, 589],
+ [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 566, 567, 572, 573, 578, 579, 584, 585, 590, 591],
+ [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 568, 569, 574, 575, 580, 581, 586, 587, 592, 593],
+ [126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619],
+ [128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621],
+ [130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623],
+ [168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 624, 625, 630, 631, 636, 637, 642, 643, 648, 649],
+ [170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 626, 627, 632, 633, 638, 639, 644, 645, 650, 651],
+ [172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 628, 629, 634, 635, 640, 641, 646, 647, 652, 653],
+ [210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 654, 655, 660, 661, 666, 667, 672, 673, 678, 679],
+ [212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 656, 657, 662, 663, 668, 669, 674, 675, 680, 681],
+ [214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 658, 659, 664, 665, 670, 671, 676, 677, 682, 683],
+ [252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709],
+ [254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711],
+ [256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713],
+ [294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739],
+ [296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741],
+ [298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743],
+ [336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 744, 745, 750, 751, 756, 757, 762, 763, 768, 769],
+ [338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 746, 747, 752, 753, 758, 759, 764, 765, 770, 771],
+ [340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 748, 749, 754, 755, 760, 761, 766, 767, 772, 773],
+ [378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 774, 775, 780, 781, 786, 787, 792, 793, 798, 799],
+ [380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 776, 777, 782, 783, 788, 789, 794, 795, 800, 801],
+ [382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 778, 779, 784, 785, 790, 791, 796, 797, 802, 803],
+ [420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829],
+ [422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831],
+ [424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833],
+ [462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 834, 835, 840, 841, 846, 847, 852, 853, 858, 859],
+ [464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 836, 837, 842, 843, 848, 849, 854, 855, 860, 861],
+ [466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 838, 839, 844, 845, 850, 851, 856, 857, 862, 863]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # all col major
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ iter_col_major=True,
+ tile_col_major=True,
+ tile_group_col_major=True,
+ allow_partial=True,
+ )
+
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[7, 2, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[5, 2, 2, 3], strides=[2, 72, 1, 24]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627],
+ [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628],
+ [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629],
+ [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 510, 513, 540, 543, 570, 573, 600, 603, 630, 633],
+ [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 511, 514, 541, 544, 571, 574, 601, 604, 631, 634],
+ [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 512, 515, 542, 545, 572, 575, 602, 605, 632, 635],
+ [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 516, 519, 546, 549, 576, 579, 606, 609, 636, 639],
+ [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 517, 520, 547, 550, 577, 580, 607, 610, 637, 640],
+ [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 518, 521, 548, 551, 578, 581, 608, 611, 638, 641],
+ [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 522, 525, 552, 555, 582, 585, 612, 615, 642, 645],
+ [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 523, 526, 553, 556, 583, 586, 613, 616, 643, 646],
+ [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 524, 527, 554, 557, 584, 587, 614, 617, 644, 647],
+ [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 528, 531, 558, 561, 588, 591, 618, 621, 648, 651],
+ [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 529, 532, 559, 562, 589, 592, 619, 622, 649, 652],
+ [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 530, 533, 560, 563, 590, 593, 620, 623, 650, 653],
+ [210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393, 654, 657, 684, 687, 714, 717, 744, 747, 774, 777],
+ [211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394, 655, 658, 685, 688, 715, 718, 745, 748, 775, 778],
+ [212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395, 656, 659, 686, 689, 716, 719, 746, 749, 776, 779],
+ [216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399, 660, 663, 690, 693, 720, 723, 750, 753, 780, 783],
+ [217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400, 661, 664, 691, 694, 721, 724, 751, 754, 781, 784],
+ [218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401, 662, 665, 692, 695, 722, 725, 752, 755, 782, 785],
+ [222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405, 666, 669, 696, 699, 726, 729, 756, 759, 786, 789],
+ [223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406, 667, 670, 697, 700, 727, 730, 757, 760, 787, 790],
+ [224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407, 668, 671, 698, 701, 728, 731, 758, 761, 788, 791],
+ [228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411, 672, 675, 702, 705, 732, 735, 762, 765, 792, 795],
+ [229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412, 673, 676, 703, 706, 733, 736, 763, 766, 793, 796],
+ [230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413, 674, 677, 704, 707, 734, 737, 764, 767, 794, 797],
+ [234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417, 678, 681, 708, 711, 738, 741, 768, 771, 798, 801],
+ [235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418, 679, 682, 709, 712, 739, 742, 769, 772, 799, 802],
+ [236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419, 680, 683, 710, 713, 740, 743, 770, 773, 800, 803],
+ [420, 423, 432, 435, 444, 447, 456, 459, 468, 471, 480, 483, 492, 495, 804, 807, 816, 819, 828, 831, 840, 843, 852, 855],
+ [421, 424, 433, 436, 445, 448, 457, 460, 469, 472, 481, 484, 493, 496, 805, 808, 817, 820, 829, 832, 841, 844, 853, 856],
+ [422, 425, 434, 437, 446, 449, 458, 461, 470, 473, 482, 485, 494, 497, 806, 809, 818, 821, 830, 833, 842, 845, 854, 857],
+ [426, 429, 438, 441, 450, 453, 462, 465, 474, 477, 486, 489, 498, 501, 810, 813, 822, 825, 834, 837, 846, 849, 858, 861],
+ [427, 430, 439, 442, 451, 454, 463, 466, 475, 478, 487, 490, 499, 502, 811, 814, 823, 826, 835, 838, 847, 850, 859, 862],
+ [428, 431, 440, 443, 452, 455, 464, 467, 476, 479, 488, 491, 500, 503, 812, 815, 824, 827, 836, 839, 848, 851, 860, 863]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # pattern repeat
+ tiles = TensorTiler2D.group_tiler(
+ tensor_dims,
+ tile_dims=(3, 2),
+ tile_group_dims=(5, 7),
+ tile_col_major=True,
+ allow_partial=True,
+ pattern_repeat=2,
+ )
+
+ reference_tiles = TensorTileSequence.from_tiles(
+ [
+ TensorTile(
+ tensor_dims, offset=0, sizes=[2, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=14, sizes=[2, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=360, sizes=[2, 5, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=374, sizes=[2, 5, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=720, sizes=[2, 2, 14, 3], strides=[0, 72, 1, 24]
+ ),
+ TensorTile(
+ tensor_dims, offset=734, sizes=[2, 2, 10, 3], strides=[0, 72, 1, 24]
+ ),
+ ]
+ )
+ assert tiles == reference_tiles
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249, 570, 573, 576, 579, 582, 585, 588, 591, 594, 597],
+ [ 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250, 571, 574, 577, 580, 583, 586, 589, 592, 595, 598],
+ [ 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251, 572, 575, 578, 581, 584, 587, 590, 593, 596, 599],
+ [ 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627],
+ [ 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628],
+ [ 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629],
+ [ 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657],
+ [ 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658],
+ [ 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659],
+ [ 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375, 660, 663, 666, 669, 672, 675, 678, 681, 684, 687],
+ [ 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688],
+ [ 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689],
+ [ 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417, 690, 693, 696, 699, 702, 705, 708, 711, 714, 717],
+ [ 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418, 691, 694, 697, 700, 703, 706, 709, 712, 715, 718],
+ [ 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419, 692, 695, 698, 701, 704, 707, 710, 713, 716, 719],
+ [ 930, 933, 936, 939, 942, 945, 948, 951, 954, 957, 960, 963, 966, 969, 1290, 1293, 1296, 1299, 1302, 1305, 1308, 1311, 1314, 1317],
+ [ 931, 934, 937, 940, 943, 946, 949, 952, 955, 958, 961, 964, 967, 970, 1291, 1294, 1297, 1300, 1303, 1306, 1309, 1312, 1315, 1318],
+ [ 932, 935, 938, 941, 944, 947, 950, 953, 956, 959, 962, 965, 968, 971, 1292, 1295, 1298, 1301, 1304, 1307, 1310, 1313, 1316, 1319],
+ [ 972, 975, 978, 981, 984, 987, 990, 993, 996, 999, 1002, 1005, 1008, 1011, 1320, 1323, 1326, 1329, 1332, 1335, 1338, 1341, 1344, 1347],
+ [ 973, 976, 979, 982, 985, 988, 991, 994, 997, 1000, 1003, 1006, 1009, 1012, 1321, 1324, 1327, 1330, 1333, 1336, 1339, 1342, 1345, 1348],
+ [ 974, 977, 980, 983, 986, 989, 992, 995, 998, 1001, 1004, 1007, 1010, 1013, 1322, 1325, 1328, 1331, 1334, 1337, 1340, 1343, 1346, 1349],
+ [1014, 1017, 1020, 1023, 1026, 1029, 1032, 1035, 1038, 1041, 1044, 1047, 1050, 1053, 1350, 1353, 1356, 1359, 1362, 1365, 1368, 1371, 1374, 1377],
+ [1015, 1018, 1021, 1024, 1027, 1030, 1033, 1036, 1039, 1042, 1045, 1048, 1051, 1054, 1351, 1354, 1357, 1360, 1363, 1366, 1369, 1372, 1375, 1378],
+ [1016, 1019, 1022, 1025, 1028, 1031, 1034, 1037, 1040, 1043, 1046, 1049, 1052, 1055, 1352, 1355, 1358, 1361, 1364, 1367, 1370, 1373, 1376, 1379],
+ [1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089, 1092, 1095, 1380, 1383, 1386, 1389, 1392, 1395, 1398, 1401, 1404, 1407],
+ [1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090, 1093, 1096, 1381, 1384, 1387, 1390, 1393, 1396, 1399, 1402, 1405, 1408],
+ [1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091, 1094, 1097, 1382, 1385, 1388, 1391, 1394, 1397, 1400, 1403, 1406, 1409],
+ [1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131, 1134, 1137, 1410, 1413, 1416, 1419, 1422, 1425, 1428, 1431, 1434, 1437],
+ [1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132, 1135, 1138, 1411, 1414, 1417, 1420, 1423, 1426, 1429, 1432, 1435, 1438],
+ [1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133, 1136, 1139, 1412, 1415, 1418, 1421, 1424, 1427, 1430, 1433, 1436, 1439],
+ [1524, 1527, 1530, 1533, 1536, 1539, 1542, 1545, 1548, 1551, 1554, 1557, 1560, 1563, 1668, 1671, 1674, 1677, 1680, 1683, 1686, 1689, 1692, 1695],
+ [1525, 1528, 1531, 1534, 1537, 1540, 1543, 1546, 1549, 1552, 1555, 1558, 1561, 1564, 1669, 1672, 1675, 1678, 1681, 1684, 1687, 1690, 1693, 1696],
+ [1526, 1529, 1532, 1535, 1538, 1541, 1544, 1547, 1550, 1553, 1556, 1559, 1562, 1565, 1670, 1673, 1676, 1679, 1682, 1685, 1688, 1691, 1694, 1697],
+ [1566, 1569, 1572, 1575, 1578, 1581, 1584, 1587, 1590, 1593, 1596, 1599, 1602, 1605, 1698, 1701, 1704, 1707, 1710, 1713, 1716, 1719, 1722, 1725],
+ [1567, 1570, 1573, 1576, 1579, 1582, 1585, 1588, 1591, 1594, 1597, 1600, 1603, 1606, 1699, 1702, 1705, 1708, 1711, 1714, 1717, 1720, 1723, 1726],
+ [1568, 1571, 1574, 1577, 1580, 1583, 1586, 1589, 1592, 1595, 1598, 1601, 1604, 1607, 1700, 1703, 1706, 1709, 1712, 1715, 1718, 1721, 1724, 1727]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 2).all()
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/simple_tiler.py b/test/python/tensortiler/simple_tiler.py
new file mode 100644
index 0000000000..616b9cb8cc
--- /dev/null
+++ b/test/python/tensortiler/simple_tiler.py
@@ -0,0 +1,265 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile, TensorTileSequence, TensorTiler2D
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: simple_tiler
+@construct_test
+def simple_tiler():
+ single_tile = TensorTiler2D.simple_tiler((3, 5))
+ assert len(single_tile) == 1
+ single_tile[0] == TensorTile(
+ (3, 5), offset=0, sizes=[1, 1, 3, 5], strides=[0, 0, 5, 1]
+ )
+ ref_access_order_tensor = np.array(
+ [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]
+ )
+ access_order, access_count = single_tile.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ tiles = TensorTiler2D.simple_tiler((9, 4), (3, 2))
+ assert len(tiles) == 6
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 6, 7],
+ [ 2, 3, 8, 9],
+ [ 4, 5, 10, 11],
+ [12, 13, 18, 19],
+ [14, 15, 20, 21],
+ [16, 17, 22, 23],
+ [24, 25, 30, 31],
+ [26, 27, 32, 33],
+ [28, 29, 34, 35]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ def offset_fn(step, _prev_offset):
+ offsets = [0, 2, 12, 14, 24, 26]
+ return offsets[step]
+
+ tiles2 = TensorTileSequence(
+ (9, 4),
+ num_steps=6,
+ sizes=[1, 1, 3, 2],
+ strides=[0, 0, 4, 1],
+ offset_fn=offset_fn,
+ )
+ assert tiles == tiles2
+
+ tile0_0 = TensorTile((9, 4), offset=0, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1])
+ tile0_1 = TensorTile((9, 4), offset=2, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1])
+ tile1_0 = TensorTile((9, 4), offset=12, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1])
+ tile1_1 = TensorTile((9, 4), offset=14, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1])
+
+ assert tiles2[0] == tile0_0
+ assert tiles2[1] == tile0_1
+ assert tiles2[2] == tile1_0
+ assert tiles2[3] == tile1_1
+
+ access_order, access_count = tiles2.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Check with column major iter order
+ tiles_iter_col_major = TensorTiler2D.simple_tiler(
+ (9, 4), (3, 2), iter_col_major=True
+ )
+ assert tiles_iter_col_major[0] == tile0_0
+ assert tiles_iter_col_major[1] == tile1_0
+ assert tiles_iter_col_major[3] == tile0_1
+ assert tiles_iter_col_major[4] == tile1_1
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 18, 19],
+ [ 2, 3, 20, 21],
+ [ 4, 5, 22, 23],
+ [ 6, 7, 24, 25],
+ [ 8, 9, 26, 27],
+ [10, 11, 28, 29],
+ [12, 13, 30, 31],
+ [14, 15, 32, 33],
+ [16, 17, 34, 35]])
+ # fmt: on
+ access_order, access_count = tiles_iter_col_major.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ tiles_tile_col_major = TensorTiler2D.simple_tiler(
+ (9, 4), (3, 2), tile_col_major=True
+ )
+ tile0_0 = TensorTile((9, 4), offset=0, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4])
+ tile0_1 = TensorTile((9, 4), offset=2, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4])
+ tile1_0 = TensorTile((9, 4), offset=12, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4])
+ tile1_1 = TensorTile((9, 4), offset=14, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4])
+ assert tiles_tile_col_major[0] == tile0_0
+ assert tiles_tile_col_major[1] == tile0_1
+ assert tiles_tile_col_major[2] == tile1_0
+ assert tiles_tile_col_major[3] == tile1_1
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 6, 9],
+ [ 1, 4, 7, 10],
+ [ 2, 5, 8, 11],
+ [12, 15, 18, 21],
+ [13, 16, 19, 22],
+ [14, 17, 20, 23],
+ [24, 27, 30, 33],
+ [25, 28, 31, 34],
+ [26, 29, 32, 35]])
+ # fmt: on
+ access_order, access_count = tiles_tile_col_major.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ tiles_tile_col_major_iter_col_major = TensorTiler2D.simple_tiler(
+ (9, 4), (3, 2), tile_col_major=True, iter_col_major=True
+ )
+ assert tiles_tile_col_major_iter_col_major[0] == tile0_0
+ assert tiles_tile_col_major_iter_col_major[1] == tile1_0
+ assert tiles_tile_col_major_iter_col_major[3] == tile0_1
+ assert tiles_tile_col_major_iter_col_major[4] == tile1_1
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 3, 18, 21],
+ [ 1, 4, 19, 22],
+ [ 2, 5, 20, 23],
+ [ 6, 9, 24, 27],
+ [ 7, 10, 25, 28],
+ [ 8, 11, 26, 29],
+ [12, 15, 30, 33],
+ [13, 16, 31, 34],
+ [14, 17, 32, 35]])
+ # fmt: on
+ access_order, access_count = tiles_tile_col_major_iter_col_major.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ tiles_repeat = TensorTiler2D.simple_tiler((9, 4), (3, 2), pattern_repeat=5)
+ tile_repeat0_0 = TensorTile(
+ (9, 4), offset=0, sizes=[1, 5, 3, 2], strides=[0, 0, 4, 1]
+ )
+ assert tiles_repeat[0] == tile_repeat0_0
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 24, 25, 54, 55],
+ [ 26, 27, 56, 57],
+ [ 28, 29, 58, 59],
+ [ 84, 85, 114, 115],
+ [ 86, 87, 116, 117],
+ [ 88, 89, 118, 119],
+ [144, 145, 174, 175],
+ [146, 147, 176, 177],
+ [148, 149, 178, 179]])
+ # fmt: on
+ access_order, access_count = tiles_repeat.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 5).all()
+
+ tiles_repeat = TensorTiler2D.simple_tiler(
+ (9, 4), (3, 2), tile_col_major=True, pattern_repeat=5
+ )
+ tile_repeat0_0 = TensorTile(
+ (9, 4), offset=0, sizes=[1, 5, 2, 3], strides=[0, 0, 1, 4]
+ )
+ assert tiles_repeat[0] == tile_repeat0_0
+
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 24, 27, 54, 57],
+ [ 25, 28, 55, 58],
+ [ 26, 29, 56, 59],
+ [ 84, 87, 114, 117],
+ [ 85, 88, 115, 118],
+ [ 86, 89, 116, 119],
+ [144, 147, 174, 177],
+ [145, 148, 175, 178],
+ [146, 149, 176, 179]])
+ # fmt: on
+ access_order, access_count = tiles_repeat.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 5).all()
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: simple_tiler_invalid
+@construct_test
+def simple_tiler_invalid():
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (), (3, 2), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tensor dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (10, 9, 4), (3, 2), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too many tensor dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (9, 4), (3, -1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (9, 4), (3,), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too few tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (9, 4), (1, 1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too many tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (9, 4), (3, 2), tile_col_major=True, pattern_repeat=0
+ )
+ raise ValueError("Invalid repeat.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (9, 4), (4, 2), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Indivisible tile (height)")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.simple_tiler(
+ (9, 4), (3, 3), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Indivisible tile (width)")
+ except ValueError:
+ # good
+ pass
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/step_tiler.py b/test/python/tensortiler/step_tiler.py
new file mode 100644
index 0000000000..11821c60fe
--- /dev/null
+++ b/test/python/tensortiler/step_tiler.py
@@ -0,0 +1,857 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile, TensorTiler2D
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: step_tiler
+@construct_test
+def step_tiler():
+ # Start with Step == (1, 1)
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(2, 2),
+ tile_group_steps=(1, 1),
+ )
+ assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[64, 2, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=4, sizes=[2, 2, 2, 2], strides=[64, 2, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=392, sizes=[2, 2, 2, 2], strides=[64, 2, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 4, 5, 16, 17, 20, 21, 32, 33, 36, 37, 48, 49, 52, 53, 64, 65, 68, 69, 80, 81, 84, 85, 96, 97, 100, 101, 112, 113, 116, 117],
+ [ 2, 3, 6, 7, 18, 19, 22, 23, 34, 35, 38, 39, 50, 51, 54, 55, 66, 67, 70, 71, 82, 83, 86, 87, 98, 99, 102, 103, 114, 115, 118, 119],
+ [ 8, 9, 12, 13, 24, 25, 28, 29, 40, 41, 44, 45, 56, 57, 60, 61, 72, 73, 76, 77, 88, 89, 92, 93, 104, 105, 108, 109, 120, 121, 124, 125],
+ [ 10, 11, 14, 15, 26, 27, 30, 31, 42, 43, 46, 47, 58, 59, 62, 63, 74, 75, 78, 79, 90, 91, 94, 95, 106, 107, 110, 111, 122, 123, 126, 127],
+ [ 128, 129, 132, 133, 144, 145, 148, 149, 160, 161, 164, 165, 176, 177, 180, 181, 192, 193, 196, 197, 208, 209, 212, 213, 224, 225, 228, 229, 240, 241, 244, 245],
+ [ 130, 131, 134, 135, 146, 147, 150, 151, 162, 163, 166, 167, 178, 179, 182, 183, 194, 195, 198, 199, 210, 211, 214, 215, 226, 227, 230, 231, 242, 243, 246, 247],
+ [ 136, 137, 140, 141, 152, 153, 156, 157, 168, 169, 172, 173, 184, 185, 188, 189, 200, 201, 204, 205, 216, 217, 220, 221, 232, 233, 236, 237, 248, 249, 252, 253],
+ [ 138, 139, 142, 143, 154, 155, 158, 159, 170, 171, 174, 175, 186, 187, 190, 191, 202, 203, 206, 207, 218, 219, 222, 223, 234, 235, 238, 239, 250, 251, 254, 255],
+ [ 256, 257, 260, 261, 272, 273, 276, 277, 288, 289, 292, 293, 304, 305, 308, 309, 320, 321, 324, 325, 336, 337, 340, 341, 352, 353, 356, 357, 368, 369, 372, 373],
+ [ 258, 259, 262, 263, 274, 275, 278, 279, 290, 291, 294, 295, 306, 307, 310, 311, 322, 323, 326, 327, 338, 339, 342, 343, 354, 355, 358, 359, 370, 371, 374, 375],
+ [ 264, 265, 268, 269, 280, 281, 284, 285, 296, 297, 300, 301, 312, 313, 316, 317, 328, 329, 332, 333, 344, 345, 348, 349, 360, 361, 364, 365, 376, 377, 380, 381],
+ [ 266, 267, 270, 271, 282, 283, 286, 287, 298, 299, 302, 303, 314, 315, 318, 319, 330, 331, 334, 335, 346, 347, 350, 351, 362, 363, 366, 367, 378, 379, 382, 383],
+ [ 384, 385, 388, 389, 400, 401, 404, 405, 416, 417, 420, 421, 432, 433, 436, 437, 448, 449, 452, 453, 464, 465, 468, 469, 480, 481, 484, 485, 496, 497, 500, 501],
+ [ 386, 387, 390, 391, 402, 403, 406, 407, 418, 419, 422, 423, 434, 435, 438, 439, 450, 451, 454, 455, 466, 467, 470, 471, 482, 483, 486, 487, 498, 499, 502, 503],
+ [ 392, 393, 396, 397, 408, 409, 412, 413, 424, 425, 428, 429, 440, 441, 444, 445, 456, 457, 460, 461, 472, 473, 476, 477, 488, 489, 492, 493, 504, 505, 508, 509],
+ [ 394, 395, 398, 399, 410, 411, 414, 415, 426, 427, 430, 431, 442, 443, 446, 447, 458, 459, 462, 463, 474, 475, 478, 479, 490, 491, 494, 495, 506, 507, 510, 511],
+ [ 512, 513, 516, 517, 528, 529, 532, 533, 544, 545, 548, 549, 560, 561, 564, 565, 576, 577, 580, 581, 592, 593, 596, 597, 608, 609, 612, 613, 624, 625, 628, 629],
+ [ 514, 515, 518, 519, 530, 531, 534, 535, 546, 547, 550, 551, 562, 563, 566, 567, 578, 579, 582, 583, 594, 595, 598, 599, 610, 611, 614, 615, 626, 627, 630, 631],
+ [ 520, 521, 524, 525, 536, 537, 540, 541, 552, 553, 556, 557, 568, 569, 572, 573, 584, 585, 588, 589, 600, 601, 604, 605, 616, 617, 620, 621, 632, 633, 636, 637],
+ [ 522, 523, 526, 527, 538, 539, 542, 543, 554, 555, 558, 559, 570, 571, 574, 575, 586, 587, 590, 591, 602, 603, 606, 607, 618, 619, 622, 623, 634, 635, 638, 639],
+ [ 640, 641, 644, 645, 656, 657, 660, 661, 672, 673, 676, 677, 688, 689, 692, 693, 704, 705, 708, 709, 720, 721, 724, 725, 736, 737, 740, 741, 752, 753, 756, 757],
+ [ 642, 643, 646, 647, 658, 659, 662, 663, 674, 675, 678, 679, 690, 691, 694, 695, 706, 707, 710, 711, 722, 723, 726, 727, 738, 739, 742, 743, 754, 755, 758, 759],
+ [ 648, 649, 652, 653, 664, 665, 668, 669, 680, 681, 684, 685, 696, 697, 700, 701, 712, 713, 716, 717, 728, 729, 732, 733, 744, 745, 748, 749, 760, 761, 764, 765],
+ [ 650, 651, 654, 655, 666, 667, 670, 671, 682, 683, 686, 687, 698, 699, 702, 703, 714, 715, 718, 719, 730, 731, 734, 735, 746, 747, 750, 751, 762, 763, 766, 767],
+ [ 768, 769, 772, 773, 784, 785, 788, 789, 800, 801, 804, 805, 816, 817, 820, 821, 832, 833, 836, 837, 848, 849, 852, 853, 864, 865, 868, 869, 880, 881, 884, 885],
+ [ 770, 771, 774, 775, 786, 787, 790, 791, 802, 803, 806, 807, 818, 819, 822, 823, 834, 835, 838, 839, 850, 851, 854, 855, 866, 867, 870, 871, 882, 883, 886, 887],
+ [ 776, 777, 780, 781, 792, 793, 796, 797, 808, 809, 812, 813, 824, 825, 828, 829, 840, 841, 844, 845, 856, 857, 860, 861, 872, 873, 876, 877, 888, 889, 892, 893],
+ [ 778, 779, 782, 783, 794, 795, 798, 799, 810, 811, 814, 815, 826, 827, 830, 831, 842, 843, 846, 847, 858, 859, 862, 863, 874, 875, 878, 879, 890, 891, 894, 895],
+ [ 896, 897, 900, 901, 912, 913, 916, 917, 928, 929, 932, 933, 944, 945, 948, 949, 960, 961, 964, 965, 976, 977, 980, 981, 992, 993, 996, 997, 1008, 1009, 1012, 1013],
+ [ 898, 899, 902, 903, 914, 915, 918, 919, 930, 931, 934, 935, 946, 947, 950, 951, 962, 963, 966, 967, 978, 979, 982, 983, 994, 995, 998, 999, 1010, 1011, 1014, 1015],
+ [ 904, 905, 908, 909, 920, 921, 924, 925, 936, 937, 940, 941, 952, 953, 956, 957, 968, 969, 972, 973, 984, 985, 988, 989, 1000, 1001, 1004, 1005, 1016, 1017, 1020, 1021],
+ [ 906, 907, 910, 911, 922, 923, 926, 927, 938, 939, 942, 943, 954, 955, 958, 959, 970, 971, 974, 975, 986, 987, 990, 991, 1002, 1003, 1006, 1007, 1018, 1019, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Step == (2, 1)
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(2, 2),
+ tile_group_steps=(2, 1),
+ )
+ assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=4, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=328, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=860, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 4, 5, 16, 17, 20, 21, 32, 33, 36, 37, 48, 49, 52, 53, 64, 65, 68, 69, 80, 81, 84, 85, 96, 97, 100, 101, 112, 113, 116, 117],
+ [ 2, 3, 6, 7, 18, 19, 22, 23, 34, 35, 38, 39, 50, 51, 54, 55, 66, 67, 70, 71, 82, 83, 86, 87, 98, 99, 102, 103, 114, 115, 118, 119],
+ [ 128, 129, 132, 133, 144, 145, 148, 149, 160, 161, 164, 165, 176, 177, 180, 181, 192, 193, 196, 197, 208, 209, 212, 213, 224, 225, 228, 229, 240, 241, 244, 245],
+ [ 130, 131, 134, 135, 146, 147, 150, 151, 162, 163, 166, 167, 178, 179, 182, 183, 194, 195, 198, 199, 210, 211, 214, 215, 226, 227, 230, 231, 242, 243, 246, 247],
+ [ 8, 9, 12, 13, 24, 25, 28, 29, 40, 41, 44, 45, 56, 57, 60, 61, 72, 73, 76, 77, 88, 89, 92, 93, 104, 105, 108, 109, 120, 121, 124, 125],
+ [ 10, 11, 14, 15, 26, 27, 30, 31, 42, 43, 46, 47, 58, 59, 62, 63, 74, 75, 78, 79, 90, 91, 94, 95, 106, 107, 110, 111, 122, 123, 126, 127],
+ [ 136, 137, 140, 141, 152, 153, 156, 157, 168, 169, 172, 173, 184, 185, 188, 189, 200, 201, 204, 205, 216, 217, 220, 221, 232, 233, 236, 237, 248, 249, 252, 253],
+ [ 138, 139, 142, 143, 154, 155, 158, 159, 170, 171, 174, 175, 186, 187, 190, 191, 202, 203, 206, 207, 218, 219, 222, 223, 234, 235, 238, 239, 250, 251, 254, 255],
+ [ 256, 257, 260, 261, 272, 273, 276, 277, 288, 289, 292, 293, 304, 305, 308, 309, 320, 321, 324, 325, 336, 337, 340, 341, 352, 353, 356, 357, 368, 369, 372, 373],
+ [ 258, 259, 262, 263, 274, 275, 278, 279, 290, 291, 294, 295, 306, 307, 310, 311, 322, 323, 326, 327, 338, 339, 342, 343, 354, 355, 358, 359, 370, 371, 374, 375],
+ [ 384, 385, 388, 389, 400, 401, 404, 405, 416, 417, 420, 421, 432, 433, 436, 437, 448, 449, 452, 453, 464, 465, 468, 469, 480, 481, 484, 485, 496, 497, 500, 501],
+ [ 386, 387, 390, 391, 402, 403, 406, 407, 418, 419, 422, 423, 434, 435, 438, 439, 450, 451, 454, 455, 466, 467, 470, 471, 482, 483, 486, 487, 498, 499, 502, 503],
+ [ 264, 265, 268, 269, 280, 281, 284, 285, 296, 297, 300, 301, 312, 313, 316, 317, 328, 329, 332, 333, 344, 345, 348, 349, 360, 361, 364, 365, 376, 377, 380, 381],
+ [ 266, 267, 270, 271, 282, 283, 286, 287, 298, 299, 302, 303, 314, 315, 318, 319, 330, 331, 334, 335, 346, 347, 350, 351, 362, 363, 366, 367, 378, 379, 382, 383],
+ [ 392, 393, 396, 397, 408, 409, 412, 413, 424, 425, 428, 429, 440, 441, 444, 445, 456, 457, 460, 461, 472, 473, 476, 477, 488, 489, 492, 493, 504, 505, 508, 509],
+ [ 394, 395, 398, 399, 410, 411, 414, 415, 426, 427, 430, 431, 442, 443, 446, 447, 458, 459, 462, 463, 474, 475, 478, 479, 490, 491, 494, 495, 506, 507, 510, 511],
+ [ 512, 513, 516, 517, 528, 529, 532, 533, 544, 545, 548, 549, 560, 561, 564, 565, 576, 577, 580, 581, 592, 593, 596, 597, 608, 609, 612, 613, 624, 625, 628, 629],
+ [ 514, 515, 518, 519, 530, 531, 534, 535, 546, 547, 550, 551, 562, 563, 566, 567, 578, 579, 582, 583, 594, 595, 598, 599, 610, 611, 614, 615, 626, 627, 630, 631],
+ [ 640, 641, 644, 645, 656, 657, 660, 661, 672, 673, 676, 677, 688, 689, 692, 693, 704, 705, 708, 709, 720, 721, 724, 725, 736, 737, 740, 741, 752, 753, 756, 757],
+ [ 642, 643, 646, 647, 658, 659, 662, 663, 674, 675, 678, 679, 690, 691, 694, 695, 706, 707, 710, 711, 722, 723, 726, 727, 738, 739, 742, 743, 754, 755, 758, 759],
+ [ 520, 521, 524, 525, 536, 537, 540, 541, 552, 553, 556, 557, 568, 569, 572, 573, 584, 585, 588, 589, 600, 601, 604, 605, 616, 617, 620, 621, 632, 633, 636, 637],
+ [ 522, 523, 526, 527, 538, 539, 542, 543, 554, 555, 558, 559, 570, 571, 574, 575, 586, 587, 590, 591, 602, 603, 606, 607, 618, 619, 622, 623, 634, 635, 638, 639],
+ [ 648, 649, 652, 653, 664, 665, 668, 669, 680, 681, 684, 685, 696, 697, 700, 701, 712, 713, 716, 717, 728, 729, 732, 733, 744, 745, 748, 749, 760, 761, 764, 765],
+ [ 650, 651, 654, 655, 666, 667, 670, 671, 682, 683, 686, 687, 698, 699, 702, 703, 714, 715, 718, 719, 730, 731, 734, 735, 746, 747, 750, 751, 762, 763, 766, 767],
+ [ 768, 769, 772, 773, 784, 785, 788, 789, 800, 801, 804, 805, 816, 817, 820, 821, 832, 833, 836, 837, 848, 849, 852, 853, 864, 865, 868, 869, 880, 881, 884, 885],
+ [ 770, 771, 774, 775, 786, 787, 790, 791, 802, 803, 806, 807, 818, 819, 822, 823, 834, 835, 838, 839, 850, 851, 854, 855, 866, 867, 870, 871, 882, 883, 886, 887],
+ [ 896, 897, 900, 901, 912, 913, 916, 917, 928, 929, 932, 933, 944, 945, 948, 949, 960, 961, 964, 965, 976, 977, 980, 981, 992, 993, 996, 997, 1008, 1009, 1012, 1013],
+ [ 898, 899, 902, 903, 914, 915, 918, 919, 930, 931, 934, 935, 946, 947, 950, 951, 962, 963, 966, 967, 978, 979, 982, 983, 994, 995, 998, 999, 1010, 1011, 1014, 1015],
+ [ 776, 777, 780, 781, 792, 793, 796, 797, 808, 809, 812, 813, 824, 825, 828, 829, 840, 841, 844, 845, 856, 857, 860, 861, 872, 873, 876, 877, 888, 889, 892, 893],
+ [ 778, 779, 782, 783, 794, 795, 798, 799, 810, 811, 814, 815, 826, 827, 830, 831, 842, 843, 846, 847, 858, 859, 862, 863, 874, 875, 878, 879, 890, 891, 894, 895],
+ [ 904, 905, 908, 909, 920, 921, 924, 925, 936, 937, 940, 941, 952, 953, 956, 957, 968, 969, 972, 973, 984, 985, 988, 989, 1000, 1001, 1004, 1005, 1016, 1017, 1020, 1021],
+ [ 906, 907, 910, 911, 922, 923, 926, 927, 938, 939, 942, 943, 954, 955, 958, 959, 970, 971, 974, 975, 986, 987, 990, 991, 1002, 1003, 1006, 1007, 1018, 1019, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Step == (1, 2)
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(2, 2),
+ tile_group_steps=(1, 2),
+ )
+ assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=392, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=922, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 16, 17, 4, 5, 20, 21, 32, 33, 48, 49, 36, 37, 52, 53, 64, 65, 80, 81, 68, 69, 84, 85, 96, 97, 112, 113, 100, 101, 116, 117],
+ [ 2, 3, 18, 19, 6, 7, 22, 23, 34, 35, 50, 51, 38, 39, 54, 55, 66, 67, 82, 83, 70, 71, 86, 87, 98, 99, 114, 115, 102, 103, 118, 119],
+ [ 8, 9, 24, 25, 12, 13, 28, 29, 40, 41, 56, 57, 44, 45, 60, 61, 72, 73, 88, 89, 76, 77, 92, 93, 104, 105, 120, 121, 108, 109, 124, 125],
+ [ 10, 11, 26, 27, 14, 15, 30, 31, 42, 43, 58, 59, 46, 47, 62, 63, 74, 75, 90, 91, 78, 79, 94, 95, 106, 107, 122, 123, 110, 111, 126, 127],
+ [ 128, 129, 144, 145, 132, 133, 148, 149, 160, 161, 176, 177, 164, 165, 180, 181, 192, 193, 208, 209, 196, 197, 212, 213, 224, 225, 240, 241, 228, 229, 244, 245],
+ [ 130, 131, 146, 147, 134, 135, 150, 151, 162, 163, 178, 179, 166, 167, 182, 183, 194, 195, 210, 211, 198, 199, 214, 215, 226, 227, 242, 243, 230, 231, 246, 247],
+ [ 136, 137, 152, 153, 140, 141, 156, 157, 168, 169, 184, 185, 172, 173, 188, 189, 200, 201, 216, 217, 204, 205, 220, 221, 232, 233, 248, 249, 236, 237, 252, 253],
+ [ 138, 139, 154, 155, 142, 143, 158, 159, 170, 171, 186, 187, 174, 175, 190, 191, 202, 203, 218, 219, 206, 207, 222, 223, 234, 235, 250, 251, 238, 239, 254, 255],
+ [ 256, 257, 272, 273, 260, 261, 276, 277, 288, 289, 304, 305, 292, 293, 308, 309, 320, 321, 336, 337, 324, 325, 340, 341, 352, 353, 368, 369, 356, 357, 372, 373],
+ [ 258, 259, 274, 275, 262, 263, 278, 279, 290, 291, 306, 307, 294, 295, 310, 311, 322, 323, 338, 339, 326, 327, 342, 343, 354, 355, 370, 371, 358, 359, 374, 375],
+ [ 264, 265, 280, 281, 268, 269, 284, 285, 296, 297, 312, 313, 300, 301, 316, 317, 328, 329, 344, 345, 332, 333, 348, 349, 360, 361, 376, 377, 364, 365, 380, 381],
+ [ 266, 267, 282, 283, 270, 271, 286, 287, 298, 299, 314, 315, 302, 303, 318, 319, 330, 331, 346, 347, 334, 335, 350, 351, 362, 363, 378, 379, 366, 367, 382, 383],
+ [ 384, 385, 400, 401, 388, 389, 404, 405, 416, 417, 432, 433, 420, 421, 436, 437, 448, 449, 464, 465, 452, 453, 468, 469, 480, 481, 496, 497, 484, 485, 500, 501],
+ [ 386, 387, 402, 403, 390, 391, 406, 407, 418, 419, 434, 435, 422, 423, 438, 439, 450, 451, 466, 467, 454, 455, 470, 471, 482, 483, 498, 499, 486, 487, 502, 503],
+ [ 392, 393, 408, 409, 396, 397, 412, 413, 424, 425, 440, 441, 428, 429, 444, 445, 456, 457, 472, 473, 460, 461, 476, 477, 488, 489, 504, 505, 492, 493, 508, 509],
+ [ 394, 395, 410, 411, 398, 399, 414, 415, 426, 427, 442, 443, 430, 431, 446, 447, 458, 459, 474, 475, 462, 463, 478, 479, 490, 491, 506, 507, 494, 495, 510, 511],
+ [ 512, 513, 528, 529, 516, 517, 532, 533, 544, 545, 560, 561, 548, 549, 564, 565, 576, 577, 592, 593, 580, 581, 596, 597, 608, 609, 624, 625, 612, 613, 628, 629],
+ [ 514, 515, 530, 531, 518, 519, 534, 535, 546, 547, 562, 563, 550, 551, 566, 567, 578, 579, 594, 595, 582, 583, 598, 599, 610, 611, 626, 627, 614, 615, 630, 631],
+ [ 520, 521, 536, 537, 524, 525, 540, 541, 552, 553, 568, 569, 556, 557, 572, 573, 584, 585, 600, 601, 588, 589, 604, 605, 616, 617, 632, 633, 620, 621, 636, 637],
+ [ 522, 523, 538, 539, 526, 527, 542, 543, 554, 555, 570, 571, 558, 559, 574, 575, 586, 587, 602, 603, 590, 591, 606, 607, 618, 619, 634, 635, 622, 623, 638, 639],
+ [ 640, 641, 656, 657, 644, 645, 660, 661, 672, 673, 688, 689, 676, 677, 692, 693, 704, 705, 720, 721, 708, 709, 724, 725, 736, 737, 752, 753, 740, 741, 756, 757],
+ [ 642, 643, 658, 659, 646, 647, 662, 663, 674, 675, 690, 691, 678, 679, 694, 695, 706, 707, 722, 723, 710, 711, 726, 727, 738, 739, 754, 755, 742, 743, 758, 759],
+ [ 648, 649, 664, 665, 652, 653, 668, 669, 680, 681, 696, 697, 684, 685, 700, 701, 712, 713, 728, 729, 716, 717, 732, 733, 744, 745, 760, 761, 748, 749, 764, 765],
+ [ 650, 651, 666, 667, 654, 655, 670, 671, 682, 683, 698, 699, 686, 687, 702, 703, 714, 715, 730, 731, 718, 719, 734, 735, 746, 747, 762, 763, 750, 751, 766, 767],
+ [ 768, 769, 784, 785, 772, 773, 788, 789, 800, 801, 816, 817, 804, 805, 820, 821, 832, 833, 848, 849, 836, 837, 852, 853, 864, 865, 880, 881, 868, 869, 884, 885],
+ [ 770, 771, 786, 787, 774, 775, 790, 791, 802, 803, 818, 819, 806, 807, 822, 823, 834, 835, 850, 851, 838, 839, 854, 855, 866, 867, 882, 883, 870, 871, 886, 887],
+ [ 776, 777, 792, 793, 780, 781, 796, 797, 808, 809, 824, 825, 812, 813, 828, 829, 840, 841, 856, 857, 844, 845, 860, 861, 872, 873, 888, 889, 876, 877, 892, 893],
+ [ 778, 779, 794, 795, 782, 783, 798, 799, 810, 811, 826, 827, 814, 815, 830, 831, 842, 843, 858, 859, 846, 847, 862, 863, 874, 875, 890, 891, 878, 879, 894, 895],
+ [ 896, 897, 912, 913, 900, 901, 916, 917, 928, 929, 944, 945, 932, 933, 948, 949, 960, 961, 976, 977, 964, 965, 980, 981, 992, 993, 1008, 1009, 996, 997, 1012, 1013],
+ [ 898, 899, 914, 915, 902, 903, 918, 919, 930, 931, 946, 947, 934, 935, 950, 951, 962, 963, 978, 979, 966, 967, 982, 983, 994, 995, 1010, 1011, 998, 999, 1014, 1015],
+ [ 904, 905, 920, 921, 908, 909, 924, 925, 936, 937, 952, 953, 940, 941, 956, 957, 968, 969, 984, 985, 972, 973, 988, 989, 1000, 1001, 1016, 1017, 1004, 1005, 1020, 1021],
+ [ 906, 907, 922, 923, 910, 911, 926, 927, 938, 939, 954, 955, 942, 943, 958, 959, 970, 971, 986, 987, 974, 975, 990, 991, 1002, 1003, 1018, 1019, 1006, 1007, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Step == (2, 2)
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(2, 2),
+ tile_group_steps=(2, 2),
+ )
+ assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=328, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=858, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 16, 17, 4, 5, 20, 21, 32, 33, 48, 49, 36, 37, 52, 53, 64, 65, 80, 81, 68, 69, 84, 85, 96, 97, 112, 113, 100, 101, 116, 117],
+ [ 2, 3, 18, 19, 6, 7, 22, 23, 34, 35, 50, 51, 38, 39, 54, 55, 66, 67, 82, 83, 70, 71, 86, 87, 98, 99, 114, 115, 102, 103, 118, 119],
+ [ 128, 129, 144, 145, 132, 133, 148, 149, 160, 161, 176, 177, 164, 165, 180, 181, 192, 193, 208, 209, 196, 197, 212, 213, 224, 225, 240, 241, 228, 229, 244, 245],
+ [ 130, 131, 146, 147, 134, 135, 150, 151, 162, 163, 178, 179, 166, 167, 182, 183, 194, 195, 210, 211, 198, 199, 214, 215, 226, 227, 242, 243, 230, 231, 246, 247],
+ [ 8, 9, 24, 25, 12, 13, 28, 29, 40, 41, 56, 57, 44, 45, 60, 61, 72, 73, 88, 89, 76, 77, 92, 93, 104, 105, 120, 121, 108, 109, 124, 125],
+ [ 10, 11, 26, 27, 14, 15, 30, 31, 42, 43, 58, 59, 46, 47, 62, 63, 74, 75, 90, 91, 78, 79, 94, 95, 106, 107, 122, 123, 110, 111, 126, 127],
+ [ 136, 137, 152, 153, 140, 141, 156, 157, 168, 169, 184, 185, 172, 173, 188, 189, 200, 201, 216, 217, 204, 205, 220, 221, 232, 233, 248, 249, 236, 237, 252, 253],
+ [ 138, 139, 154, 155, 142, 143, 158, 159, 170, 171, 186, 187, 174, 175, 190, 191, 202, 203, 218, 219, 206, 207, 222, 223, 234, 235, 250, 251, 238, 239, 254, 255],
+ [ 256, 257, 272, 273, 260, 261, 276, 277, 288, 289, 304, 305, 292, 293, 308, 309, 320, 321, 336, 337, 324, 325, 340, 341, 352, 353, 368, 369, 356, 357, 372, 373],
+ [ 258, 259, 274, 275, 262, 263, 278, 279, 290, 291, 306, 307, 294, 295, 310, 311, 322, 323, 338, 339, 326, 327, 342, 343, 354, 355, 370, 371, 358, 359, 374, 375],
+ [ 384, 385, 400, 401, 388, 389, 404, 405, 416, 417, 432, 433, 420, 421, 436, 437, 448, 449, 464, 465, 452, 453, 468, 469, 480, 481, 496, 497, 484, 485, 500, 501],
+ [ 386, 387, 402, 403, 390, 391, 406, 407, 418, 419, 434, 435, 422, 423, 438, 439, 450, 451, 466, 467, 454, 455, 470, 471, 482, 483, 498, 499, 486, 487, 502, 503],
+ [ 264, 265, 280, 281, 268, 269, 284, 285, 296, 297, 312, 313, 300, 301, 316, 317, 328, 329, 344, 345, 332, 333, 348, 349, 360, 361, 376, 377, 364, 365, 380, 381],
+ [ 266, 267, 282, 283, 270, 271, 286, 287, 298, 299, 314, 315, 302, 303, 318, 319, 330, 331, 346, 347, 334, 335, 350, 351, 362, 363, 378, 379, 366, 367, 382, 383],
+ [ 392, 393, 408, 409, 396, 397, 412, 413, 424, 425, 440, 441, 428, 429, 444, 445, 456, 457, 472, 473, 460, 461, 476, 477, 488, 489, 504, 505, 492, 493, 508, 509],
+ [ 394, 395, 410, 411, 398, 399, 414, 415, 426, 427, 442, 443, 430, 431, 446, 447, 458, 459, 474, 475, 462, 463, 478, 479, 490, 491, 506, 507, 494, 495, 510, 511],
+ [ 512, 513, 528, 529, 516, 517, 532, 533, 544, 545, 560, 561, 548, 549, 564, 565, 576, 577, 592, 593, 580, 581, 596, 597, 608, 609, 624, 625, 612, 613, 628, 629],
+ [ 514, 515, 530, 531, 518, 519, 534, 535, 546, 547, 562, 563, 550, 551, 566, 567, 578, 579, 594, 595, 582, 583, 598, 599, 610, 611, 626, 627, 614, 615, 630, 631],
+ [ 640, 641, 656, 657, 644, 645, 660, 661, 672, 673, 688, 689, 676, 677, 692, 693, 704, 705, 720, 721, 708, 709, 724, 725, 736, 737, 752, 753, 740, 741, 756, 757],
+ [ 642, 643, 658, 659, 646, 647, 662, 663, 674, 675, 690, 691, 678, 679, 694, 695, 706, 707, 722, 723, 710, 711, 726, 727, 738, 739, 754, 755, 742, 743, 758, 759],
+ [ 520, 521, 536, 537, 524, 525, 540, 541, 552, 553, 568, 569, 556, 557, 572, 573, 584, 585, 600, 601, 588, 589, 604, 605, 616, 617, 632, 633, 620, 621, 636, 637],
+ [ 522, 523, 538, 539, 526, 527, 542, 543, 554, 555, 570, 571, 558, 559, 574, 575, 586, 587, 602, 603, 590, 591, 606, 607, 618, 619, 634, 635, 622, 623, 638, 639],
+ [ 648, 649, 664, 665, 652, 653, 668, 669, 680, 681, 696, 697, 684, 685, 700, 701, 712, 713, 728, 729, 716, 717, 732, 733, 744, 745, 760, 761, 748, 749, 764, 765],
+ [ 650, 651, 666, 667, 654, 655, 670, 671, 682, 683, 698, 699, 686, 687, 702, 703, 714, 715, 730, 731, 718, 719, 734, 735, 746, 747, 762, 763, 750, 751, 766, 767],
+ [ 768, 769, 784, 785, 772, 773, 788, 789, 800, 801, 816, 817, 804, 805, 820, 821, 832, 833, 848, 849, 836, 837, 852, 853, 864, 865, 880, 881, 868, 869, 884, 885],
+ [ 770, 771, 786, 787, 774, 775, 790, 791, 802, 803, 818, 819, 806, 807, 822, 823, 834, 835, 850, 851, 838, 839, 854, 855, 866, 867, 882, 883, 870, 871, 886, 887],
+ [ 896, 897, 912, 913, 900, 901, 916, 917, 928, 929, 944, 945, 932, 933, 948, 949, 960, 961, 976, 977, 964, 965, 980, 981, 992, 993, 1008, 1009, 996, 997, 1012, 1013],
+ [ 898, 899, 914, 915, 902, 903, 918, 919, 930, 931, 946, 947, 934, 935, 950, 951, 962, 963, 978, 979, 966, 967, 982, 983, 994, 995, 1010, 1011, 998, 999, 1014, 1015],
+ [ 776, 777, 792, 793, 780, 781, 796, 797, 808, 809, 824, 825, 812, 813, 828, 829, 840, 841, 856, 857, 844, 845, 860, 861, 872, 873, 888, 889, 876, 877, 892, 893],
+ [ 778, 779, 794, 795, 782, 783, 798, 799, 810, 811, 826, 827, 814, 815, 830, 831, 842, 843, 858, 859, 846, 847, 862, 863, 874, 875, 890, 891, 878, 879, 894, 895],
+ [ 904, 905, 920, 921, 908, 909, 924, 925, 936, 937, 952, 953, 940, 941, 956, 957, 968, 969, 984, 985, 972, 973, 988, 989, 1000, 1001, 1016, 1017, 1004, 1005, 1020, 1021],
+ [ 906, 907, 922, 923, 910, 911, 926, 927, 938, 939, 954, 955, 942, 943, 958, 959, 970, 971, 986, 987, 974, 975, 990, 991, 1002, 1003, 1018, 1019, 1006, 1007, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Step == (2, 2)
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(2, 2),
+ tile_group_steps=(2, 2),
+ )
+ assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=328, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=858, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 16, 17, 4, 5, 20, 21, 32, 33, 48, 49, 36, 37, 52, 53, 64, 65, 80, 81, 68, 69, 84, 85, 96, 97, 112, 113, 100, 101, 116, 117],
+ [ 2, 3, 18, 19, 6, 7, 22, 23, 34, 35, 50, 51, 38, 39, 54, 55, 66, 67, 82, 83, 70, 71, 86, 87, 98, 99, 114, 115, 102, 103, 118, 119],
+ [ 128, 129, 144, 145, 132, 133, 148, 149, 160, 161, 176, 177, 164, 165, 180, 181, 192, 193, 208, 209, 196, 197, 212, 213, 224, 225, 240, 241, 228, 229, 244, 245],
+ [ 130, 131, 146, 147, 134, 135, 150, 151, 162, 163, 178, 179, 166, 167, 182, 183, 194, 195, 210, 211, 198, 199, 214, 215, 226, 227, 242, 243, 230, 231, 246, 247],
+ [ 8, 9, 24, 25, 12, 13, 28, 29, 40, 41, 56, 57, 44, 45, 60, 61, 72, 73, 88, 89, 76, 77, 92, 93, 104, 105, 120, 121, 108, 109, 124, 125],
+ [ 10, 11, 26, 27, 14, 15, 30, 31, 42, 43, 58, 59, 46, 47, 62, 63, 74, 75, 90, 91, 78, 79, 94, 95, 106, 107, 122, 123, 110, 111, 126, 127],
+ [ 136, 137, 152, 153, 140, 141, 156, 157, 168, 169, 184, 185, 172, 173, 188, 189, 200, 201, 216, 217, 204, 205, 220, 221, 232, 233, 248, 249, 236, 237, 252, 253],
+ [ 138, 139, 154, 155, 142, 143, 158, 159, 170, 171, 186, 187, 174, 175, 190, 191, 202, 203, 218, 219, 206, 207, 222, 223, 234, 235, 250, 251, 238, 239, 254, 255],
+ [ 256, 257, 272, 273, 260, 261, 276, 277, 288, 289, 304, 305, 292, 293, 308, 309, 320, 321, 336, 337, 324, 325, 340, 341, 352, 353, 368, 369, 356, 357, 372, 373],
+ [ 258, 259, 274, 275, 262, 263, 278, 279, 290, 291, 306, 307, 294, 295, 310, 311, 322, 323, 338, 339, 326, 327, 342, 343, 354, 355, 370, 371, 358, 359, 374, 375],
+ [ 384, 385, 400, 401, 388, 389, 404, 405, 416, 417, 432, 433, 420, 421, 436, 437, 448, 449, 464, 465, 452, 453, 468, 469, 480, 481, 496, 497, 484, 485, 500, 501],
+ [ 386, 387, 402, 403, 390, 391, 406, 407, 418, 419, 434, 435, 422, 423, 438, 439, 450, 451, 466, 467, 454, 455, 470, 471, 482, 483, 498, 499, 486, 487, 502, 503],
+ [ 264, 265, 280, 281, 268, 269, 284, 285, 296, 297, 312, 313, 300, 301, 316, 317, 328, 329, 344, 345, 332, 333, 348, 349, 360, 361, 376, 377, 364, 365, 380, 381],
+ [ 266, 267, 282, 283, 270, 271, 286, 287, 298, 299, 314, 315, 302, 303, 318, 319, 330, 331, 346, 347, 334, 335, 350, 351, 362, 363, 378, 379, 366, 367, 382, 383],
+ [ 392, 393, 408, 409, 396, 397, 412, 413, 424, 425, 440, 441, 428, 429, 444, 445, 456, 457, 472, 473, 460, 461, 476, 477, 488, 489, 504, 505, 492, 493, 508, 509],
+ [ 394, 395, 410, 411, 398, 399, 414, 415, 426, 427, 442, 443, 430, 431, 446, 447, 458, 459, 474, 475, 462, 463, 478, 479, 490, 491, 506, 507, 494, 495, 510, 511],
+ [ 512, 513, 528, 529, 516, 517, 532, 533, 544, 545, 560, 561, 548, 549, 564, 565, 576, 577, 592, 593, 580, 581, 596, 597, 608, 609, 624, 625, 612, 613, 628, 629],
+ [ 514, 515, 530, 531, 518, 519, 534, 535, 546, 547, 562, 563, 550, 551, 566, 567, 578, 579, 594, 595, 582, 583, 598, 599, 610, 611, 626, 627, 614, 615, 630, 631],
+ [ 640, 641, 656, 657, 644, 645, 660, 661, 672, 673, 688, 689, 676, 677, 692, 693, 704, 705, 720, 721, 708, 709, 724, 725, 736, 737, 752, 753, 740, 741, 756, 757],
+ [ 642, 643, 658, 659, 646, 647, 662, 663, 674, 675, 690, 691, 678, 679, 694, 695, 706, 707, 722, 723, 710, 711, 726, 727, 738, 739, 754, 755, 742, 743, 758, 759],
+ [ 520, 521, 536, 537, 524, 525, 540, 541, 552, 553, 568, 569, 556, 557, 572, 573, 584, 585, 600, 601, 588, 589, 604, 605, 616, 617, 632, 633, 620, 621, 636, 637],
+ [ 522, 523, 538, 539, 526, 527, 542, 543, 554, 555, 570, 571, 558, 559, 574, 575, 586, 587, 602, 603, 590, 591, 606, 607, 618, 619, 634, 635, 622, 623, 638, 639],
+ [ 648, 649, 664, 665, 652, 653, 668, 669, 680, 681, 696, 697, 684, 685, 700, 701, 712, 713, 728, 729, 716, 717, 732, 733, 744, 745, 760, 761, 748, 749, 764, 765],
+ [ 650, 651, 666, 667, 654, 655, 670, 671, 682, 683, 698, 699, 686, 687, 702, 703, 714, 715, 730, 731, 718, 719, 734, 735, 746, 747, 762, 763, 750, 751, 766, 767],
+ [ 768, 769, 784, 785, 772, 773, 788, 789, 800, 801, 816, 817, 804, 805, 820, 821, 832, 833, 848, 849, 836, 837, 852, 853, 864, 865, 880, 881, 868, 869, 884, 885],
+ [ 770, 771, 786, 787, 774, 775, 790, 791, 802, 803, 818, 819, 806, 807, 822, 823, 834, 835, 850, 851, 838, 839, 854, 855, 866, 867, 882, 883, 870, 871, 886, 887],
+ [ 896, 897, 912, 913, 900, 901, 916, 917, 928, 929, 944, 945, 932, 933, 948, 949, 960, 961, 976, 977, 964, 965, 980, 981, 992, 993, 1008, 1009, 996, 997, 1012, 1013],
+ [ 898, 899, 914, 915, 902, 903, 918, 919, 930, 931, 946, 947, 934, 935, 950, 951, 962, 963, 978, 979, 966, 967, 982, 983, 994, 995, 1010, 1011, 998, 999, 1014, 1015],
+ [ 776, 777, 792, 793, 780, 781, 796, 797, 808, 809, 824, 825, 812, 813, 828, 829, 840, 841, 856, 857, 844, 845, 860, 861, 872, 873, 888, 889, 876, 877, 892, 893],
+ [ 778, 779, 794, 795, 782, 783, 798, 799, 810, 811, 826, 827, 814, 815, 830, 831, 842, 843, 858, 859, 846, 847, 862, 863, 874, 875, 890, 891, 878, 879, 894, 895],
+ [ 904, 905, 920, 921, 908, 909, 924, 925, 936, 937, 952, 953, 940, 941, 956, 957, 968, 969, 984, 985, 972, 973, 988, 989, 1000, 1001, 1016, 1017, 1004, 1005, 1020, 1021],
+ [ 906, 907, 922, 923, 910, 911, 926, 927, 938, 939, 954, 955, 942, 943, 958, 959, 970, 971, 986, 987, 974, 975, 990, 991, 1002, 1003, 1018, 1019, 1006, 1007, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Repeat across column/row
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(32 // 4, 32 // 4),
+ tile_group_steps=(2, 2),
+ )
+ assert len(tiles) == 4 # (32//(2*(32//4))) * (32//(2*(32//4)))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[2] == TensorTile(
+ (32, 32), offset=64, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1]
+ )
+ assert tiles[3] == TensorTile(
+ (32, 32), offset=66, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 256, 257, 4, 5, 260, 261, 8, 9, 264, 265, 12, 13, 268, 269, 16, 17, 272, 273, 20, 21, 276, 277, 24, 25, 280, 281, 28, 29, 284, 285],
+ [ 2, 3, 258, 259, 6, 7, 262, 263, 10, 11, 266, 267, 14, 15, 270, 271, 18, 19, 274, 275, 22, 23, 278, 279, 26, 27, 282, 283, 30, 31, 286, 287],
+ [ 512, 513, 768, 769, 516, 517, 772, 773, 520, 521, 776, 777, 524, 525, 780, 781, 528, 529, 784, 785, 532, 533, 788, 789, 536, 537, 792, 793, 540, 541, 796, 797],
+ [ 514, 515, 770, 771, 518, 519, 774, 775, 522, 523, 778, 779, 526, 527, 782, 783, 530, 531, 786, 787, 534, 535, 790, 791, 538, 539, 794, 795, 542, 543, 798, 799],
+ [ 32, 33, 288, 289, 36, 37, 292, 293, 40, 41, 296, 297, 44, 45, 300, 301, 48, 49, 304, 305, 52, 53, 308, 309, 56, 57, 312, 313, 60, 61, 316, 317],
+ [ 34, 35, 290, 291, 38, 39, 294, 295, 42, 43, 298, 299, 46, 47, 302, 303, 50, 51, 306, 307, 54, 55, 310, 311, 58, 59, 314, 315, 62, 63, 318, 319],
+ [ 544, 545, 800, 801, 548, 549, 804, 805, 552, 553, 808, 809, 556, 557, 812, 813, 560, 561, 816, 817, 564, 565, 820, 821, 568, 569, 824, 825, 572, 573, 828, 829],
+ [ 546, 547, 802, 803, 550, 551, 806, 807, 554, 555, 810, 811, 558, 559, 814, 815, 562, 563, 818, 819, 566, 567, 822, 823, 570, 571, 826, 827, 574, 575, 830, 831],
+ [ 64, 65, 320, 321, 68, 69, 324, 325, 72, 73, 328, 329, 76, 77, 332, 333, 80, 81, 336, 337, 84, 85, 340, 341, 88, 89, 344, 345, 92, 93, 348, 349],
+ [ 66, 67, 322, 323, 70, 71, 326, 327, 74, 75, 330, 331, 78, 79, 334, 335, 82, 83, 338, 339, 86, 87, 342, 343, 90, 91, 346, 347, 94, 95, 350, 351],
+ [ 576, 577, 832, 833, 580, 581, 836, 837, 584, 585, 840, 841, 588, 589, 844, 845, 592, 593, 848, 849, 596, 597, 852, 853, 600, 601, 856, 857, 604, 605, 860, 861],
+ [ 578, 579, 834, 835, 582, 583, 838, 839, 586, 587, 842, 843, 590, 591, 846, 847, 594, 595, 850, 851, 598, 599, 854, 855, 602, 603, 858, 859, 606, 607, 862, 863],
+ [ 96, 97, 352, 353, 100, 101, 356, 357, 104, 105, 360, 361, 108, 109, 364, 365, 112, 113, 368, 369, 116, 117, 372, 373, 120, 121, 376, 377, 124, 125, 380, 381],
+ [ 98, 99, 354, 355, 102, 103, 358, 359, 106, 107, 362, 363, 110, 111, 366, 367, 114, 115, 370, 371, 118, 119, 374, 375, 122, 123, 378, 379, 126, 127, 382, 383],
+ [ 608, 609, 864, 865, 612, 613, 868, 869, 616, 617, 872, 873, 620, 621, 876, 877, 624, 625, 880, 881, 628, 629, 884, 885, 632, 633, 888, 889, 636, 637, 892, 893],
+ [ 610, 611, 866, 867, 614, 615, 870, 871, 618, 619, 874, 875, 622, 623, 878, 879, 626, 627, 882, 883, 630, 631, 886, 887, 634, 635, 890, 891, 638, 639, 894, 895],
+ [ 128, 129, 384, 385, 132, 133, 388, 389, 136, 137, 392, 393, 140, 141, 396, 397, 144, 145, 400, 401, 148, 149, 404, 405, 152, 153, 408, 409, 156, 157, 412, 413],
+ [ 130, 131, 386, 387, 134, 135, 390, 391, 138, 139, 394, 395, 142, 143, 398, 399, 146, 147, 402, 403, 150, 151, 406, 407, 154, 155, 410, 411, 158, 159, 414, 415],
+ [ 640, 641, 896, 897, 644, 645, 900, 901, 648, 649, 904, 905, 652, 653, 908, 909, 656, 657, 912, 913, 660, 661, 916, 917, 664, 665, 920, 921, 668, 669, 924, 925],
+ [ 642, 643, 898, 899, 646, 647, 902, 903, 650, 651, 906, 907, 654, 655, 910, 911, 658, 659, 914, 915, 662, 663, 918, 919, 666, 667, 922, 923, 670, 671, 926, 927],
+ [ 160, 161, 416, 417, 164, 165, 420, 421, 168, 169, 424, 425, 172, 173, 428, 429, 176, 177, 432, 433, 180, 181, 436, 437, 184, 185, 440, 441, 188, 189, 444, 445],
+ [ 162, 163, 418, 419, 166, 167, 422, 423, 170, 171, 426, 427, 174, 175, 430, 431, 178, 179, 434, 435, 182, 183, 438, 439, 186, 187, 442, 443, 190, 191, 446, 447],
+ [ 672, 673, 928, 929, 676, 677, 932, 933, 680, 681, 936, 937, 684, 685, 940, 941, 688, 689, 944, 945, 692, 693, 948, 949, 696, 697, 952, 953, 700, 701, 956, 957],
+ [ 674, 675, 930, 931, 678, 679, 934, 935, 682, 683, 938, 939, 686, 687, 942, 943, 690, 691, 946, 947, 694, 695, 950, 951, 698, 699, 954, 955, 702, 703, 958, 959],
+ [ 192, 193, 448, 449, 196, 197, 452, 453, 200, 201, 456, 457, 204, 205, 460, 461, 208, 209, 464, 465, 212, 213, 468, 469, 216, 217, 472, 473, 220, 221, 476, 477],
+ [ 194, 195, 450, 451, 198, 199, 454, 455, 202, 203, 458, 459, 206, 207, 462, 463, 210, 211, 466, 467, 214, 215, 470, 471, 218, 219, 474, 475, 222, 223, 478, 479],
+ [ 704, 705, 960, 961, 708, 709, 964, 965, 712, 713, 968, 969, 716, 717, 972, 973, 720, 721, 976, 977, 724, 725, 980, 981, 728, 729, 984, 985, 732, 733, 988, 989],
+ [ 706, 707, 962, 963, 710, 711, 966, 967, 714, 715, 970, 971, 718, 719, 974, 975, 722, 723, 978, 979, 726, 727, 982, 983, 730, 731, 986, 987, 734, 735, 990, 991],
+ [ 224, 225, 480, 481, 228, 229, 484, 485, 232, 233, 488, 489, 236, 237, 492, 493, 240, 241, 496, 497, 244, 245, 500, 501, 248, 249, 504, 505, 252, 253, 508, 509],
+ [ 226, 227, 482, 483, 230, 231, 486, 487, 234, 235, 490, 491, 238, 239, 494, 495, 242, 243, 498, 499, 246, 247, 502, 503, 250, 251, 506, 507, 254, 255, 510, 511],
+ [ 736, 737, 992, 993, 740, 741, 996, 997, 744, 745, 1000, 1001, 748, 749, 1004, 1005, 752, 753, 1008, 1009, 756, 757, 1012, 1013, 760, 761, 1016, 1017, 764, 765, 1020, 1021],
+ [ 738, 739, 994, 995, 742, 743, 998, 999, 746, 747, 1002, 1003, 750, 751, 1006, 1007, 754, 755, 1010, 1011, 758, 759, 1014, 1015, 762, 763, 1018, 1019, 766, 767, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Repeat one dimension
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(1, 32 // 4),
+ tile_group_steps=(2, 2),
+ )
+ assert len(tiles) == (32 // (2 * 1)) * (32 // (2 * (32 // 4)))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=832, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=962, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 32, 33, 4, 5, 36, 37, 8, 9, 40, 41, 12, 13, 44, 45, 16, 17, 48, 49, 20, 21, 52, 53, 24, 25, 56, 57, 28, 29, 60, 61],
+ [ 2, 3, 34, 35, 6, 7, 38, 39, 10, 11, 42, 43, 14, 15, 46, 47, 18, 19, 50, 51, 22, 23, 54, 55, 26, 27, 58, 59, 30, 31, 62, 63],
+ [ 64, 65, 96, 97, 68, 69, 100, 101, 72, 73, 104, 105, 76, 77, 108, 109, 80, 81, 112, 113, 84, 85, 116, 117, 88, 89, 120, 121, 92, 93, 124, 125],
+ [ 66, 67, 98, 99, 70, 71, 102, 103, 74, 75, 106, 107, 78, 79, 110, 111, 82, 83, 114, 115, 86, 87, 118, 119, 90, 91, 122, 123, 94, 95, 126, 127],
+ [ 128, 129, 160, 161, 132, 133, 164, 165, 136, 137, 168, 169, 140, 141, 172, 173, 144, 145, 176, 177, 148, 149, 180, 181, 152, 153, 184, 185, 156, 157, 188, 189],
+ [ 130, 131, 162, 163, 134, 135, 166, 167, 138, 139, 170, 171, 142, 143, 174, 175, 146, 147, 178, 179, 150, 151, 182, 183, 154, 155, 186, 187, 158, 159, 190, 191],
+ [ 192, 193, 224, 225, 196, 197, 228, 229, 200, 201, 232, 233, 204, 205, 236, 237, 208, 209, 240, 241, 212, 213, 244, 245, 216, 217, 248, 249, 220, 221, 252, 253],
+ [ 194, 195, 226, 227, 198, 199, 230, 231, 202, 203, 234, 235, 206, 207, 238, 239, 210, 211, 242, 243, 214, 215, 246, 247, 218, 219, 250, 251, 222, 223, 254, 255],
+ [ 256, 257, 288, 289, 260, 261, 292, 293, 264, 265, 296, 297, 268, 269, 300, 301, 272, 273, 304, 305, 276, 277, 308, 309, 280, 281, 312, 313, 284, 285, 316, 317],
+ [ 258, 259, 290, 291, 262, 263, 294, 295, 266, 267, 298, 299, 270, 271, 302, 303, 274, 275, 306, 307, 278, 279, 310, 311, 282, 283, 314, 315, 286, 287, 318, 319],
+ [ 320, 321, 352, 353, 324, 325, 356, 357, 328, 329, 360, 361, 332, 333, 364, 365, 336, 337, 368, 369, 340, 341, 372, 373, 344, 345, 376, 377, 348, 349, 380, 381],
+ [ 322, 323, 354, 355, 326, 327, 358, 359, 330, 331, 362, 363, 334, 335, 366, 367, 338, 339, 370, 371, 342, 343, 374, 375, 346, 347, 378, 379, 350, 351, 382, 383],
+ [ 384, 385, 416, 417, 388, 389, 420, 421, 392, 393, 424, 425, 396, 397, 428, 429, 400, 401, 432, 433, 404, 405, 436, 437, 408, 409, 440, 441, 412, 413, 444, 445],
+ [ 386, 387, 418, 419, 390, 391, 422, 423, 394, 395, 426, 427, 398, 399, 430, 431, 402, 403, 434, 435, 406, 407, 438, 439, 410, 411, 442, 443, 414, 415, 446, 447],
+ [ 448, 449, 480, 481, 452, 453, 484, 485, 456, 457, 488, 489, 460, 461, 492, 493, 464, 465, 496, 497, 468, 469, 500, 501, 472, 473, 504, 505, 476, 477, 508, 509],
+ [ 450, 451, 482, 483, 454, 455, 486, 487, 458, 459, 490, 491, 462, 463, 494, 495, 466, 467, 498, 499, 470, 471, 502, 503, 474, 475, 506, 507, 478, 479, 510, 511],
+ [ 512, 513, 544, 545, 516, 517, 548, 549, 520, 521, 552, 553, 524, 525, 556, 557, 528, 529, 560, 561, 532, 533, 564, 565, 536, 537, 568, 569, 540, 541, 572, 573],
+ [ 514, 515, 546, 547, 518, 519, 550, 551, 522, 523, 554, 555, 526, 527, 558, 559, 530, 531, 562, 563, 534, 535, 566, 567, 538, 539, 570, 571, 542, 543, 574, 575],
+ [ 576, 577, 608, 609, 580, 581, 612, 613, 584, 585, 616, 617, 588, 589, 620, 621, 592, 593, 624, 625, 596, 597, 628, 629, 600, 601, 632, 633, 604, 605, 636, 637],
+ [ 578, 579, 610, 611, 582, 583, 614, 615, 586, 587, 618, 619, 590, 591, 622, 623, 594, 595, 626, 627, 598, 599, 630, 631, 602, 603, 634, 635, 606, 607, 638, 639],
+ [ 640, 641, 672, 673, 644, 645, 676, 677, 648, 649, 680, 681, 652, 653, 684, 685, 656, 657, 688, 689, 660, 661, 692, 693, 664, 665, 696, 697, 668, 669, 700, 701],
+ [ 642, 643, 674, 675, 646, 647, 678, 679, 650, 651, 682, 683, 654, 655, 686, 687, 658, 659, 690, 691, 662, 663, 694, 695, 666, 667, 698, 699, 670, 671, 702, 703],
+ [ 704, 705, 736, 737, 708, 709, 740, 741, 712, 713, 744, 745, 716, 717, 748, 749, 720, 721, 752, 753, 724, 725, 756, 757, 728, 729, 760, 761, 732, 733, 764, 765],
+ [ 706, 707, 738, 739, 710, 711, 742, 743, 714, 715, 746, 747, 718, 719, 750, 751, 722, 723, 754, 755, 726, 727, 758, 759, 730, 731, 762, 763, 734, 735, 766, 767],
+ [ 768, 769, 800, 801, 772, 773, 804, 805, 776, 777, 808, 809, 780, 781, 812, 813, 784, 785, 816, 817, 788, 789, 820, 821, 792, 793, 824, 825, 796, 797, 828, 829],
+ [ 770, 771, 802, 803, 774, 775, 806, 807, 778, 779, 810, 811, 782, 783, 814, 815, 786, 787, 818, 819, 790, 791, 822, 823, 794, 795, 826, 827, 798, 799, 830, 831],
+ [ 832, 833, 864, 865, 836, 837, 868, 869, 840, 841, 872, 873, 844, 845, 876, 877, 848, 849, 880, 881, 852, 853, 884, 885, 856, 857, 888, 889, 860, 861, 892, 893],
+ [ 834, 835, 866, 867, 838, 839, 870, 871, 842, 843, 874, 875, 846, 847, 878, 879, 850, 851, 882, 883, 854, 855, 886, 887, 858, 859, 890, 891, 862, 863, 894, 895],
+ [ 896, 897, 928, 929, 900, 901, 932, 933, 904, 905, 936, 937, 908, 909, 940, 941, 912, 913, 944, 945, 916, 917, 948, 949, 920, 921, 952, 953, 924, 925, 956, 957],
+ [ 898, 899, 930, 931, 902, 903, 934, 935, 906, 907, 938, 939, 910, 911, 942, 943, 914, 915, 946, 947, 918, 919, 950, 951, 922, 923, 954, 955, 926, 927, 958, 959],
+ [ 960, 961, 992, 993, 964, 965, 996, 997, 968, 969, 1000, 1001, 972, 973, 1004, 1005, 976, 977, 1008, 1009, 980, 981, 1012, 1013, 984, 985, 1016, 1017, 988, 989, 1020, 1021],
+ [ 962, 963, 994, 995, 966, 967, 998, 999, 970, 971, 1002, 1003, 974, 975, 1006, 1007, 978, 979, 1010, 1011, 982, 983, 1014, 1015, 986, 987, 1018, 1019, 990, 991, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Repeat other dimension
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(32 // 4, 1),
+ tile_group_steps=(2, 2),
+ )
+ assert len(tiles) == (32 // (2 * 1)) * (32 // (2 * (32 // 4)))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1]
+ )
+ assert tiles[26] == TensorTile(
+ (32, 32), offset=84, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=94, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 32, 33, 64, 65, 96, 97, 128, 129, 160, 161, 192, 193, 224, 225, 256, 257, 288, 289, 320, 321, 352, 353, 384, 385, 416, 417, 448, 449, 480, 481],
+ [ 2, 3, 34, 35, 66, 67, 98, 99, 130, 131, 162, 163, 194, 195, 226, 227, 258, 259, 290, 291, 322, 323, 354, 355, 386, 387, 418, 419, 450, 451, 482, 483],
+ [ 512, 513, 544, 545, 576, 577, 608, 609, 640, 641, 672, 673, 704, 705, 736, 737, 768, 769, 800, 801, 832, 833, 864, 865, 896, 897, 928, 929, 960, 961, 992, 993],
+ [ 514, 515, 546, 547, 578, 579, 610, 611, 642, 643, 674, 675, 706, 707, 738, 739, 770, 771, 802, 803, 834, 835, 866, 867, 898, 899, 930, 931, 962, 963, 994, 995],
+ [ 4, 5, 36, 37, 68, 69, 100, 101, 132, 133, 164, 165, 196, 197, 228, 229, 260, 261, 292, 293, 324, 325, 356, 357, 388, 389, 420, 421, 452, 453, 484, 485],
+ [ 6, 7, 38, 39, 70, 71, 102, 103, 134, 135, 166, 167, 198, 199, 230, 231, 262, 263, 294, 295, 326, 327, 358, 359, 390, 391, 422, 423, 454, 455, 486, 487],
+ [ 516, 517, 548, 549, 580, 581, 612, 613, 644, 645, 676, 677, 708, 709, 740, 741, 772, 773, 804, 805, 836, 837, 868, 869, 900, 901, 932, 933, 964, 965, 996, 997],
+ [ 518, 519, 550, 551, 582, 583, 614, 615, 646, 647, 678, 679, 710, 711, 742, 743, 774, 775, 806, 807, 838, 839, 870, 871, 902, 903, 934, 935, 966, 967, 998, 999],
+ [ 8, 9, 40, 41, 72, 73, 104, 105, 136, 137, 168, 169, 200, 201, 232, 233, 264, 265, 296, 297, 328, 329, 360, 361, 392, 393, 424, 425, 456, 457, 488, 489],
+ [ 10, 11, 42, 43, 74, 75, 106, 107, 138, 139, 170, 171, 202, 203, 234, 235, 266, 267, 298, 299, 330, 331, 362, 363, 394, 395, 426, 427, 458, 459, 490, 491],
+ [ 520, 521, 552, 553, 584, 585, 616, 617, 648, 649, 680, 681, 712, 713, 744, 745, 776, 777, 808, 809, 840, 841, 872, 873, 904, 905, 936, 937, 968, 969, 1000, 1001],
+ [ 522, 523, 554, 555, 586, 587, 618, 619, 650, 651, 682, 683, 714, 715, 746, 747, 778, 779, 810, 811, 842, 843, 874, 875, 906, 907, 938, 939, 970, 971, 1002, 1003],
+ [ 12, 13, 44, 45, 76, 77, 108, 109, 140, 141, 172, 173, 204, 205, 236, 237, 268, 269, 300, 301, 332, 333, 364, 365, 396, 397, 428, 429, 460, 461, 492, 493],
+ [ 14, 15, 46, 47, 78, 79, 110, 111, 142, 143, 174, 175, 206, 207, 238, 239, 270, 271, 302, 303, 334, 335, 366, 367, 398, 399, 430, 431, 462, 463, 494, 495],
+ [ 524, 525, 556, 557, 588, 589, 620, 621, 652, 653, 684, 685, 716, 717, 748, 749, 780, 781, 812, 813, 844, 845, 876, 877, 908, 909, 940, 941, 972, 973, 1004, 1005],
+ [ 526, 527, 558, 559, 590, 591, 622, 623, 654, 655, 686, 687, 718, 719, 750, 751, 782, 783, 814, 815, 846, 847, 878, 879, 910, 911, 942, 943, 974, 975, 1006, 1007],
+ [ 16, 17, 48, 49, 80, 81, 112, 113, 144, 145, 176, 177, 208, 209, 240, 241, 272, 273, 304, 305, 336, 337, 368, 369, 400, 401, 432, 433, 464, 465, 496, 497],
+ [ 18, 19, 50, 51, 82, 83, 114, 115, 146, 147, 178, 179, 210, 211, 242, 243, 274, 275, 306, 307, 338, 339, 370, 371, 402, 403, 434, 435, 466, 467, 498, 499],
+ [ 528, 529, 560, 561, 592, 593, 624, 625, 656, 657, 688, 689, 720, 721, 752, 753, 784, 785, 816, 817, 848, 849, 880, 881, 912, 913, 944, 945, 976, 977, 1008, 1009],
+ [ 530, 531, 562, 563, 594, 595, 626, 627, 658, 659, 690, 691, 722, 723, 754, 755, 786, 787, 818, 819, 850, 851, 882, 883, 914, 915, 946, 947, 978, 979, 1010, 1011],
+ [ 20, 21, 52, 53, 84, 85, 116, 117, 148, 149, 180, 181, 212, 213, 244, 245, 276, 277, 308, 309, 340, 341, 372, 373, 404, 405, 436, 437, 468, 469, 500, 501],
+ [ 22, 23, 54, 55, 86, 87, 118, 119, 150, 151, 182, 183, 214, 215, 246, 247, 278, 279, 310, 311, 342, 343, 374, 375, 406, 407, 438, 439, 470, 471, 502, 503],
+ [ 532, 533, 564, 565, 596, 597, 628, 629, 660, 661, 692, 693, 724, 725, 756, 757, 788, 789, 820, 821, 852, 853, 884, 885, 916, 917, 948, 949, 980, 981, 1012, 1013],
+ [ 534, 535, 566, 567, 598, 599, 630, 631, 662, 663, 694, 695, 726, 727, 758, 759, 790, 791, 822, 823, 854, 855, 886, 887, 918, 919, 950, 951, 982, 983, 1014, 1015],
+ [ 24, 25, 56, 57, 88, 89, 120, 121, 152, 153, 184, 185, 216, 217, 248, 249, 280, 281, 312, 313, 344, 345, 376, 377, 408, 409, 440, 441, 472, 473, 504, 505],
+ [ 26, 27, 58, 59, 90, 91, 122, 123, 154, 155, 186, 187, 218, 219, 250, 251, 282, 283, 314, 315, 346, 347, 378, 379, 410, 411, 442, 443, 474, 475, 506, 507],
+ [ 536, 537, 568, 569, 600, 601, 632, 633, 664, 665, 696, 697, 728, 729, 760, 761, 792, 793, 824, 825, 856, 857, 888, 889, 920, 921, 952, 953, 984, 985, 1016, 1017],
+ [ 538, 539, 570, 571, 602, 603, 634, 635, 666, 667, 698, 699, 730, 731, 762, 763, 794, 795, 826, 827, 858, 859, 890, 891, 922, 923, 954, 955, 986, 987, 1018, 1019],
+ [ 28, 29, 60, 61, 92, 93, 124, 125, 156, 157, 188, 189, 220, 221, 252, 253, 284, 285, 316, 317, 348, 349, 380, 381, 412, 413, 444, 445, 476, 477, 508, 509],
+ [ 30, 31, 62, 63, 94, 95, 126, 127, 158, 159, 190, 191, 222, 223, 254, 255, 286, 287, 318, 319, 350, 351, 382, 383, 414, 415, 446, 447, 478, 479, 510, 511],
+ [ 540, 541, 572, 573, 604, 605, 636, 637, 668, 669, 700, 701, 732, 733, 764, 765, 796, 797, 828, 829, 860, 861, 892, 893, 924, 925, 956, 957, 988, 989, 1020, 1021],
+ [ 542, 543, 574, 575, 606, 607, 638, 639, 670, 671, 702, 703, 734, 735, 766, 767, 798, 799, 830, 831, 862, 863, 894, 895, 926, 927, 958, 959, 990, 991, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Different repeats and steps
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(8, 2),
+ tile_group_steps=(2, 4),
+ )
+ assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1]
+ )
+ assert tiles[12] == TensorTile(
+ (32, 32), offset=80, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=86, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 1, 64, 65, 128, 129, 192, 193, 4, 5, 68, 69, 132, 133, 196, 197, 256, 257, 320, 321, 384, 385, 448, 449, 260, 261, 324, 325, 388, 389, 452, 453],
+ [ 2, 3, 66, 67, 130, 131, 194, 195, 6, 7, 70, 71, 134, 135, 198, 199, 258, 259, 322, 323, 386, 387, 450, 451, 262, 263, 326, 327, 390, 391, 454, 455],
+ [ 512, 513, 576, 577, 640, 641, 704, 705, 516, 517, 580, 581, 644, 645, 708, 709, 768, 769, 832, 833, 896, 897, 960, 961, 772, 773, 836, 837, 900, 901, 964, 965],
+ [ 514, 515, 578, 579, 642, 643, 706, 707, 518, 519, 582, 583, 646, 647, 710, 711, 770, 771, 834, 835, 898, 899, 962, 963, 774, 775, 838, 839, 902, 903, 966, 967],
+ [ 8, 9, 72, 73, 136, 137, 200, 201, 12, 13, 76, 77, 140, 141, 204, 205, 264, 265, 328, 329, 392, 393, 456, 457, 268, 269, 332, 333, 396, 397, 460, 461],
+ [ 10, 11, 74, 75, 138, 139, 202, 203, 14, 15, 78, 79, 142, 143, 206, 207, 266, 267, 330, 331, 394, 395, 458, 459, 270, 271, 334, 335, 398, 399, 462, 463],
+ [ 520, 521, 584, 585, 648, 649, 712, 713, 524, 525, 588, 589, 652, 653, 716, 717, 776, 777, 840, 841, 904, 905, 968, 969, 780, 781, 844, 845, 908, 909, 972, 973],
+ [ 522, 523, 586, 587, 650, 651, 714, 715, 526, 527, 590, 591, 654, 655, 718, 719, 778, 779, 842, 843, 906, 907, 970, 971, 782, 783, 846, 847, 910, 911, 974, 975],
+ [ 16, 17, 80, 81, 144, 145, 208, 209, 20, 21, 84, 85, 148, 149, 212, 213, 272, 273, 336, 337, 400, 401, 464, 465, 276, 277, 340, 341, 404, 405, 468, 469],
+ [ 18, 19, 82, 83, 146, 147, 210, 211, 22, 23, 86, 87, 150, 151, 214, 215, 274, 275, 338, 339, 402, 403, 466, 467, 278, 279, 342, 343, 406, 407, 470, 471],
+ [ 528, 529, 592, 593, 656, 657, 720, 721, 532, 533, 596, 597, 660, 661, 724, 725, 784, 785, 848, 849, 912, 913, 976, 977, 788, 789, 852, 853, 916, 917, 980, 981],
+ [ 530, 531, 594, 595, 658, 659, 722, 723, 534, 535, 598, 599, 662, 663, 726, 727, 786, 787, 850, 851, 914, 915, 978, 979, 790, 791, 854, 855, 918, 919, 982, 983],
+ [ 24, 25, 88, 89, 152, 153, 216, 217, 28, 29, 92, 93, 156, 157, 220, 221, 280, 281, 344, 345, 408, 409, 472, 473, 284, 285, 348, 349, 412, 413, 476, 477],
+ [ 26, 27, 90, 91, 154, 155, 218, 219, 30, 31, 94, 95, 158, 159, 222, 223, 282, 283, 346, 347, 410, 411, 474, 475, 286, 287, 350, 351, 414, 415, 478, 479],
+ [ 536, 537, 600, 601, 664, 665, 728, 729, 540, 541, 604, 605, 668, 669, 732, 733, 792, 793, 856, 857, 920, 921, 984, 985, 796, 797, 860, 861, 924, 925, 988, 989],
+ [ 538, 539, 602, 603, 666, 667, 730, 731, 542, 543, 606, 607, 670, 671, 734, 735, 794, 795, 858, 859, 922, 923, 986, 987, 798, 799, 862, 863, 926, 927, 990, 991],
+ [ 32, 33, 96, 97, 160, 161, 224, 225, 36, 37, 100, 101, 164, 165, 228, 229, 288, 289, 352, 353, 416, 417, 480, 481, 292, 293, 356, 357, 420, 421, 484, 485],
+ [ 34, 35, 98, 99, 162, 163, 226, 227, 38, 39, 102, 103, 166, 167, 230, 231, 290, 291, 354, 355, 418, 419, 482, 483, 294, 295, 358, 359, 422, 423, 486, 487],
+ [ 544, 545, 608, 609, 672, 673, 736, 737, 548, 549, 612, 613, 676, 677, 740, 741, 800, 801, 864, 865, 928, 929, 992, 993, 804, 805, 868, 869, 932, 933, 996, 997],
+ [ 546, 547, 610, 611, 674, 675, 738, 739, 550, 551, 614, 615, 678, 679, 742, 743, 802, 803, 866, 867, 930, 931, 994, 995, 806, 807, 870, 871, 934, 935, 998, 999],
+ [ 40, 41, 104, 105, 168, 169, 232, 233, 44, 45, 108, 109, 172, 173, 236, 237, 296, 297, 360, 361, 424, 425, 488, 489, 300, 301, 364, 365, 428, 429, 492, 493],
+ [ 42, 43, 106, 107, 170, 171, 234, 235, 46, 47, 110, 111, 174, 175, 238, 239, 298, 299, 362, 363, 426, 427, 490, 491, 302, 303, 366, 367, 430, 431, 494, 495],
+ [ 552, 553, 616, 617, 680, 681, 744, 745, 556, 557, 620, 621, 684, 685, 748, 749, 808, 809, 872, 873, 936, 937, 1000, 1001, 812, 813, 876, 877, 940, 941, 1004, 1005],
+ [ 554, 555, 618, 619, 682, 683, 746, 747, 558, 559, 622, 623, 686, 687, 750, 751, 810, 811, 874, 875, 938, 939, 1002, 1003, 814, 815, 878, 879, 942, 943, 1006, 1007],
+ [ 48, 49, 112, 113, 176, 177, 240, 241, 52, 53, 116, 117, 180, 181, 244, 245, 304, 305, 368, 369, 432, 433, 496, 497, 308, 309, 372, 373, 436, 437, 500, 501],
+ [ 50, 51, 114, 115, 178, 179, 242, 243, 54, 55, 118, 119, 182, 183, 246, 247, 306, 307, 370, 371, 434, 435, 498, 499, 310, 311, 374, 375, 438, 439, 502, 503],
+ [ 560, 561, 624, 625, 688, 689, 752, 753, 564, 565, 628, 629, 692, 693, 756, 757, 816, 817, 880, 881, 944, 945, 1008, 1009, 820, 821, 884, 885, 948, 949, 1012, 1013],
+ [ 562, 563, 626, 627, 690, 691, 754, 755, 566, 567, 630, 631, 694, 695, 758, 759, 818, 819, 882, 883, 946, 947, 1010, 1011, 822, 823, 886, 887, 950, 951, 1014, 1015],
+ [ 56, 57, 120, 121, 184, 185, 248, 249, 60, 61, 124, 125, 188, 189, 252, 253, 312, 313, 376, 377, 440, 441, 504, 505, 316, 317, 380, 381, 444, 445, 508, 509],
+ [ 58, 59, 122, 123, 186, 187, 250, 251, 62, 63, 126, 127, 190, 191, 254, 255, 314, 315, 378, 379, 442, 443, 506, 507, 318, 319, 382, 383, 446, 447, 510, 511],
+ [ 568, 569, 632, 633, 696, 697, 760, 761, 572, 573, 636, 637, 700, 701, 764, 765, 824, 825, 888, 889, 952, 953, 1016, 1017, 828, 829, 892, 893, 956, 957, 1020, 1021],
+ [ 570, 571, 634, 635, 698, 699, 762, 763, 574, 575, 638, 639, 702, 703, 766, 767, 826, 827, 890, 891, 954, 955, 1018, 1019, 830, 831, 894, 895, 958, 959, 1022, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile col major
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(8, 2),
+ tile_group_steps=(2, 4),
+ tile_col_major=True,
+ )
+ assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32]
+ )
+ assert tiles[12] == TensorTile(
+ (32, 32), offset=80, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=86, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 2, 64, 66, 128, 130, 192, 194, 4, 6, 68, 70, 132, 134, 196, 198, 256, 258, 320, 322, 384, 386, 448, 450, 260, 262, 324, 326, 388, 390, 452, 454],
+ [ 1, 3, 65, 67, 129, 131, 193, 195, 5, 7, 69, 71, 133, 135, 197, 199, 257, 259, 321, 323, 385, 387, 449, 451, 261, 263, 325, 327, 389, 391, 453, 455],
+ [ 512, 514, 576, 578, 640, 642, 704, 706, 516, 518, 580, 582, 644, 646, 708, 710, 768, 770, 832, 834, 896, 898, 960, 962, 772, 774, 836, 838, 900, 902, 964, 966],
+ [ 513, 515, 577, 579, 641, 643, 705, 707, 517, 519, 581, 583, 645, 647, 709, 711, 769, 771, 833, 835, 897, 899, 961, 963, 773, 775, 837, 839, 901, 903, 965, 967],
+ [ 8, 10, 72, 74, 136, 138, 200, 202, 12, 14, 76, 78, 140, 142, 204, 206, 264, 266, 328, 330, 392, 394, 456, 458, 268, 270, 332, 334, 396, 398, 460, 462],
+ [ 9, 11, 73, 75, 137, 139, 201, 203, 13, 15, 77, 79, 141, 143, 205, 207, 265, 267, 329, 331, 393, 395, 457, 459, 269, 271, 333, 335, 397, 399, 461, 463],
+ [ 520, 522, 584, 586, 648, 650, 712, 714, 524, 526, 588, 590, 652, 654, 716, 718, 776, 778, 840, 842, 904, 906, 968, 970, 780, 782, 844, 846, 908, 910, 972, 974],
+ [ 521, 523, 585, 587, 649, 651, 713, 715, 525, 527, 589, 591, 653, 655, 717, 719, 777, 779, 841, 843, 905, 907, 969, 971, 781, 783, 845, 847, 909, 911, 973, 975],
+ [ 16, 18, 80, 82, 144, 146, 208, 210, 20, 22, 84, 86, 148, 150, 212, 214, 272, 274, 336, 338, 400, 402, 464, 466, 276, 278, 340, 342, 404, 406, 468, 470],
+ [ 17, 19, 81, 83, 145, 147, 209, 211, 21, 23, 85, 87, 149, 151, 213, 215, 273, 275, 337, 339, 401, 403, 465, 467, 277, 279, 341, 343, 405, 407, 469, 471],
+ [ 528, 530, 592, 594, 656, 658, 720, 722, 532, 534, 596, 598, 660, 662, 724, 726, 784, 786, 848, 850, 912, 914, 976, 978, 788, 790, 852, 854, 916, 918, 980, 982],
+ [ 529, 531, 593, 595, 657, 659, 721, 723, 533, 535, 597, 599, 661, 663, 725, 727, 785, 787, 849, 851, 913, 915, 977, 979, 789, 791, 853, 855, 917, 919, 981, 983],
+ [ 24, 26, 88, 90, 152, 154, 216, 218, 28, 30, 92, 94, 156, 158, 220, 222, 280, 282, 344, 346, 408, 410, 472, 474, 284, 286, 348, 350, 412, 414, 476, 478],
+ [ 25, 27, 89, 91, 153, 155, 217, 219, 29, 31, 93, 95, 157, 159, 221, 223, 281, 283, 345, 347, 409, 411, 473, 475, 285, 287, 349, 351, 413, 415, 477, 479],
+ [ 536, 538, 600, 602, 664, 666, 728, 730, 540, 542, 604, 606, 668, 670, 732, 734, 792, 794, 856, 858, 920, 922, 984, 986, 796, 798, 860, 862, 924, 926, 988, 990],
+ [ 537, 539, 601, 603, 665, 667, 729, 731, 541, 543, 605, 607, 669, 671, 733, 735, 793, 795, 857, 859, 921, 923, 985, 987, 797, 799, 861, 863, 925, 927, 989, 991],
+ [ 32, 34, 96, 98, 160, 162, 224, 226, 36, 38, 100, 102, 164, 166, 228, 230, 288, 290, 352, 354, 416, 418, 480, 482, 292, 294, 356, 358, 420, 422, 484, 486],
+ [ 33, 35, 97, 99, 161, 163, 225, 227, 37, 39, 101, 103, 165, 167, 229, 231, 289, 291, 353, 355, 417, 419, 481, 483, 293, 295, 357, 359, 421, 423, 485, 487],
+ [ 544, 546, 608, 610, 672, 674, 736, 738, 548, 550, 612, 614, 676, 678, 740, 742, 800, 802, 864, 866, 928, 930, 992, 994, 804, 806, 868, 870, 932, 934, 996, 998],
+ [ 545, 547, 609, 611, 673, 675, 737, 739, 549, 551, 613, 615, 677, 679, 741, 743, 801, 803, 865, 867, 929, 931, 993, 995, 805, 807, 869, 871, 933, 935, 997, 999],
+ [ 40, 42, 104, 106, 168, 170, 232, 234, 44, 46, 108, 110, 172, 174, 236, 238, 296, 298, 360, 362, 424, 426, 488, 490, 300, 302, 364, 366, 428, 430, 492, 494],
+ [ 41, 43, 105, 107, 169, 171, 233, 235, 45, 47, 109, 111, 173, 175, 237, 239, 297, 299, 361, 363, 425, 427, 489, 491, 301, 303, 365, 367, 429, 431, 493, 495],
+ [ 552, 554, 616, 618, 680, 682, 744, 746, 556, 558, 620, 622, 684, 686, 748, 750, 808, 810, 872, 874, 936, 938, 1000, 1002, 812, 814, 876, 878, 940, 942, 1004, 1006],
+ [ 553, 555, 617, 619, 681, 683, 745, 747, 557, 559, 621, 623, 685, 687, 749, 751, 809, 811, 873, 875, 937, 939, 1001, 1003, 813, 815, 877, 879, 941, 943, 1005, 1007],
+ [ 48, 50, 112, 114, 176, 178, 240, 242, 52, 54, 116, 118, 180, 182, 244, 246, 304, 306, 368, 370, 432, 434, 496, 498, 308, 310, 372, 374, 436, 438, 500, 502],
+ [ 49, 51, 113, 115, 177, 179, 241, 243, 53, 55, 117, 119, 181, 183, 245, 247, 305, 307, 369, 371, 433, 435, 497, 499, 309, 311, 373, 375, 437, 439, 501, 503],
+ [ 560, 562, 624, 626, 688, 690, 752, 754, 564, 566, 628, 630, 692, 694, 756, 758, 816, 818, 880, 882, 944, 946, 1008, 1010, 820, 822, 884, 886, 948, 950, 1012, 1014],
+ [ 561, 563, 625, 627, 689, 691, 753, 755, 565, 567, 629, 631, 693, 695, 757, 759, 817, 819, 881, 883, 945, 947, 1009, 1011, 821, 823, 885, 887, 949, 951, 1013, 1015],
+ [ 56, 58, 120, 122, 184, 186, 248, 250, 60, 62, 124, 126, 188, 190, 252, 254, 312, 314, 376, 378, 440, 442, 504, 506, 316, 318, 380, 382, 444, 446, 508, 510],
+ [ 57, 59, 121, 123, 185, 187, 249, 251, 61, 63, 125, 127, 189, 191, 253, 255, 313, 315, 377, 379, 441, 443, 505, 507, 317, 319, 381, 383, 445, 447, 509, 511],
+ [ 568, 570, 632, 634, 696, 698, 760, 762, 572, 574, 636, 638, 700, 702, 764, 766, 824, 826, 888, 890, 952, 954, 1016, 1018, 828, 830, 892, 894, 956, 958, 1020, 1022],
+ [ 569, 571, 633, 635, 697, 699, 761, 763, 573, 575, 637, 639, 701, 703, 765, 767, 825, 827, 889, 891, 953, 955, 1017, 1019, 829, 831, 893, 895, 957, 959, 1021, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile col major and tile group col major
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(8, 2),
+ tile_group_steps=(2, 4),
+ tile_col_major=True,
+ tile_group_col_major=True,
+ )
+ assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=2, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ assert tiles[12] == TensorTile(
+ (32, 32), offset=80, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=86, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 2, 64, 66, 128, 130, 192, 194, 32, 34, 96, 98, 160, 162, 224, 226, 256, 258, 320, 322, 384, 386, 448, 450, 288, 290, 352, 354, 416, 418, 480, 482],
+ [ 1, 3, 65, 67, 129, 131, 193, 195, 33, 35, 97, 99, 161, 163, 225, 227, 257, 259, 321, 323, 385, 387, 449, 451, 289, 291, 353, 355, 417, 419, 481, 483],
+ [ 512, 514, 576, 578, 640, 642, 704, 706, 544, 546, 608, 610, 672, 674, 736, 738, 768, 770, 832, 834, 896, 898, 960, 962, 800, 802, 864, 866, 928, 930, 992, 994],
+ [ 513, 515, 577, 579, 641, 643, 705, 707, 545, 547, 609, 611, 673, 675, 737, 739, 769, 771, 833, 835, 897, 899, 961, 963, 801, 803, 865, 867, 929, 931, 993, 995],
+ [ 4, 6, 68, 70, 132, 134, 196, 198, 36, 38, 100, 102, 164, 166, 228, 230, 260, 262, 324, 326, 388, 390, 452, 454, 292, 294, 356, 358, 420, 422, 484, 486],
+ [ 5, 7, 69, 71, 133, 135, 197, 199, 37, 39, 101, 103, 165, 167, 229, 231, 261, 263, 325, 327, 389, 391, 453, 455, 293, 295, 357, 359, 421, 423, 485, 487],
+ [ 516, 518, 580, 582, 644, 646, 708, 710, 548, 550, 612, 614, 676, 678, 740, 742, 772, 774, 836, 838, 900, 902, 964, 966, 804, 806, 868, 870, 932, 934, 996, 998],
+ [ 517, 519, 581, 583, 645, 647, 709, 711, 549, 551, 613, 615, 677, 679, 741, 743, 773, 775, 837, 839, 901, 903, 965, 967, 805, 807, 869, 871, 933, 935, 997, 999],
+ [ 8, 10, 72, 74, 136, 138, 200, 202, 40, 42, 104, 106, 168, 170, 232, 234, 264, 266, 328, 330, 392, 394, 456, 458, 296, 298, 360, 362, 424, 426, 488, 490],
+ [ 9, 11, 73, 75, 137, 139, 201, 203, 41, 43, 105, 107, 169, 171, 233, 235, 265, 267, 329, 331, 393, 395, 457, 459, 297, 299, 361, 363, 425, 427, 489, 491],
+ [ 520, 522, 584, 586, 648, 650, 712, 714, 552, 554, 616, 618, 680, 682, 744, 746, 776, 778, 840, 842, 904, 906, 968, 970, 808, 810, 872, 874, 936, 938, 1000, 1002],
+ [ 521, 523, 585, 587, 649, 651, 713, 715, 553, 555, 617, 619, 681, 683, 745, 747, 777, 779, 841, 843, 905, 907, 969, 971, 809, 811, 873, 875, 937, 939, 1001, 1003],
+ [ 12, 14, 76, 78, 140, 142, 204, 206, 44, 46, 108, 110, 172, 174, 236, 238, 268, 270, 332, 334, 396, 398, 460, 462, 300, 302, 364, 366, 428, 430, 492, 494],
+ [ 13, 15, 77, 79, 141, 143, 205, 207, 45, 47, 109, 111, 173, 175, 237, 239, 269, 271, 333, 335, 397, 399, 461, 463, 301, 303, 365, 367, 429, 431, 493, 495],
+ [ 524, 526, 588, 590, 652, 654, 716, 718, 556, 558, 620, 622, 684, 686, 748, 750, 780, 782, 844, 846, 908, 910, 972, 974, 812, 814, 876, 878, 940, 942, 1004, 1006],
+ [ 525, 527, 589, 591, 653, 655, 717, 719, 557, 559, 621, 623, 685, 687, 749, 751, 781, 783, 845, 847, 909, 911, 973, 975, 813, 815, 877, 879, 941, 943, 1005, 1007],
+ [ 16, 18, 80, 82, 144, 146, 208, 210, 48, 50, 112, 114, 176, 178, 240, 242, 272, 274, 336, 338, 400, 402, 464, 466, 304, 306, 368, 370, 432, 434, 496, 498],
+ [ 17, 19, 81, 83, 145, 147, 209, 211, 49, 51, 113, 115, 177, 179, 241, 243, 273, 275, 337, 339, 401, 403, 465, 467, 305, 307, 369, 371, 433, 435, 497, 499],
+ [ 528, 530, 592, 594, 656, 658, 720, 722, 560, 562, 624, 626, 688, 690, 752, 754, 784, 786, 848, 850, 912, 914, 976, 978, 816, 818, 880, 882, 944, 946, 1008, 1010],
+ [ 529, 531, 593, 595, 657, 659, 721, 723, 561, 563, 625, 627, 689, 691, 753, 755, 785, 787, 849, 851, 913, 915, 977, 979, 817, 819, 881, 883, 945, 947, 1009, 1011],
+ [ 20, 22, 84, 86, 148, 150, 212, 214, 52, 54, 116, 118, 180, 182, 244, 246, 276, 278, 340, 342, 404, 406, 468, 470, 308, 310, 372, 374, 436, 438, 500, 502],
+ [ 21, 23, 85, 87, 149, 151, 213, 215, 53, 55, 117, 119, 181, 183, 245, 247, 277, 279, 341, 343, 405, 407, 469, 471, 309, 311, 373, 375, 437, 439, 501, 503],
+ [ 532, 534, 596, 598, 660, 662, 724, 726, 564, 566, 628, 630, 692, 694, 756, 758, 788, 790, 852, 854, 916, 918, 980, 982, 820, 822, 884, 886, 948, 950, 1012, 1014],
+ [ 533, 535, 597, 599, 661, 663, 725, 727, 565, 567, 629, 631, 693, 695, 757, 759, 789, 791, 853, 855, 917, 919, 981, 983, 821, 823, 885, 887, 949, 951, 1013, 1015],
+ [ 24, 26, 88, 90, 152, 154, 216, 218, 56, 58, 120, 122, 184, 186, 248, 250, 280, 282, 344, 346, 408, 410, 472, 474, 312, 314, 376, 378, 440, 442, 504, 506],
+ [ 25, 27, 89, 91, 153, 155, 217, 219, 57, 59, 121, 123, 185, 187, 249, 251, 281, 283, 345, 347, 409, 411, 473, 475, 313, 315, 377, 379, 441, 443, 505, 507],
+ [ 536, 538, 600, 602, 664, 666, 728, 730, 568, 570, 632, 634, 696, 698, 760, 762, 792, 794, 856, 858, 920, 922, 984, 986, 824, 826, 888, 890, 952, 954, 1016, 1018],
+ [ 537, 539, 601, 603, 665, 667, 729, 731, 569, 571, 633, 635, 697, 699, 761, 763, 793, 795, 857, 859, 921, 923, 985, 987, 825, 827, 889, 891, 953, 955, 1017, 1019],
+ [ 28, 30, 92, 94, 156, 158, 220, 222, 60, 62, 124, 126, 188, 190, 252, 254, 284, 286, 348, 350, 412, 414, 476, 478, 316, 318, 380, 382, 444, 446, 508, 510],
+ [ 29, 31, 93, 95, 157, 159, 221, 223, 61, 63, 125, 127, 189, 191, 253, 255, 285, 287, 349, 351, 413, 415, 477, 479, 317, 319, 381, 383, 445, 447, 509, 511],
+ [ 540, 542, 604, 606, 668, 670, 732, 734, 572, 574, 636, 638, 700, 702, 764, 766, 796, 798, 860, 862, 924, 926, 988, 990, 828, 830, 892, 894, 956, 958, 1020, 1022],
+ [ 541, 543, 605, 607, 669, 671, 733, 735, 573, 575, 637, 639, 701, 703, 765, 767, 797, 799, 861, 863, 925, 927, 989, 991, 829, 831, 893, 895, 957, 959, 1021, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # Tile col major and tile group col major and iter col major
+ tiles = TensorTiler2D.step_tiler(
+ (32, 32),
+ tile_dims=(2, 2),
+ tile_group_repeats=(8, 2),
+ tile_group_steps=(2, 4),
+ tile_col_major=True,
+ tile_group_col_major=True,
+ iter_col_major=True,
+ )
+ assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2))
+ assert tiles[0] == TensorTile(
+ (32, 32), offset=0, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ assert tiles[1] == TensorTile(
+ (32, 32), offset=64, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ assert tiles[12] == TensorTile(
+ (32, 32), offset=20, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ assert tiles[-1] == TensorTile(
+ (32, 32), offset=86, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32]
+ )
+ # fmt: off
+ ref_access_order_tensor = np.array([
+ [ 0, 2, 128, 130, 256, 258, 384, 386, 32, 34, 160, 162, 288, 290, 416, 418, 512, 514, 640, 642, 768, 770, 896, 898, 544, 546, 672, 674, 800, 802, 928, 930],
+ [ 1, 3, 129, 131, 257, 259, 385, 387, 33, 35, 161, 163, 289, 291, 417, 419, 513, 515, 641, 643, 769, 771, 897, 899, 545, 547, 673, 675, 801, 803, 929, 931],
+ [ 64, 66, 192, 194, 320, 322, 448, 450, 96, 98, 224, 226, 352, 354, 480, 482, 576, 578, 704, 706, 832, 834, 960, 962, 608, 610, 736, 738, 864, 866, 992, 994],
+ [ 65, 67, 193, 195, 321, 323, 449, 451, 97, 99, 225, 227, 353, 355, 481, 483, 577, 579, 705, 707, 833, 835, 961, 963, 609, 611, 737, 739, 865, 867, 993, 995],
+ [ 4, 6, 132, 134, 260, 262, 388, 390, 36, 38, 164, 166, 292, 294, 420, 422, 516, 518, 644, 646, 772, 774, 900, 902, 548, 550, 676, 678, 804, 806, 932, 934],
+ [ 5, 7, 133, 135, 261, 263, 389, 391, 37, 39, 165, 167, 293, 295, 421, 423, 517, 519, 645, 647, 773, 775, 901, 903, 549, 551, 677, 679, 805, 807, 933, 935],
+ [ 68, 70, 196, 198, 324, 326, 452, 454, 100, 102, 228, 230, 356, 358, 484, 486, 580, 582, 708, 710, 836, 838, 964, 966, 612, 614, 740, 742, 868, 870, 996, 998],
+ [ 69, 71, 197, 199, 325, 327, 453, 455, 101, 103, 229, 231, 357, 359, 485, 487, 581, 583, 709, 711, 837, 839, 965, 967, 613, 615, 741, 743, 869, 871, 997, 999],
+ [ 8, 10, 136, 138, 264, 266, 392, 394, 40, 42, 168, 170, 296, 298, 424, 426, 520, 522, 648, 650, 776, 778, 904, 906, 552, 554, 680, 682, 808, 810, 936, 938],
+ [ 9, 11, 137, 139, 265, 267, 393, 395, 41, 43, 169, 171, 297, 299, 425, 427, 521, 523, 649, 651, 777, 779, 905, 907, 553, 555, 681, 683, 809, 811, 937, 939],
+ [ 72, 74, 200, 202, 328, 330, 456, 458, 104, 106, 232, 234, 360, 362, 488, 490, 584, 586, 712, 714, 840, 842, 968, 970, 616, 618, 744, 746, 872, 874, 1000, 1002],
+ [ 73, 75, 201, 203, 329, 331, 457, 459, 105, 107, 233, 235, 361, 363, 489, 491, 585, 587, 713, 715, 841, 843, 969, 971, 617, 619, 745, 747, 873, 875, 1001, 1003],
+ [ 12, 14, 140, 142, 268, 270, 396, 398, 44, 46, 172, 174, 300, 302, 428, 430, 524, 526, 652, 654, 780, 782, 908, 910, 556, 558, 684, 686, 812, 814, 940, 942],
+ [ 13, 15, 141, 143, 269, 271, 397, 399, 45, 47, 173, 175, 301, 303, 429, 431, 525, 527, 653, 655, 781, 783, 909, 911, 557, 559, 685, 687, 813, 815, 941, 943],
+ [ 76, 78, 204, 206, 332, 334, 460, 462, 108, 110, 236, 238, 364, 366, 492, 494, 588, 590, 716, 718, 844, 846, 972, 974, 620, 622, 748, 750, 876, 878, 1004, 1006],
+ [ 77, 79, 205, 207, 333, 335, 461, 463, 109, 111, 237, 239, 365, 367, 493, 495, 589, 591, 717, 719, 845, 847, 973, 975, 621, 623, 749, 751, 877, 879, 1005, 1007],
+ [ 16, 18, 144, 146, 272, 274, 400, 402, 48, 50, 176, 178, 304, 306, 432, 434, 528, 530, 656, 658, 784, 786, 912, 914, 560, 562, 688, 690, 816, 818, 944, 946],
+ [ 17, 19, 145, 147, 273, 275, 401, 403, 49, 51, 177, 179, 305, 307, 433, 435, 529, 531, 657, 659, 785, 787, 913, 915, 561, 563, 689, 691, 817, 819, 945, 947],
+ [ 80, 82, 208, 210, 336, 338, 464, 466, 112, 114, 240, 242, 368, 370, 496, 498, 592, 594, 720, 722, 848, 850, 976, 978, 624, 626, 752, 754, 880, 882, 1008, 1010],
+ [ 81, 83, 209, 211, 337, 339, 465, 467, 113, 115, 241, 243, 369, 371, 497, 499, 593, 595, 721, 723, 849, 851, 977, 979, 625, 627, 753, 755, 881, 883, 1009, 1011],
+ [ 20, 22, 148, 150, 276, 278, 404, 406, 52, 54, 180, 182, 308, 310, 436, 438, 532, 534, 660, 662, 788, 790, 916, 918, 564, 566, 692, 694, 820, 822, 948, 950],
+ [ 21, 23, 149, 151, 277, 279, 405, 407, 53, 55, 181, 183, 309, 311, 437, 439, 533, 535, 661, 663, 789, 791, 917, 919, 565, 567, 693, 695, 821, 823, 949, 951],
+ [ 84, 86, 212, 214, 340, 342, 468, 470, 116, 118, 244, 246, 372, 374, 500, 502, 596, 598, 724, 726, 852, 854, 980, 982, 628, 630, 756, 758, 884, 886, 1012, 1014],
+ [ 85, 87, 213, 215, 341, 343, 469, 471, 117, 119, 245, 247, 373, 375, 501, 503, 597, 599, 725, 727, 853, 855, 981, 983, 629, 631, 757, 759, 885, 887, 1013, 1015],
+ [ 24, 26, 152, 154, 280, 282, 408, 410, 56, 58, 184, 186, 312, 314, 440, 442, 536, 538, 664, 666, 792, 794, 920, 922, 568, 570, 696, 698, 824, 826, 952, 954],
+ [ 25, 27, 153, 155, 281, 283, 409, 411, 57, 59, 185, 187, 313, 315, 441, 443, 537, 539, 665, 667, 793, 795, 921, 923, 569, 571, 697, 699, 825, 827, 953, 955],
+ [ 88, 90, 216, 218, 344, 346, 472, 474, 120, 122, 248, 250, 376, 378, 504, 506, 600, 602, 728, 730, 856, 858, 984, 986, 632, 634, 760, 762, 888, 890, 1016, 1018],
+ [ 89, 91, 217, 219, 345, 347, 473, 475, 121, 123, 249, 251, 377, 379, 505, 507, 601, 603, 729, 731, 857, 859, 985, 987, 633, 635, 761, 763, 889, 891, 1017, 1019],
+ [ 28, 30, 156, 158, 284, 286, 412, 414, 60, 62, 188, 190, 316, 318, 444, 446, 540, 542, 668, 670, 796, 798, 924, 926, 572, 574, 700, 702, 828, 830, 956, 958],
+ [ 29, 31, 157, 159, 285, 287, 413, 415, 61, 63, 189, 191, 317, 319, 445, 447, 541, 543, 669, 671, 797, 799, 925, 927, 573, 575, 701, 703, 829, 831, 957, 959],
+ [ 92, 94, 220, 222, 348, 350, 476, 478, 124, 126, 252, 254, 380, 382, 508, 510, 604, 606, 732, 734, 860, 862, 988, 990, 636, 638, 764, 766, 892, 894, 1020, 1022],
+ [ 93, 95, 221, 223, 349, 351, 477, 479, 125, 127, 253, 255, 381, 383, 509, 511, 605, 607, 733, 735, 861, 863, 989, 991, 637, 639, 765, 767, 893, 895, 1021, 1023]])
+ # fmt: on
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order_tensor).all()
+ assert (access_count == 1).all()
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: step_tiler_invalid
+@construct_test
+def step_tiler_invalid():
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (), (3, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tensor dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (10, 9, 4), (3, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too many tensor dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3, -1), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3,), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too few tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (1, 1, 1), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too many tile dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=0
+ )
+ raise ValueError("Invalid repeat.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (4, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Indivisible tile (height)")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3, 3), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Indivisible tile (width)")
+ except ValueError:
+ # good
+ pass
+
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3, 2), (1,), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Too few tile group dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3, 2), (1, -1), (1, 1), tile_col_major=True, pattern_repeat=5
+ )
+ raise ValueError("Bad tile group dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (9, 4), (3, 2), (1, 1, 1), (1, 1), tile_col_major=True
+ )
+ raise ValueError("Too many tile group dims, should fail.")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (18, 8), (3, 2), (2, 3), (1, 1), tile_col_major=True
+ )
+ raise ValueError(
+ "Indivisible by tile repeat width (but without allow_partial)."
+ )
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (18, 8), (3, 2), (4, 2), (1, 1), tile_col_major=True
+ )
+ raise ValueError(
+ "Indivisible by tile repeat height (but without allow_partial)."
+ )
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (18, 8), (3, 2), (4, 2), (1, -1), tile_col_major=True
+ )
+ raise ValueError("Bad tile step dims")
+ except ValueError:
+ # good
+ pass
+ try:
+ tiles = TensorTiler2D.step_tiler(
+ (18, 8), (3, 2), (4, 2), (1,), tile_col_major=True
+ )
+ raise ValueError("Too few tile step dims")
+ except ValueError:
+ # good
+ pass
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/step_tiler_partial.py b/test/python/tensortiler/step_tiler_partial.py
new file mode 100644
index 0000000000..f5a8a44912
--- /dev/null
+++ b/test/python/tensortiler/step_tiler_partial.py
@@ -0,0 +1,70 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile, TensorTileSequence, TensorTiler2D
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: step_tiler_partial_row
+@construct_test
+def step_tiler_partial_row():
+
+ # all row major
+ # tile col major
+ # tile group col major
+ # iter col major
+ # all col major
+ # pattern repeat
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: step_tiler_partial_col
+@construct_test
+def step_tiler_partial_col():
+ """
+ THIS IS BUGGY:
+ tensor_dims = (3 * 5 * 3, 2 * 6 * 2)
+
+ # All row major
+ tiles = TensorTiler2D.step_tiler(
+ tensor_dims, tile_dims=(3, 2), tile_group_repeats=(5, 7), tile_group_steps=(2, 3), allow_partial=True
+ )
+ print(len(tiles))
+ for t in tiles:
+ print(t)
+ print(tiles[0])
+ print(tiles[1])
+ print(tiles[3])
+ print(tiles[-1])
+ tiles.visualize(plot_access_count=True)
+ anim = tiles.animate()
+ HTML(anim.to_jshtml())
+ """
+
+ # all row major
+ # tile col major
+ # tile group col major
+ # iter col major
+ # all col major
+ # pattern repeat
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: step_tiler_partial_both
+@construct_test
+def step_tiler_partial_both():
+
+ # all row major
+ # tile col major
+ # tile group col major
+ # iter col major
+ # all col major
+ # pattern repeat
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/tensortile.py b/test/python/tensortiler/tensortile.py
new file mode 100644
index 0000000000..6c8acf7eae
--- /dev/null
+++ b/test/python/tensortiler/tensortile.py
@@ -0,0 +1,136 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: tensor_tile
+@construct_test
+def tensor_tile():
+ # Valid
+ tile = TensorTile((2, 3), 4, sizes=[1, 2], strides=[0, 1])
+
+ # Check accessors
+ assert (
+ tile.tensor_dims[0] == 2
+ and tile.tensor_dims[1] == 3
+ and len(tile.tensor_dims) == 2
+ )
+ assert tile.offset == 4
+ assert tile.sizes == [1, 2]
+ assert tile.strides == [0, 1]
+ assert tile.transformation_dims == [(1, 0), (2, 1)]
+ access_order, access_count = tile.access_tensors()
+ assert (
+ access_order == np.array([[-1, -1, -1], [-1, 0, 1]], dtype=access_order.dtype)
+ ).all()
+ assert (
+ access_count == np.array([[0, 0, 0], [0, 1, 1]], dtype=access_count.dtype)
+ ).all()
+
+ tile2 = TensorTile((2, 3), 4, sizes=[1, 2], strides=[0, 1])
+ assert tile2 == tile
+
+ tile3 = TensorTile((2, 3), 2, sizes=[1, 2], strides=[0, 1])
+ assert tile3 != tile
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: tensor_tile_invalid
+@construct_test
+def tensor_tile_invalid():
+
+ # Bad tensor dims
+ try:
+ tile = TensorTile((), 4, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad dims (no dims)")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((0, 1), 4, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad dims (first dim 0)")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((1, 0), 4, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad dims (second dim 0)")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((-1, 1), 4, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad dims (first dim negative)")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((1, -1), 4, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad dims (second dim negative)")
+ except ValueError:
+ # Good
+ pass
+
+ # Bad offset
+ try:
+ tile = TensorTile((2, 3), -1, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad offset (negative)")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[1, 2], strides=[0, 1])
+ raise Exception("Should fail, bad offset (too large)")
+ except ValueError:
+ # Good
+ pass
+
+ # Bad sizes
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[-1], strides=[1])
+ raise Exception("Should fail, size (negative)")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[0], strides=[1])
+ raise Exception("Should fail, size (zero)")
+ except ValueError:
+ # Good
+ pass
+
+ # Bad sizes + strides
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[], strides=[])
+ raise Exception("Should fail, size and stride empty")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[1], strides=[0, 1])
+ raise Exception("Should fail, sizes and strides uneven dimensions")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[1, 1], strides=[0])
+ raise Exception("Should fail, sizes and strides uneven dimensions 2")
+ except ValueError:
+ # Good
+ pass
+
+ # Bad strides
+ try:
+ tile = TensorTile((2, 3), 2 * 3, sizes=[1], strides=[-1])
+ raise Exception("Should fail, bad stride (negative)")
+ except ValueError:
+ # Good
+ pass
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/tensortilesequence.py b/test/python/tensortiler/tensortilesequence.py
new file mode 100644
index 0000000000..8058fb5d75
--- /dev/null
+++ b/test/python/tensortiler/tensortilesequence.py
@@ -0,0 +1,145 @@
+import numpy as np
+
+from aie.helpers.tensortiler import TensorTile, TensorTileSequence
+from util import construct_test
+
+# RUN: %python %s | FileCheck %s
+
+
+# CHECK-LABEL: tensor_tile_sequence
+@construct_test
+def tensor_tile_sequence():
+
+ empty_tiles = TensorTileSequence((2, 2), 0)
+ assert len(empty_tiles) == 0
+ ref_access_order = np.array([[-1, -1], [-1, -1]])
+ ref_access_count = np.array([[0, 0], [0, 0]])
+ access_order, access_count = empty_tiles.access_tensors()
+ assert (access_order == ref_access_order).all()
+ assert (access_count == ref_access_count).all()
+
+ def offset_fn(step, _prev_offset):
+ return step
+
+ tiles = TensorTileSequence(
+ (2, 2), 4, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn
+ )
+ assert len(tiles) == 4
+ ref_access_order = np.array([[0, 1], [2, 3]])
+ ref_access_count = np.array([[1, 1], [1, 1]])
+ access_order, access_count = tiles.access_tensors()
+ assert (access_order == ref_access_order).all()
+ assert (access_count == ref_access_count).all()
+
+ tile = TensorTile((2, 2), offset=2, sizes=[1, 1], strides=[1, 1])
+ assert tile in tiles
+ assert tiles[2] == tile
+ tiles2 = list(iter(tiles))
+ assert tiles2[2] == tile
+
+ del tiles[2]
+ assert not (tile in tiles)
+ tiles.insert(2, tile)
+ assert tile in tiles
+
+ tile2 = TensorTile((3, 3), offset=2, sizes=[1, 1], strides=[1, 1])
+ assert not (tile2 in tiles)
+
+ tiles3 = TensorTileSequence(
+ (2, 2), 4, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn
+ )
+ assert tiles == tiles3
+ tiles4 = TensorTileSequence(
+ (2, 2), 3, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn
+ )
+ assert tiles != tiles4
+ ref_access_order = np.array([[0, 1], [2, -1]])
+ ref_access_count = np.array([[1, 1], [1, 0]])
+ access_order, access_count = tiles4.access_tensors()
+ assert (access_order == ref_access_order).all()
+ assert (access_count == ref_access_count).all()
+
+ tiles4_copy = TensorTileSequence.from_tiles(tiles4)
+ assert tiles4_copy == tiles4
+ access_order, access_count = tiles4_copy.access_tensors()
+ assert (access_order == ref_access_order).all()
+ assert (access_count == ref_access_count).all()
+
+ # CHECK: Pass!
+ print("Pass!")
+
+
+# CHECK-LABEL: tensor_tile_sequence_invalid
+@construct_test
+def tensor_tile_sequence_invalid():
+ def offset_fn(step, _prev_offset):
+ return step
+
+ try:
+ tiles = TensorTileSequence(
+ (0, 2), 4, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn
+ )
+ raise Exception("Should fail, bad dims")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tiles = TensorTileSequence(
+ (2, 2), -1, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn
+ )
+ raise Exception("Should fail, bad num steps")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tiles = TensorTileSequence(
+ (2, 2), 1, sizes=[1, 0], strides=[0, 1], offset_fn=offset_fn
+ )
+ raise Exception("Should fail, bad sizes")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tiles = TensorTileSequence(
+ (2, 2), 1, sizes=[1, 1], strides=[-1, 1], offset_fn=offset_fn
+ )
+ raise Exception("Should fail, bad strides")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tiles = TensorTileSequence((2, 2), 1, strides=[1, 1], offset_fn=offset_fn)
+ raise Exception("Should fail, missing sizes")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tiles = TensorTileSequence((2, 2), 1, sizes=[1, 1], offset_fn=offset_fn)
+ raise Exception("Should fail, missing strides")
+ except ValueError:
+ # Good
+ pass
+ try:
+ tiles = TensorTileSequence((2, 2), 1, strides=[1, 1], sizes=[1, 1])
+ raise Exception("Should fail, missing offset")
+ except ValueError:
+ # Good
+ pass
+
+ tiles = TensorTileSequence((2, 3), 1, offset=0, strides=[0, 1], sizes=[1, 1])
+ try:
+ tiles.append(TensorTile((3, 2), offset=0, strides=[0, 1], sizes=[1, 1]))
+ raise Exception("Should not be able to add tile with inconsistent tensor dim")
+ except ValueError:
+ # Good
+ pass
+
+ try:
+ TensorTileSequence.from_tiles([])
+ raise Exception("Should not be able to create sequence from no tiles")
+ except ValueError:
+ # Good
+ pass
+
+ # CHECK: Pass!
+ print("Pass!")
diff --git a/test/python/tensortiler/util.py b/test/python/tensortiler/util.py
new file mode 100644
index 0000000000..2267760a46
--- /dev/null
+++ b/test/python/tensortiler/util.py
@@ -0,0 +1,4 @@
+# Run test
+def construct_test(f):
+ print("\nTEST:", f.__name__)
+ f()