diff --git a/programming_examples/basic/dma_transpose/Makefile b/programming_examples/basic/dma_transpose/Makefile index 87771d1ab0..212cedabf3 100644 --- a/programming_examples/basic/dma_transpose/Makefile +++ b/programming_examples/basic/dma_transpose/Makefile @@ -44,5 +44,9 @@ endif run: ${targetname}.exe build/final.xclbin ${powershell} ./$< -x build/final.xclbin -i build/insts.txt -k MLIR_AIE --M ${M} --K ${K} +generate_access_map: ${srcdir}/aie2.py + mkdir -p ${@D} + python3 $< --generate-access-map ${M} ${K} + clean: rm -rf build _build inst ${targetname}.exe diff --git a/programming_examples/basic/dma_transpose/README.md b/programming_examples/basic/dma_transpose/README.md index 5a73dde0e3..45125d377c 100644 --- a/programming_examples/basic/dma_transpose/README.md +++ b/programming_examples/basic/dma_transpose/README.md @@ -15,11 +15,24 @@ This reference design can be run on a Ryzen™ AI NPU. In the [design](./aie2.py), a 2-D array in a row-major layout is read from external memory to `ComputeTile2` with a transposed layout, by using an implicit copy via the compute tile's Data Movement Accelerator (DMA). The data is read from and written to external memory through the Shim tile (`col`, 0). +This data movement transformation can be visualized as a map which shows the order the data the data is streamed (e.g., in transposed layout): +

+ +

Visualization of the Transpose Data Transformation for M=32, K=16. +

+

+ The implicit copy is performed using the `object_fifo_link` operation that specifies how input data arriving via `of_in` should be sent further via `of_out` by specifically leveraging the compute tile's DMA. This operation and its functionality are described in more depth in [Section-2b](../../../programming_guide/section-2/section-2b/README.md/#object-fifo-link) of the programming guide. To compile and run the design for NPU: -``` +```bash make make run +``` + +To generate a data visualization of the transpose (like that above), run: +```bash +make generate_access_map ``` \ No newline at end of file diff --git a/programming_examples/basic/dma_transpose/aie2.py b/programming_examples/basic/dma_transpose/aie2.py index d5299f4a06..bbc3a056d9 100644 --- a/programming_examples/basic/dma_transpose/aie2.py +++ b/programming_examples/basic/dma_transpose/aie2.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # # (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates +import argparse import numpy as np import sys @@ -12,20 +13,24 @@ from aie.dialects.aiex import * from aie.extras.context import mlir_mod_ctx from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.tensortiler.tensortiler2d import TensorTile -N = 4096 -M = 64 -K = 64 -if len(sys.argv) == 3: - M = int(sys.argv[1]) - K = int(sys.argv[2]) - N = M * K +def my_passthrough(M, K, N, generate_acccess_map=False): + tensor_ty = np.ndarray[(M, K), np.dtype[np.int32]] + data_transform = TensorTile( + tensor_height=M, + tensor_width=K, + sizes=[1, 1, K, M], + strides=[1, 1, 1, K], + offset=0, + ) + if generate_acccess_map: + data_transform.visualize( + plot_access_count=False, file_path="transpose_data.png" + ) + return -tensor_ty = np.ndarray[(M, K), np.dtype[np.int32]] - - -def my_passthrough(): with mlir_mod_ctx() as ctx: @device(AIEDevice.npu1_1col) @@ -56,8 +61,7 @@ def sequence(A, B, C): metadata=of_in, bd_id=1, mem=A, - sizes=[1, 1, K, M], - strides=[1, 1, 1, K], + tensor_tile=data_transform, issue_token=True, ) npu_dma_memcpy_nd(metadata=of_out, bd_id=0, mem=C, sizes=[1, 1, 1, N]) @@ -66,4 +70,24 @@ def sequence(A, B, C): print(ctx.module) -my_passthrough() +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("dims", help="M K", type=int, nargs="*", default=[64, 64]) + p.add_argument( + "--generate-access-map", + action="store_true", + help="Produce a file showing data access order", + ) + args = p.parse_args() + + if len(args.dims) != 2: + print( + "ERROR: Must provide either no dimensions or both M and K", file=sys.stderr + ) + exit(-1) + my_passthrough( + M=args.dims[0], + K=args.dims[1], + N=args.dims[0] * args.dims[1], + generate_acccess_map=args.generate_access_map, + ) diff --git a/programming_examples/basic/dma_transpose/transpose_data.png b/programming_examples/basic/dma_transpose/transpose_data.png new file mode 100644 index 0000000000..33112790be Binary files /dev/null and b/programming_examples/basic/dma_transpose/transpose_data.png differ diff --git a/programming_examples/basic/matrix_multiplication/whole_array/tiling_scratch.py b/programming_examples/basic/matrix_multiplication/whole_array/tiling_scratch.py new file mode 100644 index 0000000000..835efd87ea --- /dev/null +++ b/programming_examples/basic/matrix_multiplication/whole_array/tiling_scratch.py @@ -0,0 +1,92 @@ +import numpy as np + +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D, TensorTile +from util import construct_test + + +# RUN: %python %s | FileCheck %s +def ceildiv(a, b): + return -(a // -b) + + +def run_checks(n_aie_cols, n_aie_rows, M, N, K, m, n, k): + tb_max_n_rows = 4 + tb_n_rows = tb_max_n_rows // 2 + + # Define tilers + c_tiler = TensorTiler2D(M, N, m * n_aie_rows, n) + c_iter = c_tiler.tile_iter( + tile_repeat_step_horizontal=n_aie_cols, iter_step=tb_n_rows + ) + + for tb in range(ceildiv(M // m // n_aie_rows, tb_max_n_rows)): + for pingpong in [0, 1]: + row_base = tb * tb_max_n_rows + pingpong * tb_max_n_rows // 2 + tb_n_rows = min([tb_max_n_rows // 2, M // m // n_aie_rows - row_base]) + print(tb_n_rows) + if tb_n_rows <= 0: + # for small input sizes, we may not even need a "pong" iteration + break + + for col in range(n_aie_cols): + C_row_offset = row_base * m * n_aie_rows * N + C_col_offset = col * n + C_offset = C_col_offset + C_row_offset + C_sizes = [tb_n_rows, N // n // n_aie_cols, m * n_aie_rows, n] + C_strides = [m * n_aie_rows * N, n * n_aie_cols, N, 1] + expected_c_tile = TensorTile( + M, N, offset=C_offset, sizes=C_sizes, strides=C_strides + ) + + c_tile = next(c_iter) + if c_tile != expected_c_tile: + # equivalence for tensor tile checks offset, size, stride + # but there may be different but equivalent transformations + + reference_access, reference_count = expected_c_tile.access_tensors() + c_access, c_count = c_tile.access_tensors() + + """ + assert (reference_access == c_access).all(), ( + f"C access orders do not match. " + f"Expected ({expected_c_tile}), got ({c_tile})" + ) + assert (reference_count == c_count).all() + """ + print(f"Expected: {expected_c_tile}") + print(f"Actual: {c_tile}") + + +def matrix_whole_array_tiling_sweep(): + n_aie_cols_sweep = [1, 2, 4] # TODO: when partial, add 3 + n_aie_rows_sweep = [1, 2, 4] # TODO: when partial, add 3 + M_sweep = range(512, 4096, 512) + K_sweep = range(512, 4096, 512) + N_sweep = range(512, 4096, 512) + m_sweep = [16, 32, 64] + n_sweep = [16, 32, 64] + k_sweep = [16, 32, 64] + + for n_aie_cols in n_aie_cols_sweep: + for n_aie_rows in n_aie_rows_sweep: + for M in M_sweep: + for N in N_sweep: + for K in K_sweep: + for m in m_sweep: + for n in n_sweep: + for k in k_sweep: + run_checks( + n_aie_cols=n_aie_cols, + n_aie_rows=n_aie_rows, + M=M, + N=N, + K=K, + m=m, + k=k, + n=n, + ) + return + + +if __name__ == "__main__": + matrix_whole_array_tiling_sweep() diff --git a/programming_examples/basic/matrix_scalar_add/aie2.py b/programming_examples/basic/matrix_scalar_add/aie2.py index 87d75fd88b..75a980d208 100644 --- a/programming_examples/basic/matrix_scalar_add/aie2.py +++ b/programming_examples/basic/matrix_scalar_add/aie2.py @@ -12,6 +12,7 @@ from aie.dialects.aiex import * from aie.extras.context import mlir_mod_ctx from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D # Size of the entire image IMAGE_HEIGHT = 16 @@ -68,14 +69,16 @@ def core_body(): of_out1.release(ObjectFifoPort.Produce, 1) # To/from AIE-array data movement + tiler = TensorTiler2D(IMAGE_HEIGHT, IMAGE_WIDTH, TILE_HEIGHT, TILE_WIDTH) + t = next(tiler.tile_iter()) # Only transfer one (first) tile of data + @runtime_sequence(tile_ty, tile_ty, tile_ty) def sequence(inTensor, notUsed, outTensor): npu_dma_memcpy_nd( metadata=of_in1, bd_id=1, mem=inTensor, - sizes=[1, 1, TILE_HEIGHT, TILE_WIDTH], - strides=[1, 1, IMAGE_WIDTH, 1], + tensor_tile=t, issue_token=True, ) @@ -83,8 +86,7 @@ def sequence(inTensor, notUsed, outTensor): metadata=of_out1, bd_id=0, mem=outTensor, - sizes=[1, 1, TILE_HEIGHT, TILE_WIDTH], - strides=[1, 1, IMAGE_WIDTH, 1], + tensor_tile=t, ) dma_wait(of_in1, of_out1) diff --git a/programming_examples/basic/row_wise_bias_add/aie2.py b/programming_examples/basic/row_wise_bias_add/aie2.py index 1589473636..eac3b6d702 100644 --- a/programming_examples/basic/row_wise_bias_add/aie2.py +++ b/programming_examples/basic/row_wise_bias_add/aie2.py @@ -11,6 +11,7 @@ from aie.dialects.aiex import * from aie.extras.context import mlir_mod_ctx from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D def row_wise_bias_add(M, N, m, n): @@ -48,28 +49,32 @@ def core_body(): in_fifo.release(ObjectFifoPort.Consume, 1) bias_fifo.release(ObjectFifoPort.Consume, 1) + tiler = TensorTiler2D(M, N, m, n, tensor_col_major=True) + t = next( + tiler.tile_iter(tile_group_height=M // m, tile_group_width=N // n) + ) # Transfer all tiles at once + bias_tiler = TensorTiler2D(1, N, 1, n) + bias_t = next(bias_tiler.tile_iter(tile_group_width=N // n)) + @runtime_sequence(tensor_ty, bias_ty, tensor_ty) def sequence(inp, bias, out): npu_dma_memcpy_nd( metadata=in_fifo, bd_id=0, mem=inp, - sizes=[1, N // n, M, n], - strides=[0, n, N, 1], + tensor_tile=t, ) npu_dma_memcpy_nd( metadata=bias_fifo, bd_id=1, mem=bias, - sizes=[1, 1, N // n, n], - strides=[0, 0, n, 1], + tensor_tile=bias_t, ) npu_dma_memcpy_nd( metadata=out_fifo, bd_id=2, mem=out, - sizes=[1, N // n, M, n], - strides=[0, n, N, 1], + tensor_tile=t, ) # of_out will only complete after of_in completes, so we just wait on of_out instead of both dma_wait(out_fifo) diff --git a/programming_examples/basic/tiling_exploration/per_tile/Makefile b/programming_examples/basic/tiling_exploration/per_tile/Makefile new file mode 100644 index 0000000000..ccec0077b0 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/per_tile/Makefile @@ -0,0 +1,39 @@ +##===- Makefile -----------------------------------------------------------===## +# +# This file licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# Copyright (C) 2024, Advanced Micro Devices, Inc. +# +##===----------------------------------------------------------------------===## + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) + +include ${srcdir}/../../../makefile-common + +tensor_height = 32 +tensor_width = 32 +tile_height = 4 +tile_width = 4 +data_str=${tensor_height}_${tensor_width}_${tile_height}_${tile_width} + +.PHONY: all template clean + +all: build/final_${data_str}.xclbin + +build/aie_${data_str}.mlir: ${srcdir}/aie2.py + mkdir -p ${@D} + python3 $< --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width} > $@ + +build/final_${data_str}.xclbin: build/aie_${data_str}.mlir + mkdir -p ${@D} + cd ${@D} && aiecc.py --aie-generate-cdo --aie-generate-npu --no-compile-host \ + --no-xchesscc --no-xbridge \ + --xclbin-name=${@F} --npu-insts-name=insts_${data_str}.txt $(<:%=../%) + +run: build/final_${data_str}.xclbin build/insts_${data_str}.txt + ${powershell} python3 ${srcdir}/test.py -x build/final_${data_str}.xclbin -i build/insts_${data_str}.txt -k MLIR_AIE --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width} + +clean: + rm -rf build diff --git a/programming_examples/basic/tiling_exploration/per_tile/README.md b/programming_examples/basic/tiling_exploration/per_tile/README.md new file mode 100644 index 0000000000..cfb0936362 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/per_tile/README.md @@ -0,0 +1,31 @@ + + +# Tiling Exploration + +This IRON design flow example, called "Tiling Exploration", demonstrates how data may be `tiled` on input/output. This is a common data transformation pattern, and this example is meant to be interactive. + +## Source Files Overview + +TODO + +## Design Overview + +TODO + +## Design Component Details + +### AIE Array Structural Design + +TODO + +## Usage + +TODO diff --git a/programming_examples/basic/tiling_exploration/per_tile/aie2.py b/programming_examples/basic/tiling_exploration/per_tile/aie2.py new file mode 100644 index 0000000000..dfd6fc9495 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/per_tile/aie2.py @@ -0,0 +1,88 @@ +# tiling_exploration/aie2.py -*- Python -*- +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates +import argparse +import numpy as np +import sys + +from aie.dialects.aie import * +from aie.dialects.aiex import * +from aie.dialects import arith +from aie.extras.context import mlir_mod_ctx +from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D + + +def generate_module(tensor_height, tensor_width, tile_height, tile_width): + @device(AIEDevice.npu1_1col) + def device_body(): + # define types + tensor_size = tensor_height * tensor_width + tile_size = tile_height * tile_width + flattened_tensor = np.ndarray[(tensor_size,), np.dtype[TensorTiler2D.DTYPE]] + flattened_tile = np.ndarray[(tile_size,), np.dtype[TensorTiler2D.DTYPE]] + + # Tile declarations + ShimTile = tile(0, 0) + ComputeTile2 = tile(0, 2) + + # AIE-array data movement with object fifos + of_out = object_fifo("out", ComputeTile2, ShimTile, 2, flattened_tile) + + # Set up compute tiles + + # Compute tile 2 + @core(ComputeTile2) + def core_body(): + # TODO: better way to get mutable constant than buffer?? + access_counter = buffer( + ComputeTile2, + np.ndarray[(1,), np.dtype[TensorTiler2D.DTYPE]], + "access_counter", + initial_value=np.array([0], dtype=TensorTiler2D.DTYPE), + ) + for _ in range_(sys.maxsize): + elemOut = of_out.acquire(ObjectFifoPort.Produce, 1) + for i in range_(tile_size): + elemOut[i] = access_counter[0] + access_counter[0] += 1 + of_out.release(ObjectFifoPort.Produce, 1) + + @runtime_sequence(flattened_tensor) + def sequence(access_count): + tiler = TensorTiler2D(tensor_height, tensor_width, tile_height, tile_width) + for t in tiler.tile_iter(): + npu_dma_memcpy_nd( + metadata=of_out, + bd_id=1, + mem=access_count, + tensor_tile=t, + ) + dma_wait(of_out) + + +def main(opts): + with mlir_mod_ctx() as ctx: + generate_module( + opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width + ) + print(ctx.module) + + +def get_arg_parser(): + p = argparse.ArgumentParser() + p.add_argument("--tensor-height", required=True, help="Tensor height", type=int) + p.add_argument("--tensor-width", required=True, help="Tensor width", type=int) + p.add_argument("--tile-height", required=True, help="Tile height", type=int) + p.add_argument("--tile-width", required=True, help="Tile width", type=int) + return p + + +if __name__ == "__main__": + p = get_arg_parser() + opts = p.parse_args() + main(opts) diff --git a/programming_examples/basic/tiling_exploration/per_tile/run_makefile.lit b/programming_examples/basic/tiling_exploration/per_tile/run_makefile.lit new file mode 100644 index 0000000000..bdafa81a59 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/per_tile/run_makefile.lit @@ -0,0 +1,10 @@ +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// REQUIRES: ryzen_ai, peano +// +// RUN: make -f %S/Makefile clean +// RUN: make -f %S/Makefile +// RUN: %run_on_npu make -f %S/Makefile run | FileCheck %s +// CHECK: Running... +// CHECK: PASS! diff --git a/programming_examples/basic/tiling_exploration/per_tile/test.py b/programming_examples/basic/tiling_exploration/per_tile/test.py new file mode 100644 index 0000000000..d9adaaf2fc --- /dev/null +++ b/programming_examples/basic/tiling_exploration/per_tile/test.py @@ -0,0 +1,87 @@ +# tiling_exploration/test.py -*- Python -*- +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates +import argparse +import numpy as np + +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D +from aie.utils.xrt import setup_aie, execute as execute_on_aie + + +def main(opts): + print("Running...\n") + + dtype = TensorTiler2D.DTYPE + data_size = opts.tensor_height * opts.tensor_width + + reference_tiler = TensorTiler2D( + opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width + ) + reference_access_order = reference_tiler.access_order() + + app = setup_aie( + opts.xclbin, + opts.instr, + None, + None, + None, + None, + data_size, + dtype, + ) + aie_output = execute_on_aie(app) + aie_output = aie_output.reshape((opts.tensor_height, opts.tensor_width)) + + # Copy output results and verify they are correct + errors = 0 + if opts.verbosity >= 1: + print("Verifying results ...") + e = np.equal(reference_access_order, aie_output) + errors = np.size(e) - np.count_nonzero(e) + + if not errors: + print("\nPASS!\n") + exit(0) + else: + print("\nError count: ", errors) + print("\nFailed.\n") + exit(-1) + + +def get_arg_parser(): + p = argparse.ArgumentParser() + p.add_argument( + "-x", "--xclbin", default="final.xclbin", dest="xclbin", help="the xclbin path" + ) + p.add_argument( + "-k", + "--kernel", + dest="kernel", + default="MLIR_AIE", + help="the kernel name in the XCLBIN (for instance MLIR_AIE)", + ) + p.add_argument( + "-v", "--verbosity", default=0, type=int, help="the verbosity of the output" + ) + p.add_argument( + "-i", + "--instr", + dest="instr", + default="instr.txt", + help="path of file containing userspace instructions sent to the NPU", + ) + p.add_argument("--tensor-height", required=True, help="Tensor height", type=int) + p.add_argument("--tensor-width", required=True, help="Tensor width", type=int) + p.add_argument("--tile-height", required=True, help="Tile height", type=int) + p.add_argument("--tile-width", required=True, help="Tile width", type=int) + return p + + +if __name__ == "__main__": + p = get_arg_parser() + opts = p.parse_args() + main(opts) diff --git a/programming_examples/basic/tiling_exploration/single_transform/Makefile b/programming_examples/basic/tiling_exploration/single_transform/Makefile new file mode 100644 index 0000000000..ccec0077b0 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/single_transform/Makefile @@ -0,0 +1,39 @@ +##===- Makefile -----------------------------------------------------------===## +# +# This file licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# Copyright (C) 2024, Advanced Micro Devices, Inc. +# +##===----------------------------------------------------------------------===## + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) + +include ${srcdir}/../../../makefile-common + +tensor_height = 32 +tensor_width = 32 +tile_height = 4 +tile_width = 4 +data_str=${tensor_height}_${tensor_width}_${tile_height}_${tile_width} + +.PHONY: all template clean + +all: build/final_${data_str}.xclbin + +build/aie_${data_str}.mlir: ${srcdir}/aie2.py + mkdir -p ${@D} + python3 $< --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width} > $@ + +build/final_${data_str}.xclbin: build/aie_${data_str}.mlir + mkdir -p ${@D} + cd ${@D} && aiecc.py --aie-generate-cdo --aie-generate-npu --no-compile-host \ + --no-xchesscc --no-xbridge \ + --xclbin-name=${@F} --npu-insts-name=insts_${data_str}.txt $(<:%=../%) + +run: build/final_${data_str}.xclbin build/insts_${data_str}.txt + ${powershell} python3 ${srcdir}/test.py -x build/final_${data_str}.xclbin -i build/insts_${data_str}.txt -k MLIR_AIE --tensor-height ${tensor_height} --tensor-width ${tensor_width} --tile-height ${tile_height} --tile-width ${tile_width} + +clean: + rm -rf build diff --git a/programming_examples/basic/tiling_exploration/single_transform/README.md b/programming_examples/basic/tiling_exploration/single_transform/README.md new file mode 100644 index 0000000000..cfb0936362 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/single_transform/README.md @@ -0,0 +1,31 @@ + + +# Tiling Exploration + +This IRON design flow example, called "Tiling Exploration", demonstrates how data may be `tiled` on input/output. This is a common data transformation pattern, and this example is meant to be interactive. + +## Source Files Overview + +TODO + +## Design Overview + +TODO + +## Design Component Details + +### AIE Array Structural Design + +TODO + +## Usage + +TODO diff --git a/programming_examples/basic/tiling_exploration/single_transform/aie2.py b/programming_examples/basic/tiling_exploration/single_transform/aie2.py new file mode 100644 index 0000000000..d1b8f98951 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/single_transform/aie2.py @@ -0,0 +1,80 @@ +# tiling_exploration/aie2.py -*- Python -*- +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates +import argparse +import numpy as np +import sys + +from aie.dialects.aie import * +from aie.dialects.aiex import * +from aie.dialects import arith +from aie.extras.context import mlir_mod_ctx +from aie.helpers.dialects.ext.scf import _for as range_ +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D + + +def generate_module(tensor_height, tensor_width, tile_height, tile_width): + @device(AIEDevice.npu1_1col) + def device_body(): + # define types + tensor_size = tensor_height * tensor_width + flattened_tensor = np.ndarray[(tensor_size,), np.dtype[TensorTiler2D.DTYPE]] + + # Tile declarations + ShimTile = tile(0, 0) + ComputeTile2 = tile(0, 2) + + # AIE-array data movement with object fifos + of_out = object_fifo("out", ComputeTile2, ShimTile, 2, flattened_tensor) + + # Set up compute tiles + + # Compute tile 2 + @core(ComputeTile2) + def core_body(): + for _ in range_(sys.maxsize): + elemOut = of_out.acquire(ObjectFifoPort.Produce, 1) + for i in range_(tensor_size): + # TODO: fix need for cast here. + elemOut[i] = arith.index_cast(T.i32(), i) + of_out.release(ObjectFifoPort.Produce, 1) + + @runtime_sequence(flattened_tensor) + def sequence(access_count): + t = TensorTiler2D( + tensor_height, tensor_width, tile_height, tile_width + ).as_tile() + npu_dma_memcpy_nd( + metadata=of_out, + bd_id=1, + mem=access_count, + tensor_tile=t, + ) + dma_wait(of_out) + + +def main(opts): + with mlir_mod_ctx() as ctx: + generate_module( + opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width + ) + print(ctx.module) + + +def get_arg_parser(): + p = argparse.ArgumentParser() + p.add_argument("--tensor-height", required=True, help="Tensor height", type=int) + p.add_argument("--tensor-width", required=True, help="Tensor width", type=int) + p.add_argument("--tile-height", required=True, help="Tile height", type=int) + p.add_argument("--tile-width", required=True, help="Tile width", type=int) + return p + + +if __name__ == "__main__": + p = get_arg_parser() + opts = p.parse_args() + main(opts) diff --git a/programming_examples/basic/tiling_exploration/single_transform/run_makefile.lit b/programming_examples/basic/tiling_exploration/single_transform/run_makefile.lit new file mode 100644 index 0000000000..bdafa81a59 --- /dev/null +++ b/programming_examples/basic/tiling_exploration/single_transform/run_makefile.lit @@ -0,0 +1,10 @@ +// (c) Copyright 2024 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// REQUIRES: ryzen_ai, peano +// +// RUN: make -f %S/Makefile clean +// RUN: make -f %S/Makefile +// RUN: %run_on_npu make -f %S/Makefile run | FileCheck %s +// CHECK: Running... +// CHECK: PASS! diff --git a/programming_examples/basic/tiling_exploration/single_transform/test.py b/programming_examples/basic/tiling_exploration/single_transform/test.py new file mode 100644 index 0000000000..d9adaaf2fc --- /dev/null +++ b/programming_examples/basic/tiling_exploration/single_transform/test.py @@ -0,0 +1,87 @@ +# tiling_exploration/test.py -*- Python -*- +# +# This file is licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# +# (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates +import argparse +import numpy as np + +from aie.helpers.tensortiler.tensortiler2d import TensorTiler2D +from aie.utils.xrt import setup_aie, execute as execute_on_aie + + +def main(opts): + print("Running...\n") + + dtype = TensorTiler2D.DTYPE + data_size = opts.tensor_height * opts.tensor_width + + reference_tiler = TensorTiler2D( + opts.tensor_height, opts.tensor_width, opts.tile_height, opts.tile_width + ) + reference_access_order = reference_tiler.access_order() + + app = setup_aie( + opts.xclbin, + opts.instr, + None, + None, + None, + None, + data_size, + dtype, + ) + aie_output = execute_on_aie(app) + aie_output = aie_output.reshape((opts.tensor_height, opts.tensor_width)) + + # Copy output results and verify they are correct + errors = 0 + if opts.verbosity >= 1: + print("Verifying results ...") + e = np.equal(reference_access_order, aie_output) + errors = np.size(e) - np.count_nonzero(e) + + if not errors: + print("\nPASS!\n") + exit(0) + else: + print("\nError count: ", errors) + print("\nFailed.\n") + exit(-1) + + +def get_arg_parser(): + p = argparse.ArgumentParser() + p.add_argument( + "-x", "--xclbin", default="final.xclbin", dest="xclbin", help="the xclbin path" + ) + p.add_argument( + "-k", + "--kernel", + dest="kernel", + default="MLIR_AIE", + help="the kernel name in the XCLBIN (for instance MLIR_AIE)", + ) + p.add_argument( + "-v", "--verbosity", default=0, type=int, help="the verbosity of the output" + ) + p.add_argument( + "-i", + "--instr", + dest="instr", + default="instr.txt", + help="path of file containing userspace instructions sent to the NPU", + ) + p.add_argument("--tensor-height", required=True, help="Tensor height", type=int) + p.add_argument("--tensor-width", required=True, help="Tensor width", type=int) + p.add_argument("--tile-height", required=True, help="Tile height", type=int) + p.add_argument("--tile-width", required=True, help="Tile width", type=int) + return p + + +if __name__ == "__main__": + p = get_arg_parser() + opts = p.parse_args() + main(opts) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 97d2dbceb3..fa95393a48 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -42,6 +42,7 @@ declare_mlir_python_sources(AIEPythonSources.Helpers helpers/*.py helpers/dialects/ext/*.py helpers/runtime/*.py + helpers/tensortiler/*.py ) declare_mlir_dialect_python_bindings( diff --git a/python/dialects/aiex.py b/python/dialects/aiex.py index 0f5c329e1b..c3cd2d124b 100644 --- a/python/dialects/aiex.py +++ b/python/dialects/aiex.py @@ -28,6 +28,7 @@ # noinspection PyUnresolvedReferences from ..extras import types as T from ..helpers.util import try_convert_np_type_to_mlir_type, np_ndarray_type_get_shape +from ..helpers.tensortiler.tensortiler2d import TensorTile # Comes from _aie register_dialect(get_dialect_registry()) @@ -55,19 +56,29 @@ def __init__( metadata: str | ObjectFifoCreateOp, bd_id, mem, - offsets: MixedValues = None, - sizes: MixedValues = None, - strides: MixedValues = None, + tensor_tile: TensorTile | None = None, + offsets: MixedValues | None = None, + sizes: MixedValues | None = None, + strides: MixedValues | None = None, issue_token: bool | None = None, ): x = 0 y = 0 - if offsets is None: - offsets = [0] * 4 - if sizes is None: - sizes = [0] * 4 - if strides is None: - strides = [0] * 3 + [1] + if tensor_tile and not (offsets is None and sizes is None and strides is None): + raise ValueError( + "NpuDmaMemcpyNd can take either a tensor_tile OR (sizes and/or strides and/or offsets), but not both." + ) + if tensor_tile: + sizes = tensor_tile.sizes.copy() + strides = tensor_tile.strides.copy() + offsets = [0, 0, 0, tensor_tile.offset] + else: + if offsets is None: + offsets = [0] * 4 + if sizes is None: + sizes = [0] * 4 + if strides is None: + strides = [0] * 3 + [1] dynamic_offsets, _packed_offsets, static_offsets = _dispatch_mixed_values( offsets ) diff --git a/python/helpers/tensortiler/__init__.py b/python/helpers/tensortiler/__init__.py new file mode 100644 index 0000000000..5641295b7f --- /dev/null +++ b/python/helpers/tensortiler/__init__.py @@ -0,0 +1,5 @@ +from .tensortile import TensorTile +from .tensortilesequence import ( + TensorTileSequence, +) +from .tensortiler2d import TensorTiler2D diff --git a/python/helpers/tensortiler/tensortile.py b/python/helpers/tensortiler/tensortile.py new file mode 100644 index 0000000000..b9df1d1a8a --- /dev/null +++ b/python/helpers/tensortiler/tensortile.py @@ -0,0 +1,120 @@ +from copy import deepcopy +import numpy as np +import itertools +from typing import Sequence + +from .utils import ( + validate_and_clean_sizes_strides, + validate_offset, + validate_tensor_dims, +) +from .visualization2d import visualize_from_access_tensors + + +class TensorTile: + _DTYPE = np.int32 + + def __init__( + self, + tensor_dims: Sequence[int], + offset: int, + sizes: Sequence[int], + strides: Sequence[int], + ): + self._tensor_dims = validate_tensor_dims(tensor_dims) + self._offset = validate_offset(offset) + if self._offset >= np.prod(tensor_dims): + raise ValueError( + f"Offset too large: {self._offset}. Max value allowed for tensor: {np.prod(tensor_dims)}" + ) + self._sizes, self._strides = validate_and_clean_sizes_strides(sizes, strides) + + @property + def tensor_dims(self) -> Sequence[int]: + # Copy to prevent callers from mutating self + return deepcopy(self._tensor_dims) + + @property + def offset(self) -> int: + return self._offset + + @property + def sizes(self) -> Sequence[int]: + # Copy to prevent callers from mutating self + return deepcopy(self._sizes) + + @property + def strides(self) -> Sequence[int]: + # Copy to prevent callers from mutating self + return deepcopy(self._strides) + + @property + def transformation_dims(self) -> Sequence[tuple[int, int]]: + return list(zip(self._sizes, self._strides)) + + def access_tensors(self) -> tuple[np.ndarray, np.ndarray]: + # TODO: should access order be a list of lists instead of generate two separate tensors? + # TODO: for performance, should cache and return copies? Or just cache? + + # Initialize access order and count maps + total_elems = np.prod(self._tensor_dims) + access_order_tensor = np.full(total_elems, -1, dtype=self._DTYPE) + access_count_tensor = np.full(total_elems, 0, dtype=self._DTYPE) + + # Use itertools.product to collapse len(sizes) nested forloop into one forloop + access_count = 0 + for dims in itertools.product(*[range(0, n) for n in self._sizes]): + access_idx = ( + self._offset + np.sum(np.multiply(dims, self._strides)) + ) % total_elems + access_count_tensor[access_idx] += 1 + access_order_tensor[access_idx] = access_count + access_count += 1 + + access_order_tensor = access_order_tensor.reshape(self._tensor_dims) + access_count_tensor = access_count_tensor.reshape(self._tensor_dims) + return access_order_tensor, access_count_tensor + + def visualize( + self, + show_arrows: bool | None = None, + title: str = None, + file_path: str | None = None, + show_plot: bool = True, + plot_access_count: bool = False, + ) -> None: + access_order, access_count = self.access_tensors() + if title is None: + title = str(self) + if not plot_access_count: + access_count = None + if len(self._tensor_dims) == 2: + visualize_from_access_tensors( + access_order, + access_count, + title=title, + show_arrows=show_arrows, + file_path=file_path, + show_plot=show_plot, + ) + else: + raise NotImplementedError( + "Visualization is only currently supported for 1- or 2-dimensional tensors" + ) + + def __str__(self) -> str: + return f"TensorTile(offset={self._offset}, sizes={self._sizes}, strides={self._strides})" + + def __eq__(self, other): + if isinstance(other, self.__class__): + return ( + self._tensor_dims == other._tensor_dims + and self._offset == other._offset + and self._sizes == other._sizes + and self._strides == other._strides + ) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/python/helpers/tensortiler/tensortiler2d.py b/python/helpers/tensortiler/tensortiler2d.py new file mode 100644 index 0000000000..c70481b587 --- /dev/null +++ b/python/helpers/tensortiler/tensortiler2d.py @@ -0,0 +1,337 @@ +from copy import deepcopy +from functools import partial +import numpy as np +from typing import Sequence + +from .tensortilesequence import TensorTileSequence +from .utils import ceildiv, validate_and_clean_sizes_strides, validate_tensor_dims + + +class TensorTiler2D: + """ + This is a generator (similar to factory pattern) class which produces TensorTileSequence + objects for common 2-dimensional tiling patterns. + """ + + _DTYPE = np.int32 + _NUM_DIMS = 2 + + def __init__(self): + raise Exception( + f"{self.__class__} cannot be instantiated. Use it as a factory/generator of TensorTileSequences." + ) + + @classmethod + def simple_tiler( + cls, + tensor_dims: Sequence[int], + tile_dims: Sequence[int] | None = None, + tile_col_major: bool = False, + iter_col_major: bool = False, + pattern_repeat: int = 1, + ) -> TensorTileSequence: + if tile_dims is None: + tile_dims = deepcopy(tensor_dims) + # Special case of group_tiler + return cls.group_tiler( + tensor_dims=tensor_dims, + tile_dims=tile_dims, + tile_col_major=tile_col_major, + iter_col_major=iter_col_major, + pattern_repeat=pattern_repeat, + ) + + @classmethod + def group_tiler( + cls, + tensor_dims: Sequence[int], + tile_dims: Sequence[int], + tile_group_dims: Sequence[int] | None = None, + tile_col_major: bool = False, + tile_group_col_major: bool = False, + iter_col_major: bool = False, + pattern_repeat: int = 1, + allow_partial: bool = False, + ) -> TensorTileSequence: + if tile_group_dims is None: + tile_group_dims = (1,) * cls._NUM_DIMS + # Special case of step_tiler + return cls.step_tiler( + tensor_dims=tensor_dims, + tile_dims=tile_dims, + tile_group_repeats=tile_group_dims, + tile_col_major=tile_col_major, + tile_group_col_major=tile_group_col_major, + iter_col_major=iter_col_major, + pattern_repeat=pattern_repeat, + allow_partial=allow_partial, + ) + + @classmethod + def step_tiler( + cls, + tensor_dims: Sequence[int], + tile_dims: Sequence[int], + tile_group_repeats: Sequence[int], + tile_group_steps: Sequence[int] | None = None, + tile_col_major: bool = False, + tile_group_col_major: bool = False, + iter_col_major: bool = False, + allow_partial: bool = False, + pattern_repeat: int = 1, + ) -> TensorTileSequence: + if tile_group_steps is None: + tile_group_steps = (1,) * cls._NUM_DIMS + tensor_dims = validate_tensor_dims(tensor_dims, expected_dims=cls._NUM_DIMS) + tile_dims = validate_tensor_dims(tile_dims, expected_dims=cls._NUM_DIMS) + tile_group_repeats = validate_tensor_dims( + tile_group_repeats, expected_dims=cls._NUM_DIMS + ) + if len(tile_group_steps) != cls._NUM_DIMS: + raise ValueError(f"Expcted {cls._NUM_DIMS} dimensions of tile group steps") + for i, s in enumerate(tile_group_steps): + if s < 0: + raise ValueError( + f"Tile group step dimension {i} must be >= 0, but got {s}" + ) + + for i, (tensor_dim, tile_dim) in enumerate(zip(tensor_dims, tile_dims)): + if tensor_dim % tile_dim != 0: + raise ValueError( + f"Tensor dimension {i} ({tensor_dim}) is not divisible by tile dim ({tile_dim})" + ) + + if not isinstance(pattern_repeat, int) or pattern_repeat < 1: + raise ValueError(f"Pattern repeat must be >= 1 but is {pattern_repeat}") + if not allow_partial: + for i, (tensor_dim, tile_dim, repeat_dim, step_dim) in enumerate( + zip(tensor_dims, tile_dims, tile_group_repeats, tile_group_steps) + ): + if tensor_dim % (tile_dim * repeat_dim * step_dim) != 0: + raise ValueError( + f"allow_partial={allow_partial} but tensor does not divide evenly into tile groups in dimension {i}" + ) + if tile_dim * repeat_dim * step_dim > tensor_dim: + raise ValueError( + f"Tile pattern exceeds tensor size in dimension {i} ({tile_dim}x{repeat_dim}x{step_dim} > {tensor_dim})" + ) + + steps_per_dim = cls.__get_num_steps( + tensor_dims=tensor_dims, + tile_dims=tile_dims, + step_dims=tile_group_steps, + repeat_dims=tile_group_repeats, + ) + num_steps = np.prod(steps_per_dim) + + def offset_fn(step_num: int, _prev_offset: int) -> int: + tile_offsets = cls.__tile_offset_by_step_num( + step_num, + tile_group_steps, + tile_group_repeats, + steps_per_dim, + iter_col_major, + ) + total_offset = 0 + num_dims = len(tile_offsets) + for dim, (offset, tile_dim) in enumerate(zip(tile_offsets, tile_dims)): + total_offset += ( + offset + * tile_dim + * (np.prod(tensor_dims[dim + 1 :]) if dim < num_dims - 1 else 1) + ) + return total_offset + + def sizes_or_strides_fn(step_num, _prev_sizes, is_sizes): + tile_offsets = cls.__tile_offset_by_step_num( + step_num, + tile_group_steps, + tile_group_repeats, + steps_per_dim, + iter_col_major, + ) + + iter_sizes, iter_strides = cls.__sizes_strides_for_step_tile_group( + tensor_dims, + tile_dims, + tile_group_steps, + tile_group_repeats, + tile_offsets, + tile_col_major, + tile_group_col_major, + pattern_repeat=pattern_repeat, + ) + if is_sizes: + return iter_sizes + else: + return iter_strides + + sizes_fn = partial(sizes_or_strides_fn, is_sizes=True) + strides_fn = partial(sizes_or_strides_fn, is_sizes=False) + + return TensorTileSequence( + tensor_dims, + num_steps, + sizes_fn=sizes_fn, + strides_fn=strides_fn, + offset_fn=offset_fn, + ) + + @classmethod + def __get_num_steps( + cls, + tensor_dims: Sequence[int], + tile_dims: Sequence[int], + step_dims: Sequence[int], + repeat_dims: Sequence[int], + ) -> Sequence[int]: + num_steps_dims = [] + for tensor_dim, tile_dim, step_dim, repeat_dim in zip( + tensor_dims, tile_dims, step_dims, repeat_dims + ): + num_steps_per_dim = tensor_dim // (tile_dim * repeat_dim) + tiles_in_tensor = (tensor_dim // tile_dim) - num_steps_per_dim * repeat_dim + partial_height_steps = 0 + while tiles_in_tensor > 0: + tiles_in_tensor -= min(repeat_dim, ceildiv(tiles_in_tensor, step_dim)) + partial_height_steps += 1 + num_steps_per_dim += partial_height_steps + num_steps_dims.append(num_steps_per_dim) + return num_steps_dims + + @classmethod + def __tile_offset_by_step_num( + cls, + step_num: int, + tile_group_steps: Sequence[int], + tile_group_repeats: Sequence[int], + num_steps: Sequence[int], + iter_col_major: bool, + ) -> Sequence[int]: + # TODO: this code is still specific to two dimensions + steps_per_col, steps_per_row = num_steps + tile_step_height, tile_step_width = tile_group_steps + tile_repeat_height, tile_repeat_width = tile_group_repeats + + if not iter_col_major: + row_idx = step_num % steps_per_row + col_idx = step_num // steps_per_row + else: + col_idx = step_num % steps_per_col + row_idx = step_num // steps_per_col + + col_chunk_idx = col_idx // tile_step_height + row_chunk_idx = row_idx // tile_step_width + col_in_chunk_idx = col_idx % tile_step_height + row_in_chunk_idx = row_idx % tile_step_width + + tile_offset_in_row = ( + row_chunk_idx * tile_step_width * tile_repeat_width + row_in_chunk_idx + ) + tile_offset_in_col = ( + col_chunk_idx * tile_step_height * tile_repeat_height + col_in_chunk_idx + ) + return (tile_offset_in_col, tile_offset_in_row) + + @classmethod + def __sizes_strides_for_step_tile_group( + cls, + tensor_dims: Sequence[int], + tile_dims: Sequence[int], + tile_group_steps: Sequence[int], + tile_group_repeats: Sequence[int], + tile_offsets: Sequence[int], + tile_col_major: bool, + tile_group_col_major: bool, + pattern_repeat: int, + ) -> tuple[Sequence[int], Sequence[int]]: + # TODO: this code is still specific to two dimensions + # TODO: this code assumes sizes/strides of len 4 + + # Interior method, assumes all validation already done + tensor_height, tensor_width = tensor_dims + tile_height, tile_width = tile_dims + tile_step_height, tile_step_width = tile_group_steps + tile_repeat_height, tile_repeat_width = tile_group_repeats + tile_offset_height, tile_offset_width = tile_offsets + + tiles_remaining_height = tensor_height // tile_height - tile_offset_height + tiles_remaining_width = tensor_width // tile_width - tile_offset_width + + # use tile offsets to prune step + if tile_step_height > tiles_remaining_height: + tile_step_height = 1 + if tile_step_width > tiles_remaining_width: + tile_step_width = 1 + + # use tile offsets to prune repeat count + tile_repeat_width = min( + tile_repeat_width, + ceildiv(tiles_remaining_width, tile_step_width), + ) + tile_repeat_height = min( + tile_repeat_height, + ceildiv(tiles_remaining_height, tile_step_height), + ) + tile_group_repeats = (tile_repeat_height, tile_repeat_width) + + if ( + tile_group_col_major + and tile_step_height == 1 + and tile_repeat_height > 1 + and not tile_col_major + ): + # Can combine into one big tile vertically + tile_height *= tile_repeat_height + tile_repeat_height = 1 + elif ( + not tile_group_col_major + and tile_step_width == 1 + and tile_repeat_width > 1 + and tile_col_major + ): + # Can combine into one big tile horizontally + tile_width *= tile_repeat_width + tile_repeat_width = 1 + + if not tile_col_major: + iter_sizes = [1, 1, tile_height, tile_width] + iter_strides = [0, 0, tensor_width, 1] + else: + iter_sizes = [1, 1, tile_width, tile_height] + iter_strides = [0, 0, 1, tensor_width] + + if tile_repeat_width > 1 and tile_repeat_height > 1: + idx_order = [0, 1] + if tile_group_col_major: + idx_order = [1, 0] + iter_sizes[idx_order[0]] = tile_repeat_height + iter_sizes[idx_order[1]] = tile_repeat_width + iter_strides[idx_order[0]] = tensor_width * tile_height * tile_step_height + iter_strides[idx_order[1]] = tile_width * tile_step_width + elif tile_repeat_height > 1: + iter_sizes[1], iter_strides[1] = ( + tile_repeat_height, + tensor_width * tile_height * tile_step_height, + ) + elif tile_repeat_width > 1: + iter_sizes[1], iter_strides[1] = ( + tile_repeat_width, + tile_width * tile_step_width, + ) + + iter_sizes, iter_strides = validate_and_clean_sizes_strides( + iter_sizes, iter_strides + ) + + if pattern_repeat != 1: + if iter_sizes[1] == 1 and iter_strides[0] == 0: + iter_sizes[1] = pattern_repeat + else: + if iter_sizes[0] != 1 or iter_strides[0] != 0: + raise ValueError( + f"Ran out of dimensions for repeat (sizes={iter_sizes}, strides={iter_strides})" + ) + iter_sizes[0] = pattern_repeat + + return iter_sizes, iter_strides diff --git a/python/helpers/tensortiler/tensortilesequence.py b/python/helpers/tensortiler/tensortilesequence.py new file mode 100644 index 0000000000..f57b44fd44 --- /dev/null +++ b/python/helpers/tensortiler/tensortilesequence.py @@ -0,0 +1,247 @@ +from __future__ import annotations +from collections import abc +from copy import deepcopy +import matplotlib.animation as animation +import numpy as np +from typing import Callable, Sequence + +from .tensortile import TensorTile +from .utils import ( + validate_and_clean_sizes_strides, + validate_offset, + validate_tensor_dims, +) +from .visualization2d import animate_from_access_tensors, visualize_from_access_tensors + + +class TensorTileSequence(abc.MutableSequence, abc.Iterable): + """ + TensorTileSequence is a MutableSequence and an Iterable which is a thin wrapper around a list[TensorTiles]. + """ + + def __init__( + self, + tensor_dims: Sequence[int], + num_steps: int, + offset: int | None = None, + sizes: Sequence[int] | None = None, + strides: Sequence[int] | None = None, + offset_fn: Callable[[int, int], int] | None = None, + sizes_fn: Callable[[int, Sequence[int]], Sequence[int]] | None = None, + strides_fn: Callable[[int, Sequence[int]], Sequence[int]] | None = None, + ): + self._current_step = 0 + + # Check tensor dims, offset, sizes, strides + self._tensor_dims = validate_tensor_dims(tensor_dims) + if not (offset is None): + offset = validate_offset(offset) + sizes, strides = validate_and_clean_sizes_strides( + sizes, strides, allow_none=True + ) + + # Validate and set num steps + if num_steps < 0: + raise ValueError(f"Number of steps must be positive (but is {num_steps})") + + if num_steps == 0: + if ( + offset != None + or sizes != None + or strides != None + or offset_fn != None + or sizes_fn != None + or strides_fn != None + ): + raise ValueError( + f"If num_steps=0, no sizes/strides/offset information may be specified" + ) + self._tiles = [] + else: + # Make sure values or not None if iteration functions are None; also set default iter fn + if offset_fn is None: + if offset is None: + raise ValueError("Offset must be provided if offset_fn is None") + offset_fn = lambda _step, _prev_offset: offset + else: + offset_fn = offset_fn + if sizes_fn is None: + if sizes is None: + raise ValueError("Sizes must be provided if size_fn is None") + sizes_fn = lambda _step, _prev_sizes: sizes + else: + sizes_fn = sizes_fn + if strides_fn is None: + if strides is None: + raise ValueError("Strides must be provided if stride_fn is None") + strides_fn = lambda _step, _prev_strides: strides + else: + strides_fn = strides_fn + + # Pre-calculate tiles, because better for error handling up-front (and for visualizing full iter) + # This is somewhat against the mentality behind iterations, but should be okay at the scale this + # class will be used for (e.g., no scalability concerns with keeping all tiles in mem) + self._tiles = [] + for step in range(num_steps): + offset = offset_fn(step, offset) + sizes = sizes_fn(step, sizes) + strides = strides_fn(step, strides) + + self._tiles.append( + TensorTile( + self._tensor_dims, + offset, + sizes, + strides, + ) + ) + + @classmethod + def from_tiles(cls, tiles: Sequence[TensorTile]) -> TensorTileSequence: + if len(tiles) < 1: + raise ValueError( + "Received no tiles; must have at least one tile to create a tile sequence." + ) + tensor_dims = tiles[0].tensor_dims + for t in tiles: + if t.tensor_dims != tensor_dims: + raise ValueError( + f"Tiles have multiple tensor dimensions (found {tensor_dims} and {t.tensor_dims})" + ) + tileseq = cls( + tensor_dims, + num_steps=1, + offset=tiles[0].offset, + sizes=tiles[0].sizes, + strides=tiles[0].strides, + ) + for t in tiles[1:]: + tileseq.append(t) + return tileseq + + def access_tensors(self) -> tuple[np.ndarray, np.ndarray]: + total_elems = np.prod(self._tensor_dims) + + combined_access_order_tensor = np.full( + total_elems, 0, TensorTile._DTYPE + ).reshape(self._tensor_dims) + combined_access_count_tensor = np.full( + total_elems, 0, TensorTile._DTYPE + ).reshape(self._tensor_dims) + highest_count = 0 + for t in self._tiles: + t_access_order, t_access_count = t.access_tensors() + t_access_order[t_access_order != -1] += 1 + highest_count + t_access_order[t_access_order == -1] = 0 + combined_access_order_tensor += t_access_order + highest_count = np.max(combined_access_order_tensor) + + combined_access_count_tensor += t_access_count + + combined_access_order_tensor -= 1 + return (combined_access_order_tensor, combined_access_count_tensor) + + def animate( + self, title: str = None, animate_access_count: bool = False + ) -> animation.FuncAnimation: + if title is None: + title = "TensorTileSequence Animation" + if len(self._tensor_dims) == 2: + total_elems = np.prod(self._tensor_dims) + + animate_order_frames = [ + np.full(total_elems, -1, TensorTile._DTYPE).reshape(self._tensor_dims) + ] + if animate_access_count: + animate_count_frames = [ + np.full(total_elems, 0, TensorTile._DTYPE).reshape( + self._tensor_dims + ) + ] + else: + animate_count_frames = None + + for t in self._tiles: + t_access_order, t_access_count = t.access_tensors() + animate_order_frames.append(t_access_order) + if animate_access_count: + animate_count_frames.append(t_access_count) + + return animate_from_access_tensors( + animate_order_frames, + animate_count_frames, + title=title, + ) + + else: + raise NotImplementedError( + "Visualization is only currently supported for 1- or 2-dimensional tensors" + ) + + def visualize( + self, + title: str = None, + file_path: str | None = None, + show_plot: bool = True, + plot_access_count: bool = False, + ) -> None: + if len(self._tensor_dims) != 2: + raise NotImplementedError( + "Visualization is only currently supported for 1- or 2-dimensional tensors" + ) + + if title is None: + title = "TensorTileSequence" + access_order_tensor, access_count_tensor = self.access_tensors() + if not plot_access_count: + access_count_tensor = None + + visualize_from_access_tensors( + access_order_tensor, + access_count_tensor, + title=title, + show_arrows=False, + file_path=file_path, + show_plot=show_plot, + ) + + def __contains__(self, tile: TensorTile): + return tile in self._tiles + + def __iter__(self): + return iter(deepcopy(self._tiles)) + + def __len__(self) -> int: + return len(self._tiles) + + def __getitem__(self, idx: int) -> TensorTile: + return self._tiles[idx] + + def __setitem__(self, idx: int, tile: TensorTile): + if self._tensor_dims != tile.tensor_dims: + raise ValueError( + f"Cannot add tile with tensor dims {tile.tensor_dims} to sequence of tiles with tensor dims {self._tensor_dims}" + ) + self._tiles[idx] = deepcopy(tile) + + def __delitem__(self, idx: int): + del self._tiles[idx] + + def insert(self, idx: int, tile: TensorTile): + if self._tensor_dims != tile.tensor_dims: + raise ValueError( + f"Cannot add tile with tensor dims {tile.tensor_dims} to sequence of tiles with tensor dims {self._tensor_dims}" + ) + self._tiles.insert(idx, tile) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return ( + self._tiles == other._tiles + and self._current_step == other._current_step + ) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/python/helpers/tensortiler/utils.py b/python/helpers/tensortiler/utils.py new file mode 100644 index 0000000000..c997e72d94 --- /dev/null +++ b/python/helpers/tensortiler/utils.py @@ -0,0 +1,114 @@ +from copy import deepcopy +from typing import Sequence + + +def ceildiv(a, b): + return -(a // -b) + + +def validate_and_clean_sizes_strides( + sizes: Sequence[int] | None, + strides: Sequence[int] | None, + allow_none: bool = False, + expected_dims: int | None = None, +) -> tuple[Sequence[int] | None, Sequence[int] | None]: + if not allow_none: + if sizes is None: + raise ValueError("Sizes is None, but expected Sequence[int]") + if strides is None: + raise ValueError("Strides is None, but expected Sequence[int]") + # After this point can assume None is ok for sizes/strides + + if not (expected_dims is None): + if expected_dims < 1: + raise ValueError(f"Expected dimensions ({expected_dims}) should be >= 1") + + if sizes is None and strides is None: + # nothing to do + return None, None + + # Validate dimensions + if (not (sizes is None)) and len(sizes) == 0: + raise ValueError("len(sizes) must be >0") + if (not (strides is None)) and len(strides) == 0: + raise ValueError("len(strides) must be >0") + + if sizes and strides: + if expected_dims: + if len(sizes) != expected_dims: + raise ( + f"Num dimensions of sizes ({sizes}) is not expected number of dimensions ({expected_dims})" + ) + if len(strides) != expected_dims: + raise ( + f"Num dimensions of strides ({strides}) is not expected number of dimensions ({expected_dims})" + ) + elif len(strides) != len(sizes): + raise ValueError( + f"len(sizes) ({len(sizes)}) != len(strides) ({len(strides)})" + ) + if strides: + num_dims = len(strides) + else: + num_dims = len(sizes) + + # Validate sizes/strides values + if sizes: + sizes = deepcopy(sizes) + for s in sizes: + if s < 1: + raise ValueError(f"All sizes must be >= 1, but got {sizes}") + if strides: + strides = deepcopy(strides) + for s in strides: + if s < 0: + raise ValueError(f"All strides must be >= 0, but got {strides}") + + # Clean (set size=1, stride=0 for as many dims as possible) + if sizes and strides: + for i in range(num_dims): + if sizes[i] == 1: + if isinstance(strides, tuple): + # Tuple is immutable, so convert if necessary + strides = list(strides) + strides[i] = 0 + else: + break + return sizes, strides + + +def validate_tensor_dims( + tensor_dims: Sequence[int], expected_dims: int | None = None +) -> Sequence[int]: + if not (expected_dims is None): + if expected_dims < 1: + raise ValueError(f"Expected dimensions ({expected_dims}) should be >= 1") + tensor_dims = deepcopy(tensor_dims) + + # Validate tensor dims and offset, then set + if len(tensor_dims) == 0: + raise ValueError( + f"Number of tensor dimensions must be >= 1 (dimensions={tensor_dims})" + ) + for d in tensor_dims: + if d <= 0: + raise ValueError( + f"Each tensor dimension must be >= 1 (dimensions={tensor_dims})" + ) + + # We can treat a 1-dimensional tensor as a 2-dimensional tensor, + if len(tensor_dims) == 1: + tensor_dims = [1, tensor_dims[0]] + + if not (expected_dims is None) and len(tensor_dims) != expected_dims: + raise ValueError( + f"Tensor dimension ({tensor_dims}) does not match expected dimension ({expected_dims})" + ) + + return tensor_dims + + +def validate_offset(offset: int): + if offset < 0: + raise ValueError(f"Offset must be >= 0 (offset={offset})") + return offset diff --git a/python/helpers/tensortiler/visualization2d.py b/python/helpers/tensortiler/visualization2d.py new file mode 100644 index 0000000000..7244475f8e --- /dev/null +++ b/python/helpers/tensortiler/visualization2d.py @@ -0,0 +1,172 @@ +import matplotlib.animation as animation +import matplotlib.patheffects as pe +import matplotlib.pyplot as plt +import numpy as np +import os +import sys + +from .utils import ceildiv + + +def animate_from_access_tensors( + access_order_tensors: list[np.ndarray], + access_count_tensors: list[np.ndarray] | None, + title: str = "Animated Access Visualization", +) -> animation.FuncAnimation: + if len(access_order_tensors) < 1: + raise ValueError("At least one access order tensor is required.") + if not (access_count_tensors is None): + if len(access_count_tensors) < 1: + raise ValueError( + "access_count_tensor should be None or requires at least one tensor" + ) + if len(access_count_tensors) != len(access_order_tensors): + raise ValueError( + "Number of access count tensors and number of access order tensors should be equal" + ) + + tensor_height, tensor_width = access_order_tensors[0].shape + fig_width = 7 + if tensor_width < 32: + fig_width = 5 + height_width_ratio = ceildiv(tensor_height, tensor_width) + fig_height = min(fig_width, fig_width * height_width_ratio) + + if not (access_count_tensors is None): + fig_height *= 2 + fig, (ax_order, ax_count) = plt.subplots(2, 1) + else: + fig, ax_order = plt.subplots() + + fig.set_figheight(fig_height) + fig.set_figwidth(fig_width) + fig.suptitle(title) + xs = np.arange(access_order_tensors[0].shape[1]) + ys = np.arange(access_order_tensors[0].shape[0]) + + ax_order.xaxis.tick_top() + ax_order.invert_yaxis() + ax_order.set_title("Access Order Animation") + + if not (access_count_tensors is None): + ax_count.xaxis.tick_top() + ax_count.invert_yaxis() + ax_count.set_title(f"Access Counts") + + def animate_order(i): + access_heatmap = ax_order.pcolormesh(access_order_tensors[i]) + + if not (access_count_tensors is None): + count_heatmap = ax_count.pcolormesh( + xs, ys, access_count_tensors[i], cmap="gnuplot2" + ) + return ( + access_heatmap, + count_heatmap, + ) + return access_heatmap + + _animation = animation.FuncAnimation( + fig, + animate_order, + frames=len(access_order_tensors), + interval=max(400, 100 + 5 * len(access_order_tensors)), + ) + + plt.tight_layout() + plt.close() + return _animation + + +def visualize_from_access_tensors( + access_order_tensor: np.ndarray, + access_count_tensor: np.ndarray | None, + title: str = "Access Visualization", + show_arrows: bool | None = None, + file_path: str | None = None, + show_plot: bool = True, +): + tensor_height, tensor_width = access_order_tensor.shape + if tensor_height * tensor_width >= 1024: + if show_arrows: + print( + f"show_arrows not recommended for tensor sizes > 1024 elements", + file=sys.stderr, + ) + if show_arrows is None: + show_arrows = False + elif show_arrows is None: + # Set to true by default only for 'small' tensor sizes + show_arrows = True + + fig_width = 7 + if tensor_width < 32: + fig_width = 5 + height_width_ratio = ceildiv(tensor_height, tensor_width) + fig_height = min(fig_width, fig_width * height_width_ratio) + + if not (access_count_tensor is None): + fig_height *= 2 + fig, (ax_order, ax_count) = plt.subplots(2, 1) + else: + fig, ax_order = plt.subplots() + + fig.set_figheight(fig_height) + fig.set_figwidth(fig_width) + fig.suptitle(title) + xs = np.arange(access_order_tensor.shape[1]) + ys = np.arange(access_order_tensor.shape[0]) + + _access_heatmap = ax_order.pcolormesh(xs, ys, access_order_tensor, cmap="gnuplot2") + ax_order.xaxis.tick_top() + ax_order.invert_yaxis() + ax_order.set_title("Access Order") + + if not (access_count_tensor is None): + max_count = np.max(access_count_tensor) + _count_heatmap = ax_count.pcolormesh( + xs, ys, access_count_tensor, cmap="gnuplot2" + ) + ax_count.xaxis.tick_top() + ax_count.invert_yaxis() + ax_count.set_title(f"Access Counts (max={max_count})") + + # Add arrows to show access order + if show_arrows: + # Thanks to https://stackoverflow.com/questions/37719304/python-imshow-set-certain-value-to-defined-color + # Thanks to tmdavison answer here https://stackoverflow.com/a/40890587/7871710 + + order_dict = {} + for i in range(access_order_tensor.shape[0]): + for j in range(access_order_tensor.shape[1]): + if access_order_tensor[i, j] != -1: + order_dict[access_order_tensor[i, j]] = (i, j) + + order_keys = list(order_dict.keys()) + order_keys.sort() + for i in range(order_keys[0], order_keys[-1]): + y1, x1 = order_dict[i] + y2, x2 = order_dict[i + 1] + ax_order.arrow( + x1, + y1, + x2 - x1, + y2 - y1, + length_includes_head=True, + head_width=0.1, + head_length=0.15, + overhang=0.2, + path_effects=[pe.withStroke(linewidth=3, foreground="white")], + ) + + plt.tight_layout() + if show_plot: + plt.show() + if file_path: + if os.path.exists(file_path): + print( + f"Cannot save plot to {file_path}; file already exists", + file=sys.stderr, + ) + plt.savefig(file_path) + plt.close() diff --git a/python/requirements.txt b/python/requirements.txt index 780c4f4481..6efee7f4bd 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -12,3 +12,4 @@ rich setuptools wheel ml_dtypes +matplotlib diff --git a/python/utils/xrt.py b/python/utils/xrt.py index 276e9e2cde..bd4fc1e3a8 100644 --- a/python/utils/xrt.py +++ b/python/utils/xrt.py @@ -131,7 +131,8 @@ def setup_aie( ): app = AIE_Application(xclbin_path, insts_path, kernel_name) - app.register_buffer(3, shape=in_0_shape, dtype=in_0_dtype) + if in_0_shape or in_0_dtype: + app.register_buffer(3, shape=in_0_shape, dtype=in_0_dtype) if in_1_shape or in_1_dtype: app.register_buffer(4, shape=in_1_shape, dtype=in_1_dtype) @@ -159,8 +160,9 @@ def write_out_trace(trace, file_name): f.write(out_str) -def execute(app, input_one, input_two=None): - app.buffers[3].write(input_one) +def execute(app, input_one=None, input_two=None): + if not (input_one is None): + app.buffers[3].write(input_one) if not (input_two is None): app.buffers[4].write(input_two) app.run() diff --git a/test/python/tensortiler/group_tiler.py b/test/python/tensortiler/group_tiler.py new file mode 100644 index 0000000000..527ce8397e --- /dev/null +++ b/test/python/tensortiler/group_tiler.py @@ -0,0 +1,576 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile, TensorTiler2D +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: group_tiler +@construct_test +def group_tiler(): + # Default tile group dims + tiles = TensorTiler2D.group_tiler((3 * 5, 2 * 7), tile_dims=(3, 2)) + assert len(tiles) == 5 * 7 + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125], + [126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163], + [128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165], + [130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167], + [168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205], + [170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207], + [172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + tiles = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), tile_dims=(3, 2), tile_group_dims=(5, 7) + ) + assert len(tiles) == 3 * 2 + tile0_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ) + assert tiles[0] == tile0_0 + tile0_1 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ) + assert tiles[1] == tile0_1 + tile1_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ) + assert tiles[2] == tile1_0 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335], + [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373], + [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375], + [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377], + [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415], + [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417], + [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419], + [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667], + [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669], + [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671], + [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709], + [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711], + [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713], + [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751], + [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753], + [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755], + [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793], + [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795], + [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797], + [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835], + [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837], + [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839], + [ 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075, 1080, 1081, 1086, 1087], + [ 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077, 1082, 1083, 1088, 1089], + [ 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079, 1084, 1085, 1090, 1091], + [ 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919, 1092, 1093, 1098, 1099, 1104, 1105, 1110, 1111, 1116, 1117, 1122, 1123, 1128, 1129], + [ 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921, 1094, 1095, 1100, 1101, 1106, 1107, 1112, 1113, 1118, 1119, 1124, 1125, 1130, 1131], + [ 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923, 1096, 1097, 1102, 1103, 1108, 1109, 1114, 1115, 1120, 1121, 1126, 1127, 1132, 1133], + [ 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961, 1134, 1135, 1140, 1141, 1146, 1147, 1152, 1153, 1158, 1159, 1164, 1165, 1170, 1171], + [ 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963, 1136, 1137, 1142, 1143, 1148, 1149, 1154, 1155, 1160, 1161, 1166, 1167, 1172, 1173], + [ 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965, 1138, 1139, 1144, 1145, 1150, 1151, 1156, 1157, 1162, 1163, 1168, 1169, 1174, 1175], + [ 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003, 1176, 1177, 1182, 1183, 1188, 1189, 1194, 1195, 1200, 1201, 1206, 1207, 1212, 1213], + [ 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005, 1178, 1179, 1184, 1185, 1190, 1191, 1196, 1197, 1202, 1203, 1208, 1209, 1214, 1215], + [ 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007, 1180, 1181, 1186, 1187, 1192, 1193, 1198, 1199, 1204, 1205, 1210, 1211, 1216, 1217], + [1008, 1009, 1014, 1015, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045, 1218, 1219, 1224, 1225, 1230, 1231, 1236, 1237, 1242, 1243, 1248, 1249, 1254, 1255], + [1010, 1011, 1016, 1017, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047, 1220, 1221, 1226, 1227, 1232, 1233, 1238, 1239, 1244, 1245, 1250, 1251, 1256, 1257], + [1012, 1013, 1018, 1019, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049, 1222, 1223, 1228, 1229, 1234, 1235, 1240, 1241, 1246, 1247, 1252, 1253, 1258, 1259]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # iter_col_major + tiles_col_iter = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), + tile_dims=(3, 2), + tile_group_dims=(5, 7), + iter_col_major=True, + ) + assert tiles_col_iter[0] == tile0_0 + assert tiles_col_iter[1] == tile1_0 + assert tiles_col_iter[3] == tile0_1 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755], + [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793], + [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795], + [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797], + [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835], + [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837], + [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839], + [ 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877], + [ 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879], + [ 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881], + [ 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919], + [ 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921], + [ 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923], + [ 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961], + [ 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963], + [ 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965], + [ 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003], + [ 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005], + [ 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007], + [ 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 1008, 1009, 1014, 1015, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045], + [ 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 1010, 1011, 1016, 1017, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047], + [ 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 1012, 1013, 1018, 1019, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049], + [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075, 1080, 1081, 1086, 1087], + [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077, 1082, 1083, 1088, 1089], + [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079, 1084, 1085, 1090, 1091], + [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 1092, 1093, 1098, 1099, 1104, 1105, 1110, 1111, 1116, 1117, 1122, 1123, 1128, 1129], + [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 1094, 1095, 1100, 1101, 1106, 1107, 1112, 1113, 1118, 1119, 1124, 1125, 1130, 1131], + [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 1096, 1097, 1102, 1103, 1108, 1109, 1114, 1115, 1120, 1121, 1126, 1127, 1132, 1133], + [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 1134, 1135, 1140, 1141, 1146, 1147, 1152, 1153, 1158, 1159, 1164, 1165, 1170, 1171], + [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 1136, 1137, 1142, 1143, 1148, 1149, 1154, 1155, 1160, 1161, 1166, 1167, 1172, 1173], + [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 1138, 1139, 1144, 1145, 1150, 1151, 1156, 1157, 1162, 1163, 1168, 1169, 1174, 1175], + [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 1176, 1177, 1182, 1183, 1188, 1189, 1194, 1195, 1200, 1201, 1206, 1207, 1212, 1213], + [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 1178, 1179, 1184, 1185, 1190, 1191, 1196, 1197, 1202, 1203, 1208, 1209, 1214, 1215], + [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 1180, 1181, 1186, 1187, 1192, 1193, 1198, 1199, 1204, 1205, 1210, 1211, 1216, 1217], + [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 1218, 1219, 1224, 1225, 1230, 1231, 1236, 1237, 1242, 1243, 1248, 1249, 1254, 1255], + [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 1220, 1221, 1226, 1227, 1232, 1233, 1238, 1239, 1244, 1245, 1250, 1251, 1256, 1257], + [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 1222, 1223, 1228, 1229, 1234, 1235, 1240, 1241, 1246, 1247, 1252, 1253, 1258, 1259]]) + # fmt: on + access_order, access_count = tiles_col_iter.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # tile_col_major + tiles_tile_col_major = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + ) + tile0_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ) + assert tiles_tile_col_major[0] == tile0_0 + tile0_1 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ) + assert tiles_tile_col_major[1] == tile0_1 + tile1_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ) + assert tiles_tile_col_major[2] == tile1_0 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249], + [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250], + [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251], + [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291], + [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292], + [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293], + [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333], + [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334], + [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335], + [ 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375], + [ 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376], + [ 128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377], + [ 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417], + [ 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418], + [ 170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419], + [ 420, 423, 426, 429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669], + [ 421, 424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670], + [ 422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458, 461, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671], + [ 462, 465, 468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711], + [ 463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499, 502, 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712], + [ 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497, 500, 503, 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713], + [ 504, 507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543, 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753], + [ 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538, 541, 544, 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754], + [ 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536, 539, 542, 545, 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755], + [ 546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582, 585, 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795], + [ 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577, 580, 583, 586, 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796], + [ 548, 551, 554, 557, 560, 563, 566, 569, 572, 575, 578, 581, 584, 587, 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797], + [ 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627, 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837], + [ 589, 592, 595, 598, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628, 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838], + [ 590, 593, 596, 599, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629, 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839], + [ 840, 843, 846, 849, 852, 855, 858, 861, 864, 867, 870, 873, 876, 879, 1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089], + [ 841, 844, 847, 850, 853, 856, 859, 862, 865, 868, 871, 874, 877, 880, 1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090], + [ 842, 845, 848, 851, 854, 857, 860, 863, 866, 869, 872, 875, 878, 881, 1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091], + [ 882, 885, 888, 891, 894, 897, 900, 903, 906, 909, 912, 915, 918, 921, 1092, 1095, 1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131], + [ 883, 886, 889, 892, 895, 898, 901, 904, 907, 910, 913, 916, 919, 922, 1093, 1096, 1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132], + [ 884, 887, 890, 893, 896, 899, 902, 905, 908, 911, 914, 917, 920, 923, 1094, 1097, 1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133], + [ 924, 927, 930, 933, 936, 939, 942, 945, 948, 951, 954, 957, 960, 963, 1134, 1137, 1140, 1143, 1146, 1149, 1152, 1155, 1158, 1161, 1164, 1167, 1170, 1173], + [ 925, 928, 931, 934, 937, 940, 943, 946, 949, 952, 955, 958, 961, 964, 1135, 1138, 1141, 1144, 1147, 1150, 1153, 1156, 1159, 1162, 1165, 1168, 1171, 1174], + [ 926, 929, 932, 935, 938, 941, 944, 947, 950, 953, 956, 959, 962, 965, 1136, 1139, 1142, 1145, 1148, 1151, 1154, 1157, 1160, 1163, 1166, 1169, 1172, 1175], + [ 966, 969, 972, 975, 978, 981, 984, 987, 990, 993, 996, 999, 1002, 1005, 1176, 1179, 1182, 1185, 1188, 1191, 1194, 1197, 1200, 1203, 1206, 1209, 1212, 1215], + [ 967, 970, 973, 976, 979, 982, 985, 988, 991, 994, 997, 1000, 1003, 1006, 1177, 1180, 1183, 1186, 1189, 1192, 1195, 1198, 1201, 1204, 1207, 1210, 1213, 1216], + [ 968, 971, 974, 977, 980, 983, 986, 989, 992, 995, 998, 1001, 1004, 1007, 1178, 1181, 1184, 1187, 1190, 1193, 1196, 1199, 1202, 1205, 1208, 1211, 1214, 1217], + [1008, 1011, 1014, 1017, 1020, 1023, 1026, 1029, 1032, 1035, 1038, 1041, 1044, 1047, 1218, 1221, 1224, 1227, 1230, 1233, 1236, 1239, 1242, 1245, 1248, 1251, 1254, 1257], + [1009, 1012, 1015, 1018, 1021, 1024, 1027, 1030, 1033, 1036, 1039, 1042, 1045, 1048, 1219, 1222, 1225, 1228, 1231, 1234, 1237, 1240, 1243, 1246, 1249, 1252, 1255, 1258], + [1010, 1013, 1016, 1019, 1022, 1025, 1028, 1031, 1034, 1037, 1040, 1043, 1046, 1049, 1220, 1223, 1226, 1229, 1232, 1235, 1238, 1241, 1244, 1247, 1250, 1253, 1256, 1259]]) + # fmt: on + access_order, access_count = tiles_tile_col_major.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # iter_col_major and tile_col_major + tiles_tile_col_major_col_iter = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), + tile_dims=(3, 2), + tile_group_dims=(5, 7), + iter_col_major=True, + tile_col_major=True, + ) + assert tiles_tile_col_major_col_iter[0] == tile0_0 + assert tiles_tile_col_major_col_iter[1] == tile1_0 + assert tiles_tile_col_major_col_iter[3] == tile0_1 + + # tile_col_major and pattern_repeat + tiles_tile_col_major_pattern_repeat = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + pattern_repeat=2, + ) + assert tiles_tile_col_major_pattern_repeat[0] == TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[2, 5, 14, 3], strides=[0, 84, 1, 28] + ) + + # fmt: off + ref_access_order_tensor = np.array([ + [ 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669], + [ 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670], + [ 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671], + [ 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711], + [ 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292, 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712], + [ 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293, 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713], + [ 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333, 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753], + [ 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334, 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754], + [ 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335, 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755], + [ 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375, 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795], + [ 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376, 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796], + [ 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377, 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797], + [ 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417, 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837], + [ 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418, 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838], + [ 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419, 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839], + [1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089, 1470, 1473, 1476, 1479, 1482, 1485, 1488, 1491, 1494, 1497, 1500, 1503, 1506, 1509], + [1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090, 1471, 1474, 1477, 1480, 1483, 1486, 1489, 1492, 1495, 1498, 1501, 1504, 1507, 1510], + [1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091, 1472, 1475, 1478, 1481, 1484, 1487, 1490, 1493, 1496, 1499, 1502, 1505, 1508, 1511], + [1092, 1095, 1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131, 1512, 1515, 1518, 1521, 1524, 1527, 1530, 1533, 1536, 1539, 1542, 1545, 1548, 1551], + [1093, 1096, 1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132, 1513, 1516, 1519, 1522, 1525, 1528, 1531, 1534, 1537, 1540, 1543, 1546, 1549, 1552], + [1094, 1097, 1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133, 1514, 1517, 1520, 1523, 1526, 1529, 1532, 1535, 1538, 1541, 1544, 1547, 1550, 1553], + [1134, 1137, 1140, 1143, 1146, 1149, 1152, 1155, 1158, 1161, 1164, 1167, 1170, 1173, 1554, 1557, 1560, 1563, 1566, 1569, 1572, 1575, 1578, 1581, 1584, 1587, 1590, 1593], + [1135, 1138, 1141, 1144, 1147, 1150, 1153, 1156, 1159, 1162, 1165, 1168, 1171, 1174, 1555, 1558, 1561, 1564, 1567, 1570, 1573, 1576, 1579, 1582, 1585, 1588, 1591, 1594], + [1136, 1139, 1142, 1145, 1148, 1151, 1154, 1157, 1160, 1163, 1166, 1169, 1172, 1175, 1556, 1559, 1562, 1565, 1568, 1571, 1574, 1577, 1580, 1583, 1586, 1589, 1592, 1595], + [1176, 1179, 1182, 1185, 1188, 1191, 1194, 1197, 1200, 1203, 1206, 1209, 1212, 1215, 1596, 1599, 1602, 1605, 1608, 1611, 1614, 1617, 1620, 1623, 1626, 1629, 1632, 1635], + [1177, 1180, 1183, 1186, 1189, 1192, 1195, 1198, 1201, 1204, 1207, 1210, 1213, 1216, 1597, 1600, 1603, 1606, 1609, 1612, 1615, 1618, 1621, 1624, 1627, 1630, 1633, 1636], + [1178, 1181, 1184, 1187, 1190, 1193, 1196, 1199, 1202, 1205, 1208, 1211, 1214, 1217, 1598, 1601, 1604, 1607, 1610, 1613, 1616, 1619, 1622, 1625, 1628, 1631, 1634, 1637], + [1218, 1221, 1224, 1227, 1230, 1233, 1236, 1239, 1242, 1245, 1248, 1251, 1254, 1257, 1638, 1641, 1644, 1647, 1650, 1653, 1656, 1659, 1662, 1665, 1668, 1671, 1674, 1677], + [1219, 1222, 1225, 1228, 1231, 1234, 1237, 1240, 1243, 1246, 1249, 1252, 1255, 1258, 1639, 1642, 1645, 1648, 1651, 1654, 1657, 1660, 1663, 1666, 1669, 1672, 1675, 1678], + [1220, 1223, 1226, 1229, 1232, 1235, 1238, 1241, 1244, 1247, 1250, 1253, 1256, 1259, 1640, 1643, 1646, 1649, 1652, 1655, 1658, 1661, 1664, 1667, 1670, 1673, 1676, 1679], + [1890, 1893, 1896, 1899, 1902, 1905, 1908, 1911, 1914, 1917, 1920, 1923, 1926, 1929, 2310, 2313, 2316, 2319, 2322, 2325, 2328, 2331, 2334, 2337, 2340, 2343, 2346, 2349], + [1891, 1894, 1897, 1900, 1903, 1906, 1909, 1912, 1915, 1918, 1921, 1924, 1927, 1930, 2311, 2314, 2317, 2320, 2323, 2326, 2329, 2332, 2335, 2338, 2341, 2344, 2347, 2350], + [1892, 1895, 1898, 1901, 1904, 1907, 1910, 1913, 1916, 1919, 1922, 1925, 1928, 1931, 2312, 2315, 2318, 2321, 2324, 2327, 2330, 2333, 2336, 2339, 2342, 2345, 2348, 2351], + [1932, 1935, 1938, 1941, 1944, 1947, 1950, 1953, 1956, 1959, 1962, 1965, 1968, 1971, 2352, 2355, 2358, 2361, 2364, 2367, 2370, 2373, 2376, 2379, 2382, 2385, 2388, 2391], + [1933, 1936, 1939, 1942, 1945, 1948, 1951, 1954, 1957, 1960, 1963, 1966, 1969, 1972, 2353, 2356, 2359, 2362, 2365, 2368, 2371, 2374, 2377, 2380, 2383, 2386, 2389, 2392], + [1934, 1937, 1940, 1943, 1946, 1949, 1952, 1955, 1958, 1961, 1964, 1967, 1970, 1973, 2354, 2357, 2360, 2363, 2366, 2369, 2372, 2375, 2378, 2381, 2384, 2387, 2390, 2393], + [1974, 1977, 1980, 1983, 1986, 1989, 1992, 1995, 1998, 2001, 2004, 2007, 2010, 2013, 2394, 2397, 2400, 2403, 2406, 2409, 2412, 2415, 2418, 2421, 2424, 2427, 2430, 2433], + [1975, 1978, 1981, 1984, 1987, 1990, 1993, 1996, 1999, 2002, 2005, 2008, 2011, 2014, 2395, 2398, 2401, 2404, 2407, 2410, 2413, 2416, 2419, 2422, 2425, 2428, 2431, 2434], + [1976, 1979, 1982, 1985, 1988, 1991, 1994, 1997, 2000, 2003, 2006, 2009, 2012, 2015, 2396, 2399, 2402, 2405, 2408, 2411, 2414, 2417, 2420, 2423, 2426, 2429, 2432, 2435], + [2016, 2019, 2022, 2025, 2028, 2031, 2034, 2037, 2040, 2043, 2046, 2049, 2052, 2055, 2436, 2439, 2442, 2445, 2448, 2451, 2454, 2457, 2460, 2463, 2466, 2469, 2472, 2475], + [2017, 2020, 2023, 2026, 2029, 2032, 2035, 2038, 2041, 2044, 2047, 2050, 2053, 2056, 2437, 2440, 2443, 2446, 2449, 2452, 2455, 2458, 2461, 2464, 2467, 2470, 2473, 2476], + [2018, 2021, 2024, 2027, 2030, 2033, 2036, 2039, 2042, 2045, 2048, 2051, 2054, 2057, 2438, 2441, 2444, 2447, 2450, 2453, 2456, 2459, 2462, 2465, 2468, 2471, 2474, 2477], + [2058, 2061, 2064, 2067, 2070, 2073, 2076, 2079, 2082, 2085, 2088, 2091, 2094, 2097, 2478, 2481, 2484, 2487, 2490, 2493, 2496, 2499, 2502, 2505, 2508, 2511, 2514, 2517], + [2059, 2062, 2065, 2068, 2071, 2074, 2077, 2080, 2083, 2086, 2089, 2092, 2095, 2098, 2479, 2482, 2485, 2488, 2491, 2494, 2497, 2500, 2503, 2506, 2509, 2512, 2515, 2518], + [2060, 2063, 2066, 2069, 2072, 2075, 2078, 2081, 2084, 2087, 2090, 2093, 2096, 2099, 2480, 2483, 2486, 2489, 2492, 2495, 2498, 2501, 2504, 2507, 2510, 2513, 2516, 2519]]) + # fmt: on + access_order, access_count = tiles_tile_col_major_pattern_repeat.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 2).all() + + # tile_group_col_major + tiles_group_col_major = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_group_col_major=True, + ) + tile0_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ) + assert tiles_group_col_major[0] == tile0_0 + tile0_1 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ) + assert tiles_group_col_major[1] == tile0_1 + tile1_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ) + assert tiles_group_col_major[2] == tile1_0 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331, 360, 361, 390, 391], + [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333, 362, 363, 392, 393], + [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335, 364, 365, 394, 395], + [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337, 366, 367, 396, 397], + [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339, 368, 369, 398, 399], + [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341, 370, 371, 400, 401], + [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343, 372, 373, 402, 403], + [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345, 374, 375, 404, 405], + [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347, 376, 377, 406, 407], + [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349, 378, 379, 408, 409], + [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351, 380, 381, 410, 411], + [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353, 382, 383, 412, 413], + [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355, 384, 385, 414, 415], + [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357, 386, 387, 416, 417], + [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359, 388, 389, 418, 419], + [ 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691, 720, 721, 750, 751, 780, 781, 810, 811], + [ 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693, 722, 723, 752, 753, 782, 783, 812, 813], + [ 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695, 724, 725, 754, 755, 784, 785, 814, 815], + [ 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697, 726, 727, 756, 757, 786, 787, 816, 817], + [ 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699, 728, 729, 758, 759, 788, 789, 818, 819], + [ 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701, 730, 731, 760, 761, 790, 791, 820, 821], + [ 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703, 732, 733, 762, 763, 792, 793, 822, 823], + [ 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705, 734, 735, 764, 765, 794, 795, 824, 825], + [ 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707, 736, 737, 766, 767, 796, 797, 826, 827], + [ 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709, 738, 739, 768, 769, 798, 799, 828, 829], + [ 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711, 740, 741, 770, 771, 800, 801, 830, 831], + [ 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713, 742, 743, 772, 773, 802, 803, 832, 833], + [ 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715, 744, 745, 774, 775, 804, 805, 834, 835], + [ 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717, 746, 747, 776, 777, 806, 807, 836, 837], + [ 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719, 748, 749, 778, 779, 808, 809, 838, 839], + [ 840, 841, 870, 871, 900, 901, 930, 931, 960, 961, 990, 991, 1020, 1021, 1050, 1051, 1080, 1081, 1110, 1111, 1140, 1141, 1170, 1171, 1200, 1201, 1230, 1231], + [ 842, 843, 872, 873, 902, 903, 932, 933, 962, 963, 992, 993, 1022, 1023, 1052, 1053, 1082, 1083, 1112, 1113, 1142, 1143, 1172, 1173, 1202, 1203, 1232, 1233], + [ 844, 845, 874, 875, 904, 905, 934, 935, 964, 965, 994, 995, 1024, 1025, 1054, 1055, 1084, 1085, 1114, 1115, 1144, 1145, 1174, 1175, 1204, 1205, 1234, 1235], + [ 846, 847, 876, 877, 906, 907, 936, 937, 966, 967, 996, 997, 1026, 1027, 1056, 1057, 1086, 1087, 1116, 1117, 1146, 1147, 1176, 1177, 1206, 1207, 1236, 1237], + [ 848, 849, 878, 879, 908, 909, 938, 939, 968, 969, 998, 999, 1028, 1029, 1058, 1059, 1088, 1089, 1118, 1119, 1148, 1149, 1178, 1179, 1208, 1209, 1238, 1239], + [ 850, 851, 880, 881, 910, 911, 940, 941, 970, 971, 1000, 1001, 1030, 1031, 1060, 1061, 1090, 1091, 1120, 1121, 1150, 1151, 1180, 1181, 1210, 1211, 1240, 1241], + [ 852, 853, 882, 883, 912, 913, 942, 943, 972, 973, 1002, 1003, 1032, 1033, 1062, 1063, 1092, 1093, 1122, 1123, 1152, 1153, 1182, 1183, 1212, 1213, 1242, 1243], + [ 854, 855, 884, 885, 914, 915, 944, 945, 974, 975, 1004, 1005, 1034, 1035, 1064, 1065, 1094, 1095, 1124, 1125, 1154, 1155, 1184, 1185, 1214, 1215, 1244, 1245], + [ 856, 857, 886, 887, 916, 917, 946, 947, 976, 977, 1006, 1007, 1036, 1037, 1066, 1067, 1096, 1097, 1126, 1127, 1156, 1157, 1186, 1187, 1216, 1217, 1246, 1247], + [ 858, 859, 888, 889, 918, 919, 948, 949, 978, 979, 1008, 1009, 1038, 1039, 1068, 1069, 1098, 1099, 1128, 1129, 1158, 1159, 1188, 1189, 1218, 1219, 1248, 1249], + [ 860, 861, 890, 891, 920, 921, 950, 951, 980, 981, 1010, 1011, 1040, 1041, 1070, 1071, 1100, 1101, 1130, 1131, 1160, 1161, 1190, 1191, 1220, 1221, 1250, 1251], + [ 862, 863, 892, 893, 922, 923, 952, 953, 982, 983, 1012, 1013, 1042, 1043, 1072, 1073, 1102, 1103, 1132, 1133, 1162, 1163, 1192, 1193, 1222, 1223, 1252, 1253], + [ 864, 865, 894, 895, 924, 925, 954, 955, 984, 985, 1014, 1015, 1044, 1045, 1074, 1075, 1104, 1105, 1134, 1135, 1164, 1165, 1194, 1195, 1224, 1225, 1254, 1255], + [ 866, 867, 896, 897, 926, 927, 956, 957, 986, 987, 1016, 1017, 1046, 1047, 1076, 1077, 1106, 1107, 1136, 1137, 1166, 1167, 1196, 1197, 1226, 1227, 1256, 1257], + [ 868, 869, 898, 899, 928, 929, 958, 959, 988, 989, 1018, 1019, 1048, 1049, 1078, 1079, 1108, 1109, 1138, 1139, 1168, 1169, 1198, 1199, 1228, 1229, 1258, 1259]]) + # fmt: on + access_order, access_count = tiles_group_col_major.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # tile_group_col_major and tile_col_major + tiles_group_col_major = TensorTiler2D.group_tiler( + (3 * 5 * 3, 2 * 7 * 2), + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + tile_group_col_major=True, + ) + tile0_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=0, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ) + assert tiles_group_col_major[0] == tile0_0 + tile0_1 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=14, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ) + assert tiles_group_col_major[1] == tile0_1 + tile1_0 = TensorTile( + (3 * 5 * 3, 2 * 7 * 2), offset=420, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ) + assert tiles_group_col_major[2] == tile1_0 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393], + [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394], + [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395], + [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399], + [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400], + [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401], + [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405], + [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406], + [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407], + [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411], + [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412], + [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413], + [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417], + [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418], + [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419], + [ 420, 423, 450, 453, 480, 483, 510, 513, 540, 543, 570, 573, 600, 603, 630, 633, 660, 663, 690, 693, 720, 723, 750, 753, 780, 783, 810, 813], + [ 421, 424, 451, 454, 481, 484, 511, 514, 541, 544, 571, 574, 601, 604, 631, 634, 661, 664, 691, 694, 721, 724, 751, 754, 781, 784, 811, 814], + [ 422, 425, 452, 455, 482, 485, 512, 515, 542, 545, 572, 575, 602, 605, 632, 635, 662, 665, 692, 695, 722, 725, 752, 755, 782, 785, 812, 815], + [ 426, 429, 456, 459, 486, 489, 516, 519, 546, 549, 576, 579, 606, 609, 636, 639, 666, 669, 696, 699, 726, 729, 756, 759, 786, 789, 816, 819], + [ 427, 430, 457, 460, 487, 490, 517, 520, 547, 550, 577, 580, 607, 610, 637, 640, 667, 670, 697, 700, 727, 730, 757, 760, 787, 790, 817, 820], + [ 428, 431, 458, 461, 488, 491, 518, 521, 548, 551, 578, 581, 608, 611, 638, 641, 668, 671, 698, 701, 728, 731, 758, 761, 788, 791, 818, 821], + [ 432, 435, 462, 465, 492, 495, 522, 525, 552, 555, 582, 585, 612, 615, 642, 645, 672, 675, 702, 705, 732, 735, 762, 765, 792, 795, 822, 825], + [ 433, 436, 463, 466, 493, 496, 523, 526, 553, 556, 583, 586, 613, 616, 643, 646, 673, 676, 703, 706, 733, 736, 763, 766, 793, 796, 823, 826], + [ 434, 437, 464, 467, 494, 497, 524, 527, 554, 557, 584, 587, 614, 617, 644, 647, 674, 677, 704, 707, 734, 737, 764, 767, 794, 797, 824, 827], + [ 438, 441, 468, 471, 498, 501, 528, 531, 558, 561, 588, 591, 618, 621, 648, 651, 678, 681, 708, 711, 738, 741, 768, 771, 798, 801, 828, 831], + [ 439, 442, 469, 472, 499, 502, 529, 532, 559, 562, 589, 592, 619, 622, 649, 652, 679, 682, 709, 712, 739, 742, 769, 772, 799, 802, 829, 832], + [ 440, 443, 470, 473, 500, 503, 530, 533, 560, 563, 590, 593, 620, 623, 650, 653, 680, 683, 710, 713, 740, 743, 770, 773, 800, 803, 830, 833], + [ 444, 447, 474, 477, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627, 654, 657, 684, 687, 714, 717, 744, 747, 774, 777, 804, 807, 834, 837], + [ 445, 448, 475, 478, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628, 655, 658, 685, 688, 715, 718, 745, 748, 775, 778, 805, 808, 835, 838], + [ 446, 449, 476, 479, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629, 656, 659, 686, 689, 716, 719, 746, 749, 776, 779, 806, 809, 836, 839], + [ 840, 843, 870, 873, 900, 903, 930, 933, 960, 963, 990, 993, 1020, 1023, 1050, 1053, 1080, 1083, 1110, 1113, 1140, 1143, 1170, 1173, 1200, 1203, 1230, 1233], + [ 841, 844, 871, 874, 901, 904, 931, 934, 961, 964, 991, 994, 1021, 1024, 1051, 1054, 1081, 1084, 1111, 1114, 1141, 1144, 1171, 1174, 1201, 1204, 1231, 1234], + [ 842, 845, 872, 875, 902, 905, 932, 935, 962, 965, 992, 995, 1022, 1025, 1052, 1055, 1082, 1085, 1112, 1115, 1142, 1145, 1172, 1175, 1202, 1205, 1232, 1235], + [ 846, 849, 876, 879, 906, 909, 936, 939, 966, 969, 996, 999, 1026, 1029, 1056, 1059, 1086, 1089, 1116, 1119, 1146, 1149, 1176, 1179, 1206, 1209, 1236, 1239], + [ 847, 850, 877, 880, 907, 910, 937, 940, 967, 970, 997, 1000, 1027, 1030, 1057, 1060, 1087, 1090, 1117, 1120, 1147, 1150, 1177, 1180, 1207, 1210, 1237, 1240], + [ 848, 851, 878, 881, 908, 911, 938, 941, 968, 971, 998, 1001, 1028, 1031, 1058, 1061, 1088, 1091, 1118, 1121, 1148, 1151, 1178, 1181, 1208, 1211, 1238, 1241], + [ 852, 855, 882, 885, 912, 915, 942, 945, 972, 975, 1002, 1005, 1032, 1035, 1062, 1065, 1092, 1095, 1122, 1125, 1152, 1155, 1182, 1185, 1212, 1215, 1242, 1245], + [ 853, 856, 883, 886, 913, 916, 943, 946, 973, 976, 1003, 1006, 1033, 1036, 1063, 1066, 1093, 1096, 1123, 1126, 1153, 1156, 1183, 1186, 1213, 1216, 1243, 1246], + [ 854, 857, 884, 887, 914, 917, 944, 947, 974, 977, 1004, 1007, 1034, 1037, 1064, 1067, 1094, 1097, 1124, 1127, 1154, 1157, 1184, 1187, 1214, 1217, 1244, 1247], + [ 858, 861, 888, 891, 918, 921, 948, 951, 978, 981, 1008, 1011, 1038, 1041, 1068, 1071, 1098, 1101, 1128, 1131, 1158, 1161, 1188, 1191, 1218, 1221, 1248, 1251], + [ 859, 862, 889, 892, 919, 922, 949, 952, 979, 982, 1009, 1012, 1039, 1042, 1069, 1072, 1099, 1102, 1129, 1132, 1159, 1162, 1189, 1192, 1219, 1222, 1249, 1252], + [ 860, 863, 890, 893, 920, 923, 950, 953, 980, 983, 1010, 1013, 1040, 1043, 1070, 1073, 1100, 1103, 1130, 1133, 1160, 1163, 1190, 1193, 1220, 1223, 1250, 1253], + [ 864, 867, 894, 897, 924, 927, 954, 957, 984, 987, 1014, 1017, 1044, 1047, 1074, 1077, 1104, 1107, 1134, 1137, 1164, 1167, 1194, 1197, 1224, 1227, 1254, 1257], + [ 865, 868, 895, 898, 925, 928, 955, 958, 985, 988, 1015, 1018, 1045, 1048, 1075, 1078, 1105, 1108, 1135, 1138, 1165, 1168, 1195, 1198, 1225, 1228, 1255, 1258], + [ 866, 869, 896, 899, 926, 929, 956, 959, 986, 989, 1016, 1019, 1046, 1049, 1076, 1079, 1106, 1109, 1136, 1139, 1166, 1169, 1196, 1199, 1226, 1229, 1256, 1259]]) + # fmt: on + access_order, access_count = tiles_group_col_major.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: group_tiler_invalid +@construct_test +def group_tiler_invalid(): + try: + tiles = TensorTiler2D.group_tiler( + (), (3, 2), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tensor dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (10, 9, 4), (3, 2), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too many tensor dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3, -1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3,), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too few tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (1, 1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too many tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3, 2), (1, 1), tile_col_major=True, pattern_repeat=0 + ) + raise ValueError("Invalid repeat.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (4, 2), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Indivisible tile (height)") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3, 3), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Indivisible tile (width)") + except ValueError: + # good + pass + + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3, 2), (1,), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too few tile group dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3, 2), (1, -1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tile group dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler( + (9, 4), (3, 2), (1, 1, 1), tile_col_major=True + ) + raise ValueError("Too many tile group dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler((18, 8), (3, 2), (2, 3), tile_col_major=True) + raise ValueError( + "Indivisible by tile repeat width (but without allow_partial)." + ) + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.group_tiler((18, 8), (3, 2), (4, 2), tile_col_major=True) + raise ValueError( + "Indivisible by tile repeat height (but without allow_partial)." + ) + except ValueError: + # good + pass + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/group_tiler_partial.py b/test/python/tensortiler/group_tiler_partial.py new file mode 100644 index 0000000000..8f2f54370a --- /dev/null +++ b/test/python/tensortiler/group_tiler_partial.py @@ -0,0 +1,1427 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile, TensorTileSequence, TensorTiler2D +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: group_tiler_partial_row +@construct_test +def group_tiler_partial_row(): + + tensor_dims = (3 * 5 * 3, 2 * 6 * 2) + + # All row major + tiles = TensorTiler2D.group_tiler( + tensor_dims, tile_dims=(3, 2), tile_group_dims=(5, 7), allow_partial=True + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=720, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=734, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 240, 241, 246, 247, 252, 253, 258, 259, 264, 265], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 242, 243, 248, 249, 254, 255, 260, 261, 266, 267], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 244, 245, 250, 251, 256, 257, 262, 263, 268, 269], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 270, 271, 276, 277, 282, 283, 288, 289, 294, 295], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 272, 273, 278, 279, 284, 285, 290, 291, 296, 297], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 274, 275, 280, 281, 286, 287, 292, 293, 298, 299], + [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325], + [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327], + [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329], + [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 330, 331, 336, 337, 342, 343, 348, 349, 354, 355], + [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 332, 333, 338, 339, 344, 345, 350, 351, 356, 357], + [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 334, 335, 340, 341, 346, 347, 352, 353, 358, 359], + [ 360, 361, 366, 367, 372, 373, 378, 379, 384, 385, 390, 391, 396, 397, 570, 571, 576, 577, 582, 583, 588, 589, 594, 595], + [ 362, 363, 368, 369, 374, 375, 380, 381, 386, 387, 392, 393, 398, 399, 572, 573, 578, 579, 584, 585, 590, 591, 596, 597], + [ 364, 365, 370, 371, 376, 377, 382, 383, 388, 389, 394, 395, 400, 401, 574, 575, 580, 581, 586, 587, 592, 593, 598, 599], + [ 402, 403, 408, 409, 414, 415, 420, 421, 426, 427, 432, 433, 438, 439, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625], + [ 404, 405, 410, 411, 416, 417, 422, 423, 428, 429, 434, 435, 440, 441, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627], + [ 406, 407, 412, 413, 418, 419, 424, 425, 430, 431, 436, 437, 442, 443, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629], + [ 444, 445, 450, 451, 456, 457, 462, 463, 468, 469, 474, 475, 480, 481, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655], + [ 446, 447, 452, 453, 458, 459, 464, 465, 470, 471, 476, 477, 482, 483, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657], + [ 448, 449, 454, 455, 460, 461, 466, 467, 472, 473, 478, 479, 484, 485, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659], + [ 486, 487, 492, 493, 498, 499, 504, 505, 510, 511, 516, 517, 522, 523, 660, 661, 666, 667, 672, 673, 678, 679, 684, 685], + [ 488, 489, 494, 495, 500, 501, 506, 507, 512, 513, 518, 519, 524, 525, 662, 663, 668, 669, 674, 675, 680, 681, 686, 687], + [ 490, 491, 496, 497, 502, 503, 508, 509, 514, 515, 520, 521, 526, 527, 664, 665, 670, 671, 676, 677, 682, 683, 688, 689], + [ 528, 529, 534, 535, 540, 541, 546, 547, 552, 553, 558, 559, 564, 565, 690, 691, 696, 697, 702, 703, 708, 709, 714, 715], + [ 530, 531, 536, 537, 542, 543, 548, 549, 554, 555, 560, 561, 566, 567, 692, 693, 698, 699, 704, 705, 710, 711, 716, 717], + [ 532, 533, 538, 539, 544, 545, 550, 551, 556, 557, 562, 563, 568, 569, 694, 695, 700, 701, 706, 707, 712, 713, 718, 719], + [ 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751, 756, 757, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955], + [ 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753, 758, 759, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957], + [ 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755, 760, 761, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959], + [ 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793, 798, 799, 960, 961, 966, 967, 972, 973, 978, 979, 984, 985], + [ 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795, 800, 801, 962, 963, 968, 969, 974, 975, 980, 981, 986, 987], + [ 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797, 802, 803, 964, 965, 970, 971, 976, 977, 982, 983, 988, 989], + [ 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835, 840, 841, 990, 991, 996, 997, 1002, 1003, 1008, 1009, 1014, 1015], + [ 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837, 842, 843, 992, 993, 998, 999, 1004, 1005, 1010, 1011, 1016, 1017], + [ 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839, 844, 845, 994, 995, 1000, 1001, 1006, 1007, 1012, 1013, 1018, 1019], + [ 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877, 882, 883, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045], + [ 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879, 884, 885, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047], + [ 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881, 886, 887, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049], + [ 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919, 924, 925, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075], + [ 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921, 926, 927, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077], + [ 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923, 928, 929, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=14, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=360, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=374, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=720, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=734, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237], + [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238], + [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239], + [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 240, 243, 246, 249, 252, 255, 258, 261, 264, 267], + [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 241, 244, 247, 250, 253, 256, 259, 262, 265, 268], + [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 242, 245, 248, 251, 254, 257, 260, 263, 266, 269], + [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 270, 273, 276, 279, 282, 285, 288, 291, 294, 297], + [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 271, 274, 277, 280, 283, 286, 289, 292, 295, 298], + [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 272, 275, 278, 281, 284, 287, 290, 293, 296, 299], + [ 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327], + [ 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328], + [ 128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329], + [ 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 330, 333, 336, 339, 342, 345, 348, 351, 354, 357], + [ 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 331, 334, 337, 340, 343, 346, 349, 352, 355, 358], + [ 170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 332, 335, 338, 341, 344, 347, 350, 353, 356, 359], + [ 360, 363, 366, 369, 372, 375, 378, 381, 384, 387, 390, 393, 396, 399, 570, 573, 576, 579, 582, 585, 588, 591, 594, 597], + [ 361, 364, 367, 370, 373, 376, 379, 382, 385, 388, 391, 394, 397, 400, 571, 574, 577, 580, 583, 586, 589, 592, 595, 598], + [ 362, 365, 368, 371, 374, 377, 380, 383, 386, 389, 392, 395, 398, 401, 572, 575, 578, 581, 584, 587, 590, 593, 596, 599], + [ 402, 405, 408, 411, 414, 417, 420, 423, 426, 429, 432, 435, 438, 441, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627], + [ 403, 406, 409, 412, 415, 418, 421, 424, 427, 430, 433, 436, 439, 442, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628], + [ 404, 407, 410, 413, 416, 419, 422, 425, 428, 431, 434, 437, 440, 443, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629], + [ 444, 447, 450, 453, 456, 459, 462, 465, 468, 471, 474, 477, 480, 483, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657], + [ 445, 448, 451, 454, 457, 460, 463, 466, 469, 472, 475, 478, 481, 484, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658], + [ 446, 449, 452, 455, 458, 461, 464, 467, 470, 473, 476, 479, 482, 485, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659], + [ 486, 489, 492, 495, 498, 501, 504, 507, 510, 513, 516, 519, 522, 525, 660, 663, 666, 669, 672, 675, 678, 681, 684, 687], + [ 487, 490, 493, 496, 499, 502, 505, 508, 511, 514, 517, 520, 523, 526, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688], + [ 488, 491, 494, 497, 500, 503, 506, 509, 512, 515, 518, 521, 524, 527, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689], + [ 528, 531, 534, 537, 540, 543, 546, 549, 552, 555, 558, 561, 564, 567, 690, 693, 696, 699, 702, 705, 708, 711, 714, 717], + [ 529, 532, 535, 538, 541, 544, 547, 550, 553, 556, 559, 562, 565, 568, 691, 694, 697, 700, 703, 706, 709, 712, 715, 718], + [ 530, 533, 536, 539, 542, 545, 548, 551, 554, 557, 560, 563, 566, 569, 692, 695, 698, 701, 704, 707, 710, 713, 716, 719], + [ 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753, 756, 759, 930, 933, 936, 939, 942, 945, 948, 951, 954, 957], + [ 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754, 757, 760, 931, 934, 937, 940, 943, 946, 949, 952, 955, 958], + [ 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755, 758, 761, 932, 935, 938, 941, 944, 947, 950, 953, 956, 959], + [ 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795, 798, 801, 960, 963, 966, 969, 972, 975, 978, 981, 984, 987], + [ 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796, 799, 802, 961, 964, 967, 970, 973, 976, 979, 982, 985, 988], + [ 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797, 800, 803, 962, 965, 968, 971, 974, 977, 980, 983, 986, 989], + [ 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837, 840, 843, 990, 993, 996, 999, 1002, 1005, 1008, 1011, 1014, 1017], + [ 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838, 841, 844, 991, 994, 997, 1000, 1003, 1006, 1009, 1012, 1015, 1018], + [ 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839, 842, 845, 992, 995, 998, 1001, 1004, 1007, 1010, 1013, 1016, 1019], + [ 846, 849, 852, 855, 858, 861, 864, 867, 870, 873, 876, 879, 882, 885, 1020, 1023, 1026, 1029, 1032, 1035, 1038, 1041, 1044, 1047], + [ 847, 850, 853, 856, 859, 862, 865, 868, 871, 874, 877, 880, 883, 886, 1021, 1024, 1027, 1030, 1033, 1036, 1039, 1042, 1045, 1048], + [ 848, 851, 854, 857, 860, 863, 866, 869, 872, 875, 878, 881, 884, 887, 1022, 1025, 1028, 1031, 1034, 1037, 1040, 1043, 1046, 1049], + [ 888, 891, 894, 897, 900, 903, 906, 909, 912, 915, 918, 921, 924, 927, 1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077], + [ 889, 892, 895, 898, 901, 904, 907, 910, 913, 916, 919, 922, 925, 928, 1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078], + [ 890, 893, 896, 899, 902, 905, 908, 911, 914, 917, 920, 923, 926, 929, 1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile group col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_group_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=360, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=374, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=720, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=734, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331], + [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333], + [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335], + [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337], + [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339], + [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341], + [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343], + [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345], + [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347], + [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349], + [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351], + [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353], + [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355], + [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357], + [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359], + [ 360, 361, 390, 391, 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691], + [ 362, 363, 392, 393, 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693], + [ 364, 365, 394, 395, 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695], + [ 366, 367, 396, 397, 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697], + [ 368, 369, 398, 399, 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699], + [ 370, 371, 400, 401, 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701], + [ 372, 373, 402, 403, 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703], + [ 374, 375, 404, 405, 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705], + [ 376, 377, 406, 407, 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707], + [ 378, 379, 408, 409, 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709], + [ 380, 381, 410, 411, 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711], + [ 382, 383, 412, 413, 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713], + [ 384, 385, 414, 415, 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715], + [ 386, 387, 416, 417, 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717], + [ 388, 389, 418, 419, 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719], + [ 720, 721, 750, 751, 780, 781, 810, 811, 840, 841, 870, 871, 900, 901, 930, 931, 960, 961, 990, 991, 1020, 1021, 1050, 1051], + [ 722, 723, 752, 753, 782, 783, 812, 813, 842, 843, 872, 873, 902, 903, 932, 933, 962, 963, 992, 993, 1022, 1023, 1052, 1053], + [ 724, 725, 754, 755, 784, 785, 814, 815, 844, 845, 874, 875, 904, 905, 934, 935, 964, 965, 994, 995, 1024, 1025, 1054, 1055], + [ 726, 727, 756, 757, 786, 787, 816, 817, 846, 847, 876, 877, 906, 907, 936, 937, 966, 967, 996, 997, 1026, 1027, 1056, 1057], + [ 728, 729, 758, 759, 788, 789, 818, 819, 848, 849, 878, 879, 908, 909, 938, 939, 968, 969, 998, 999, 1028, 1029, 1058, 1059], + [ 730, 731, 760, 761, 790, 791, 820, 821, 850, 851, 880, 881, 910, 911, 940, 941, 970, 971, 1000, 1001, 1030, 1031, 1060, 1061], + [ 732, 733, 762, 763, 792, 793, 822, 823, 852, 853, 882, 883, 912, 913, 942, 943, 972, 973, 1002, 1003, 1032, 1033, 1062, 1063], + [ 734, 735, 764, 765, 794, 795, 824, 825, 854, 855, 884, 885, 914, 915, 944, 945, 974, 975, 1004, 1005, 1034, 1035, 1064, 1065], + [ 736, 737, 766, 767, 796, 797, 826, 827, 856, 857, 886, 887, 916, 917, 946, 947, 976, 977, 1006, 1007, 1036, 1037, 1066, 1067], + [ 738, 739, 768, 769, 798, 799, 828, 829, 858, 859, 888, 889, 918, 919, 948, 949, 978, 979, 1008, 1009, 1038, 1039, 1068, 1069], + [ 740, 741, 770, 771, 800, 801, 830, 831, 860, 861, 890, 891, 920, 921, 950, 951, 980, 981, 1010, 1011, 1040, 1041, 1070, 1071], + [ 742, 743, 772, 773, 802, 803, 832, 833, 862, 863, 892, 893, 922, 923, 952, 953, 982, 983, 1012, 1013, 1042, 1043, 1072, 1073], + [ 744, 745, 774, 775, 804, 805, 834, 835, 864, 865, 894, 895, 924, 925, 954, 955, 984, 985, 1014, 1015, 1044, 1045, 1074, 1075], + [ 746, 747, 776, 777, 806, 807, 836, 837, 866, 867, 896, 897, 926, 927, 956, 957, 986, 987, 1016, 1017, 1046, 1047, 1076, 1077], + [ 748, 749, 778, 779, 808, 809, 838, 839, 868, 869, 898, 899, 928, 929, 958, 959, 988, 989, 1018, 1019, 1048, 1049, 1078, 1079]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # iter col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + iter_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=720, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=734, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 660, 661, 666, 667, 672, 673, 678, 679, 684, 685], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 662, 663, 668, 669, 674, 675, 680, 681, 686, 687], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 664, 665, 670, 671, 676, 677, 682, 683, 688, 689], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 690, 691, 696, 697, 702, 703, 708, 709, 714, 715], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 692, 693, 698, 699, 704, 705, 710, 711, 716, 717], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 694, 695, 700, 701, 706, 707, 712, 713, 718, 719], + [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745], + [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747], + [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749], + [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 750, 751, 756, 757, 762, 763, 768, 769, 774, 775], + [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 752, 753, 758, 759, 764, 765, 770, 771, 776, 777], + [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 754, 755, 760, 761, 766, 767, 772, 773, 778, 779], + [ 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 780, 781, 786, 787, 792, 793, 798, 799, 804, 805], + [ 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 782, 783, 788, 789, 794, 795, 800, 801, 806, 807], + [ 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 784, 785, 790, 791, 796, 797, 802, 803, 808, 809], + [ 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835], + [ 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837], + [ 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839], + [ 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 840, 841, 846, 847, 852, 853, 858, 859, 864, 865], + [ 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 842, 843, 848, 849, 854, 855, 860, 861, 866, 867], + [ 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 844, 845, 850, 851, 856, 857, 862, 863, 868, 869], + [ 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 870, 871, 876, 877, 882, 883, 888, 889, 894, 895], + [ 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 872, 873, 878, 879, 884, 885, 890, 891, 896, 897], + [ 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 874, 875, 880, 881, 886, 887, 892, 893, 898, 899], + [ 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 900, 901, 906, 907, 912, 913, 918, 919, 924, 925], + [ 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 902, 903, 908, 909, 914, 915, 920, 921, 926, 927], + [ 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 904, 905, 910, 911, 916, 917, 922, 923, 928, 929], + [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955], + [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957], + [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959], + [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 960, 961, 966, 967, 972, 973, 978, 979, 984, 985], + [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 962, 963, 968, 969, 974, 975, 980, 981, 986, 987], + [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 964, 965, 970, 971, 976, 977, 982, 983, 988, 989], + [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 990, 991, 996, 997, 1002, 1003, 1008, 1009, 1014, 1015], + [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 992, 993, 998, 999, 1004, 1005, 1010, 1011, 1016, 1017], + [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 994, 995, 1000, 1001, 1006, 1007, 1012, 1013, 1018, 1019], + [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 1020, 1021, 1026, 1027, 1032, 1033, 1038, 1039, 1044, 1045], + [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 1022, 1023, 1028, 1029, 1034, 1035, 1040, 1041, 1046, 1047], + [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 1024, 1025, 1030, 1031, 1036, 1037, 1042, 1043, 1048, 1049], + [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 1050, 1051, 1056, 1057, 1062, 1063, 1068, 1069, 1074, 1075], + [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 1052, 1053, 1058, 1059, 1064, 1065, 1070, 1071, 1076, 1077], + [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 1054, 1055, 1060, 1061, 1066, 1067, 1072, 1073, 1078, 1079]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # all col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + tile_group_col_major=True, + iter_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=360, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=720, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=374, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=734, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 630, 633, 660, 663, 690, 693, 720, 723, 750, 753], + [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 631, 634, 661, 664, 691, 694, 721, 724, 751, 754], + [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 632, 635, 662, 665, 692, 695, 722, 725, 752, 755], + [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 636, 639, 666, 669, 696, 699, 726, 729, 756, 759], + [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 637, 640, 667, 670, 697, 700, 727, 730, 757, 760], + [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 638, 641, 668, 671, 698, 701, 728, 731, 758, 761], + [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 642, 645, 672, 675, 702, 705, 732, 735, 762, 765], + [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 643, 646, 673, 676, 703, 706, 733, 736, 763, 766], + [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 644, 647, 674, 677, 704, 707, 734, 737, 764, 767], + [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 648, 651, 678, 681, 708, 711, 738, 741, 768, 771], + [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 649, 652, 679, 682, 709, 712, 739, 742, 769, 772], + [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 650, 653, 680, 683, 710, 713, 740, 743, 770, 773], + [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 654, 657, 684, 687, 714, 717, 744, 747, 774, 777], + [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 655, 658, 685, 688, 715, 718, 745, 748, 775, 778], + [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 656, 659, 686, 689, 716, 719, 746, 749, 776, 779], + [ 210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393, 780, 783, 810, 813, 840, 843, 870, 873, 900, 903], + [ 211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394, 781, 784, 811, 814, 841, 844, 871, 874, 901, 904], + [ 212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395, 782, 785, 812, 815, 842, 845, 872, 875, 902, 905], + [ 216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399, 786, 789, 816, 819, 846, 849, 876, 879, 906, 909], + [ 217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400, 787, 790, 817, 820, 847, 850, 877, 880, 907, 910], + [ 218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401, 788, 791, 818, 821, 848, 851, 878, 881, 908, 911], + [ 222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405, 792, 795, 822, 825, 852, 855, 882, 885, 912, 915], + [ 223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406, 793, 796, 823, 826, 853, 856, 883, 886, 913, 916], + [ 224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407, 794, 797, 824, 827, 854, 857, 884, 887, 914, 917], + [ 228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411, 798, 801, 828, 831, 858, 861, 888, 891, 918, 921], + [ 229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412, 799, 802, 829, 832, 859, 862, 889, 892, 919, 922], + [ 230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413, 800, 803, 830, 833, 860, 863, 890, 893, 920, 923], + [ 234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417, 804, 807, 834, 837, 864, 867, 894, 897, 924, 927], + [ 235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418, 805, 808, 835, 838, 865, 868, 895, 898, 925, 928], + [ 236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419, 806, 809, 836, 839, 866, 869, 896, 899, 926, 929], + [ 420, 423, 450, 453, 480, 483, 510, 513, 540, 543, 570, 573, 600, 603, 930, 933, 960, 963, 990, 993, 1020, 1023, 1050, 1053], + [ 421, 424, 451, 454, 481, 484, 511, 514, 541, 544, 571, 574, 601, 604, 931, 934, 961, 964, 991, 994, 1021, 1024, 1051, 1054], + [ 422, 425, 452, 455, 482, 485, 512, 515, 542, 545, 572, 575, 602, 605, 932, 935, 962, 965, 992, 995, 1022, 1025, 1052, 1055], + [ 426, 429, 456, 459, 486, 489, 516, 519, 546, 549, 576, 579, 606, 609, 936, 939, 966, 969, 996, 999, 1026, 1029, 1056, 1059], + [ 427, 430, 457, 460, 487, 490, 517, 520, 547, 550, 577, 580, 607, 610, 937, 940, 967, 970, 997, 1000, 1027, 1030, 1057, 1060], + [ 428, 431, 458, 461, 488, 491, 518, 521, 548, 551, 578, 581, 608, 611, 938, 941, 968, 971, 998, 1001, 1028, 1031, 1058, 1061], + [ 432, 435, 462, 465, 492, 495, 522, 525, 552, 555, 582, 585, 612, 615, 942, 945, 972, 975, 1002, 1005, 1032, 1035, 1062, 1065], + [ 433, 436, 463, 466, 493, 496, 523, 526, 553, 556, 583, 586, 613, 616, 943, 946, 973, 976, 1003, 1006, 1033, 1036, 1063, 1066], + [ 434, 437, 464, 467, 494, 497, 524, 527, 554, 557, 584, 587, 614, 617, 944, 947, 974, 977, 1004, 1007, 1034, 1037, 1064, 1067], + [ 438, 441, 468, 471, 498, 501, 528, 531, 558, 561, 588, 591, 618, 621, 948, 951, 978, 981, 1008, 1011, 1038, 1041, 1068, 1071], + [ 439, 442, 469, 472, 499, 502, 529, 532, 559, 562, 589, 592, 619, 622, 949, 952, 979, 982, 1009, 1012, 1039, 1042, 1069, 1072], + [ 440, 443, 470, 473, 500, 503, 530, 533, 560, 563, 590, 593, 620, 623, 950, 953, 980, 983, 1010, 1013, 1040, 1043, 1070, 1073], + [ 444, 447, 474, 477, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627, 954, 957, 984, 987, 1014, 1017, 1044, 1047, 1074, 1077], + [ 445, 448, 475, 478, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628, 955, 958, 985, 988, 1015, 1018, 1045, 1048, 1075, 1078], + [ 446, 449, 476, 479, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629, 956, 959, 986, 989, 1016, 1019, 1046, 1049, 1076, 1079]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # pattern repeat + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + allow_partial=True, + pattern_repeat=4, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[4, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=14, sizes=[4, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=360, sizes=[4, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=374, sizes=[4, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=720, sizes=[4, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=734, sizes=[4, 5, 10, 3], strides=[0, 72, 1, 24] + ), + ] + ) + assert tiles == reference_tiles + + # fmt: off + ref_access_order_tensor = np.array([ + [ 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669, 1290, 1293, 1296, 1299, 1302, 1305, 1308, 1311, 1314, 1317], + [ 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670, 1291, 1294, 1297, 1300, 1303, 1306, 1309, 1312, 1315, 1318], + [ 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671, 1292, 1295, 1298, 1301, 1304, 1307, 1310, 1313, 1316, 1319], + [ 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711, 1320, 1323, 1326, 1329, 1332, 1335, 1338, 1341, 1344, 1347], + [ 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712, 1321, 1324, 1327, 1330, 1333, 1336, 1339, 1342, 1345, 1348], + [ 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713, 1322, 1325, 1328, 1331, 1334, 1337, 1340, 1343, 1346, 1349], + [ 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753, 1350, 1353, 1356, 1359, 1362, 1365, 1368, 1371, 1374, 1377], + [ 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754, 1351, 1354, 1357, 1360, 1363, 1366, 1369, 1372, 1375, 1378], + [ 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755, 1352, 1355, 1358, 1361, 1364, 1367, 1370, 1373, 1376, 1379], + [ 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795, 1380, 1383, 1386, 1389, 1392, 1395, 1398, 1401, 1404, 1407], + [ 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796, 1381, 1384, 1387, 1390, 1393, 1396, 1399, 1402, 1405, 1408], + [ 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797, 1382, 1385, 1388, 1391, 1394, 1397, 1400, 1403, 1406, 1409], + [ 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837, 1410, 1413, 1416, 1419, 1422, 1425, 1428, 1431, 1434, 1437], + [ 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838, 1411, 1414, 1417, 1420, 1423, 1426, 1429, 1432, 1435, 1438], + [ 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839, 1412, 1415, 1418, 1421, 1424, 1427, 1430, 1433, 1436, 1439], + [2070, 2073, 2076, 2079, 2082, 2085, 2088, 2091, 2094, 2097, 2100, 2103, 2106, 2109, 2730, 2733, 2736, 2739, 2742, 2745, 2748, 2751, 2754, 2757], + [2071, 2074, 2077, 2080, 2083, 2086, 2089, 2092, 2095, 2098, 2101, 2104, 2107, 2110, 2731, 2734, 2737, 2740, 2743, 2746, 2749, 2752, 2755, 2758], + [2072, 2075, 2078, 2081, 2084, 2087, 2090, 2093, 2096, 2099, 2102, 2105, 2108, 2111, 2732, 2735, 2738, 2741, 2744, 2747, 2750, 2753, 2756, 2759], + [2112, 2115, 2118, 2121, 2124, 2127, 2130, 2133, 2136, 2139, 2142, 2145, 2148, 2151, 2760, 2763, 2766, 2769, 2772, 2775, 2778, 2781, 2784, 2787], + [2113, 2116, 2119, 2122, 2125, 2128, 2131, 2134, 2137, 2140, 2143, 2146, 2149, 2152, 2761, 2764, 2767, 2770, 2773, 2776, 2779, 2782, 2785, 2788], + [2114, 2117, 2120, 2123, 2126, 2129, 2132, 2135, 2138, 2141, 2144, 2147, 2150, 2153, 2762, 2765, 2768, 2771, 2774, 2777, 2780, 2783, 2786, 2789], + [2154, 2157, 2160, 2163, 2166, 2169, 2172, 2175, 2178, 2181, 2184, 2187, 2190, 2193, 2790, 2793, 2796, 2799, 2802, 2805, 2808, 2811, 2814, 2817], + [2155, 2158, 2161, 2164, 2167, 2170, 2173, 2176, 2179, 2182, 2185, 2188, 2191, 2194, 2791, 2794, 2797, 2800, 2803, 2806, 2809, 2812, 2815, 2818], + [2156, 2159, 2162, 2165, 2168, 2171, 2174, 2177, 2180, 2183, 2186, 2189, 2192, 2195, 2792, 2795, 2798, 2801, 2804, 2807, 2810, 2813, 2816, 2819], + [2196, 2199, 2202, 2205, 2208, 2211, 2214, 2217, 2220, 2223, 2226, 2229, 2232, 2235, 2820, 2823, 2826, 2829, 2832, 2835, 2838, 2841, 2844, 2847], + [2197, 2200, 2203, 2206, 2209, 2212, 2215, 2218, 2221, 2224, 2227, 2230, 2233, 2236, 2821, 2824, 2827, 2830, 2833, 2836, 2839, 2842, 2845, 2848], + [2198, 2201, 2204, 2207, 2210, 2213, 2216, 2219, 2222, 2225, 2228, 2231, 2234, 2237, 2822, 2825, 2828, 2831, 2834, 2837, 2840, 2843, 2846, 2849], + [2238, 2241, 2244, 2247, 2250, 2253, 2256, 2259, 2262, 2265, 2268, 2271, 2274, 2277, 2850, 2853, 2856, 2859, 2862, 2865, 2868, 2871, 2874, 2877], + [2239, 2242, 2245, 2248, 2251, 2254, 2257, 2260, 2263, 2266, 2269, 2272, 2275, 2278, 2851, 2854, 2857, 2860, 2863, 2866, 2869, 2872, 2875, 2878], + [2240, 2243, 2246, 2249, 2252, 2255, 2258, 2261, 2264, 2267, 2270, 2273, 2276, 2279, 2852, 2855, 2858, 2861, 2864, 2867, 2870, 2873, 2876, 2879], + [3510, 3513, 3516, 3519, 3522, 3525, 3528, 3531, 3534, 3537, 3540, 3543, 3546, 3549, 4170, 4173, 4176, 4179, 4182, 4185, 4188, 4191, 4194, 4197], + [3511, 3514, 3517, 3520, 3523, 3526, 3529, 3532, 3535, 3538, 3541, 3544, 3547, 3550, 4171, 4174, 4177, 4180, 4183, 4186, 4189, 4192, 4195, 4198], + [3512, 3515, 3518, 3521, 3524, 3527, 3530, 3533, 3536, 3539, 3542, 3545, 3548, 3551, 4172, 4175, 4178, 4181, 4184, 4187, 4190, 4193, 4196, 4199], + [3552, 3555, 3558, 3561, 3564, 3567, 3570, 3573, 3576, 3579, 3582, 3585, 3588, 3591, 4200, 4203, 4206, 4209, 4212, 4215, 4218, 4221, 4224, 4227], + [3553, 3556, 3559, 3562, 3565, 3568, 3571, 3574, 3577, 3580, 3583, 3586, 3589, 3592, 4201, 4204, 4207, 4210, 4213, 4216, 4219, 4222, 4225, 4228], + [3554, 3557, 3560, 3563, 3566, 3569, 3572, 3575, 3578, 3581, 3584, 3587, 3590, 3593, 4202, 4205, 4208, 4211, 4214, 4217, 4220, 4223, 4226, 4229], + [3594, 3597, 3600, 3603, 3606, 3609, 3612, 3615, 3618, 3621, 3624, 3627, 3630, 3633, 4230, 4233, 4236, 4239, 4242, 4245, 4248, 4251, 4254, 4257], + [3595, 3598, 3601, 3604, 3607, 3610, 3613, 3616, 3619, 3622, 3625, 3628, 3631, 3634, 4231, 4234, 4237, 4240, 4243, 4246, 4249, 4252, 4255, 4258], + [3596, 3599, 3602, 3605, 3608, 3611, 3614, 3617, 3620, 3623, 3626, 3629, 3632, 3635, 4232, 4235, 4238, 4241, 4244, 4247, 4250, 4253, 4256, 4259], + [3636, 3639, 3642, 3645, 3648, 3651, 3654, 3657, 3660, 3663, 3666, 3669, 3672, 3675, 4260, 4263, 4266, 4269, 4272, 4275, 4278, 4281, 4284, 4287], + [3637, 3640, 3643, 3646, 3649, 3652, 3655, 3658, 3661, 3664, 3667, 3670, 3673, 3676, 4261, 4264, 4267, 4270, 4273, 4276, 4279, 4282, 4285, 4288], + [3638, 3641, 3644, 3647, 3650, 3653, 3656, 3659, 3662, 3665, 3668, 3671, 3674, 3677, 4262, 4265, 4268, 4271, 4274, 4277, 4280, 4283, 4286, 4289], + [3678, 3681, 3684, 3687, 3690, 3693, 3696, 3699, 3702, 3705, 3708, 3711, 3714, 3717, 4290, 4293, 4296, 4299, 4302, 4305, 4308, 4311, 4314, 4317], + [3679, 3682, 3685, 3688, 3691, 3694, 3697, 3700, 3703, 3706, 3709, 3712, 3715, 3718, 4291, 4294, 4297, 4300, 4303, 4306, 4309, 4312, 4315, 4318], + [3680, 3683, 3686, 3689, 3692, 3695, 3698, 3701, 3704, 3707, 3710, 3713, 3716, 3719, 4292, 4295, 4298, 4301, 4304, 4307, 4310, 4313, 4316, 4319]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 4).all() + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: group_tiler_partial_col +@construct_test +def group_tiler_partial_col(): + + # All row major + tensor_dims = (3 * 4 * 3, 2 * 7 * 2) + tiles = TensorTiler2D.group_tiler( + tensor_dims, tile_dims=(3, 2), tile_group_dims=(5, 7), allow_partial=True + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=420, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=434, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=840, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=854, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335], + [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373], + [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375], + [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377], + [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415], + [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417], + [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419], + [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667], + [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669], + [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671], + [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709], + [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711], + [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713], + [ 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751], + [ 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753], + [ 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755], + [ 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793], + [ 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795], + [ 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797], + [ 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835], + [ 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837], + [ 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839], + [ 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877, 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961], + [ 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879, 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963], + [ 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881, 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965], + [ 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919, 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003], + [ 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921, 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005], + [ 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923, 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=14, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=420, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=434, sizes=[1, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=840, sizes=[1, 2, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=854, sizes=[1, 2, 14, 3], strides=[0, 84, 1, 28] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249], + [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250], + [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251], + [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291], + [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292], + [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293], + [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333], + [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334], + [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335], + [ 126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375], + [ 127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376], + [ 128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377], + [ 168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417], + [ 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418], + [ 170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419], + [ 420, 423, 426, 429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657, 660, 663, 666, 669], + [ 421, 424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658, 661, 664, 667, 670], + [ 422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458, 461, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659, 662, 665, 668, 671], + [ 462, 465, 468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 672, 675, 678, 681, 684, 687, 690, 693, 696, 699, 702, 705, 708, 711], + [ 463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499, 502, 673, 676, 679, 682, 685, 688, 691, 694, 697, 700, 703, 706, 709, 712], + [ 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497, 500, 503, 674, 677, 680, 683, 686, 689, 692, 695, 698, 701, 704, 707, 710, 713], + [ 504, 507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543, 714, 717, 720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753], + [ 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538, 541, 544, 715, 718, 721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754], + [ 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536, 539, 542, 545, 716, 719, 722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755], + [ 546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582, 585, 756, 759, 762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795], + [ 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577, 580, 583, 586, 757, 760, 763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796], + [ 548, 551, 554, 557, 560, 563, 566, 569, 572, 575, 578, 581, 584, 587, 758, 761, 764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797], + [ 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627, 798, 801, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831, 834, 837], + [ 589, 592, 595, 598, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628, 799, 802, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832, 835, 838], + [ 590, 593, 596, 599, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629, 800, 803, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833, 836, 839], + [ 840, 843, 846, 849, 852, 855, 858, 861, 864, 867, 870, 873, 876, 879, 924, 927, 930, 933, 936, 939, 942, 945, 948, 951, 954, 957, 960, 963], + [ 841, 844, 847, 850, 853, 856, 859, 862, 865, 868, 871, 874, 877, 880, 925, 928, 931, 934, 937, 940, 943, 946, 949, 952, 955, 958, 961, 964], + [ 842, 845, 848, 851, 854, 857, 860, 863, 866, 869, 872, 875, 878, 881, 926, 929, 932, 935, 938, 941, 944, 947, 950, 953, 956, 959, 962, 965], + [ 882, 885, 888, 891, 894, 897, 900, 903, 906, 909, 912, 915, 918, 921, 966, 969, 972, 975, 978, 981, 984, 987, 990, 993, 996, 999, 1002, 1005], + [ 883, 886, 889, 892, 895, 898, 901, 904, 907, 910, 913, 916, 919, 922, 967, 970, 973, 976, 979, 982, 985, 988, 991, 994, 997, 1000, 1003, 1006], + [ 884, 887, 890, 893, 896, 899, 902, 905, 908, 911, 914, 917, 920, 923, 968, 971, 974, 977, 980, 983, 986, 989, 992, 995, 998, 1001, 1004, 1007]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile group col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_group_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=420, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=434, sizes=[1, 7, 15, 2], strides=[0, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=840, sizes=[1, 7, 6, 2], strides=[0, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=854, sizes=[1, 7, 6, 2], strides=[0, 2, 28, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331, 360, 361, 390, 391], + [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333, 362, 363, 392, 393], + [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335, 364, 365, 394, 395], + [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337, 366, 367, 396, 397], + [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339, 368, 369, 398, 399], + [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341, 370, 371, 400, 401], + [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343, 372, 373, 402, 403], + [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345, 374, 375, 404, 405], + [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347, 376, 377, 406, 407], + [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349, 378, 379, 408, 409], + [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351, 380, 381, 410, 411], + [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353, 382, 383, 412, 413], + [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355, 384, 385, 414, 415], + [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357, 386, 387, 416, 417], + [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359, 388, 389, 418, 419], + [ 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691, 720, 721, 750, 751, 780, 781, 810, 811], + [ 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693, 722, 723, 752, 753, 782, 783, 812, 813], + [ 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695, 724, 725, 754, 755, 784, 785, 814, 815], + [ 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697, 726, 727, 756, 757, 786, 787, 816, 817], + [ 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699, 728, 729, 758, 759, 788, 789, 818, 819], + [ 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701, 730, 731, 760, 761, 790, 791, 820, 821], + [ 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703, 732, 733, 762, 763, 792, 793, 822, 823], + [ 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705, 734, 735, 764, 765, 794, 795, 824, 825], + [ 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707, 736, 737, 766, 767, 796, 797, 826, 827], + [ 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709, 738, 739, 768, 769, 798, 799, 828, 829], + [ 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711, 740, 741, 770, 771, 800, 801, 830, 831], + [ 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713, 742, 743, 772, 773, 802, 803, 832, 833], + [ 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715, 744, 745, 774, 775, 804, 805, 834, 835], + [ 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717, 746, 747, 776, 777, 806, 807, 836, 837], + [ 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719, 748, 749, 778, 779, 808, 809, 838, 839], + [ 840, 841, 852, 853, 864, 865, 876, 877, 888, 889, 900, 901, 912, 913, 924, 925, 936, 937, 948, 949, 960, 961, 972, 973, 984, 985, 996, 997], + [ 842, 843, 854, 855, 866, 867, 878, 879, 890, 891, 902, 903, 914, 915, 926, 927, 938, 939, 950, 951, 962, 963, 974, 975, 986, 987, 998, 999], + [ 844, 845, 856, 857, 868, 869, 880, 881, 892, 893, 904, 905, 916, 917, 928, 929, 940, 941, 952, 953, 964, 965, 976, 977, 988, 989, 1000, 1001], + [ 846, 847, 858, 859, 870, 871, 882, 883, 894, 895, 906, 907, 918, 919, 930, 931, 942, 943, 954, 955, 966, 967, 978, 979, 990, 991, 1002, 1003], + [ 848, 849, 860, 861, 872, 873, 884, 885, 896, 897, 908, 909, 920, 921, 932, 933, 944, 945, 956, 957, 968, 969, 980, 981, 992, 993, 1004, 1005], + [ 850, 851, 862, 863, 874, 875, 886, 887, 898, 899, 910, 911, 922, 923, 934, 935, 946, 947, 958, 959, 970, 971, 982, 983, 994, 995, 1006, 1007]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # iter col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + iter_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=420, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=840, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=434, sizes=[5, 7, 3, 2], strides=[84, 2, 28, 1] + ), + TensorTile( + tensor_dims, offset=854, sizes=[2, 7, 3, 2], strides=[84, 2, 28, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 504, 505, 510, 511, 516, 517, 522, 523, 528, 529, 534, 535, 540, 541], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 506, 507, 512, 513, 518, 519, 524, 525, 530, 531, 536, 537, 542, 543], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 508, 509, 514, 515, 520, 521, 526, 527, 532, 533, 538, 539, 544, 545], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 546, 547, 552, 553, 558, 559, 564, 565, 570, 571, 576, 577, 582, 583], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 548, 549, 554, 555, 560, 561, 566, 567, 572, 573, 578, 579, 584, 585], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 550, 551, 556, 557, 562, 563, 568, 569, 574, 575, 580, 581, 586, 587], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 588, 589, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 590, 591, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 592, 593, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629], + [ 126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655, 660, 661, 666, 667], + [ 128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657, 662, 663, 668, 669], + [ 130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659, 664, 665, 670, 671], + [ 168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 672, 673, 678, 679, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709], + [ 170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 674, 675, 680, 681, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711], + [ 172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 676, 677, 682, 683, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713], + [ 210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751], + [ 212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753], + [ 214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755], + [ 252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 756, 757, 762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793], + [ 254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 758, 759, 764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795], + [ 256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 760, 761, 766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797], + [ 294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 798, 799, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829, 834, 835], + [ 296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 800, 801, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831, 836, 837], + [ 298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 802, 803, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833, 838, 839], + [ 336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 840, 841, 846, 847, 852, 853, 858, 859, 864, 865, 870, 871, 876, 877], + [ 338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 842, 843, 848, 849, 854, 855, 860, 861, 866, 867, 872, 873, 878, 879], + [ 340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 844, 845, 850, 851, 856, 857, 862, 863, 868, 869, 874, 875, 880, 881], + [ 378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 882, 883, 888, 889, 894, 895, 900, 901, 906, 907, 912, 913, 918, 919], + [ 380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 884, 885, 890, 891, 896, 897, 902, 903, 908, 909, 914, 915, 920, 921], + [ 382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 886, 887, 892, 893, 898, 899, 904, 905, 910, 911, 916, 917, 922, 923], + [ 420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 924, 925, 930, 931, 936, 937, 942, 943, 948, 949, 954, 955, 960, 961], + [ 422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 926, 927, 932, 933, 938, 939, 944, 945, 950, 951, 956, 957, 962, 963], + [ 424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 928, 929, 934, 935, 940, 941, 946, 947, 952, 953, 958, 959, 964, 965], + [ 462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 966, 967, 972, 973, 978, 979, 984, 985, 990, 991, 996, 997, 1002, 1003], + [ 464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 968, 969, 974, 975, 980, 981, 986, 987, 992, 993, 998, 999, 1004, 1005], + [ 466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 970, 971, 976, 977, 982, 983, 988, 989, 994, 995, 1000, 1001, 1006, 1007]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # all col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + tile_group_col_major=True, + iter_col_major=True, + allow_partial=True, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=420, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=840, sizes=[7, 2, 2, 3], strides=[2, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=14, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=434, sizes=[7, 5, 2, 3], strides=[2, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=854, sizes=[7, 2, 2, 3], strides=[2, 84, 1, 28] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627, 654, 657, 684, 687], + [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628, 655, 658, 685, 688], + [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629, 656, 659, 686, 689], + [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 510, 513, 540, 543, 570, 573, 600, 603, 630, 633, 660, 663, 690, 693], + [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 511, 514, 541, 544, 571, 574, 601, 604, 631, 634, 661, 664, 691, 694], + [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 512, 515, 542, 545, 572, 575, 602, 605, 632, 635, 662, 665, 692, 695], + [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 516, 519, 546, 549, 576, 579, 606, 609, 636, 639, 666, 669, 696, 699], + [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 517, 520, 547, 550, 577, 580, 607, 610, 637, 640, 667, 670, 697, 700], + [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 518, 521, 548, 551, 578, 581, 608, 611, 638, 641, 668, 671, 698, 701], + [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 522, 525, 552, 555, 582, 585, 612, 615, 642, 645, 672, 675, 702, 705], + [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 523, 526, 553, 556, 583, 586, 613, 616, 643, 646, 673, 676, 703, 706], + [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 524, 527, 554, 557, 584, 587, 614, 617, 644, 647, 674, 677, 704, 707], + [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 528, 531, 558, 561, 588, 591, 618, 621, 648, 651, 678, 681, 708, 711], + [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 529, 532, 559, 562, 589, 592, 619, 622, 649, 652, 679, 682, 709, 712], + [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 530, 533, 560, 563, 590, 593, 620, 623, 650, 653, 680, 683, 710, 713], + [ 210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393, 714, 717, 744, 747, 774, 777, 804, 807, 834, 837, 864, 867, 894, 897], + [ 211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394, 715, 718, 745, 748, 775, 778, 805, 808, 835, 838, 865, 868, 895, 898], + [ 212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395, 716, 719, 746, 749, 776, 779, 806, 809, 836, 839, 866, 869, 896, 899], + [ 216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399, 720, 723, 750, 753, 780, 783, 810, 813, 840, 843, 870, 873, 900, 903], + [ 217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400, 721, 724, 751, 754, 781, 784, 811, 814, 841, 844, 871, 874, 901, 904], + [ 218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401, 722, 725, 752, 755, 782, 785, 812, 815, 842, 845, 872, 875, 902, 905], + [ 222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405, 726, 729, 756, 759, 786, 789, 816, 819, 846, 849, 876, 879, 906, 909], + [ 223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406, 727, 730, 757, 760, 787, 790, 817, 820, 847, 850, 877, 880, 907, 910], + [ 224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407, 728, 731, 758, 761, 788, 791, 818, 821, 848, 851, 878, 881, 908, 911], + [ 228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411, 732, 735, 762, 765, 792, 795, 822, 825, 852, 855, 882, 885, 912, 915], + [ 229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412, 733, 736, 763, 766, 793, 796, 823, 826, 853, 856, 883, 886, 913, 916], + [ 230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413, 734, 737, 764, 767, 794, 797, 824, 827, 854, 857, 884, 887, 914, 917], + [ 234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417, 738, 741, 768, 771, 798, 801, 828, 831, 858, 861, 888, 891, 918, 921], + [ 235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418, 739, 742, 769, 772, 799, 802, 829, 832, 859, 862, 889, 892, 919, 922], + [ 236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419, 740, 743, 770, 773, 800, 803, 830, 833, 860, 863, 890, 893, 920, 923], + [ 420, 423, 432, 435, 444, 447, 456, 459, 468, 471, 480, 483, 492, 495, 924, 927, 936, 939, 948, 951, 960, 963, 972, 975, 984, 987, 996, 999], + [ 421, 424, 433, 436, 445, 448, 457, 460, 469, 472, 481, 484, 493, 496, 925, 928, 937, 940, 949, 952, 961, 964, 973, 976, 985, 988, 997, 1000], + [ 422, 425, 434, 437, 446, 449, 458, 461, 470, 473, 482, 485, 494, 497, 926, 929, 938, 941, 950, 953, 962, 965, 974, 977, 986, 989, 998, 1001], + [ 426, 429, 438, 441, 450, 453, 462, 465, 474, 477, 486, 489, 498, 501, 930, 933, 942, 945, 954, 957, 966, 969, 978, 981, 990, 993, 1002, 1005], + [ 427, 430, 439, 442, 451, 454, 463, 466, 475, 478, 487, 490, 499, 502, 931, 934, 943, 946, 955, 958, 967, 970, 979, 982, 991, 994, 1003, 1006], + [ 428, 431, 440, 443, 452, 455, 464, 467, 476, 479, 488, 491, 500, 503, 932, 935, 944, 947, 956, 959, 968, 971, 980, 983, 992, 995, 1004, 1007]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # pattern repeat + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + allow_partial=True, + pattern_repeat=3, + ) + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=14, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=420, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=434, sizes=[3, 5, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=840, sizes=[3, 2, 14, 3], strides=[0, 84, 1, 28] + ), + TensorTile( + tensor_dims, offset=854, sizes=[3, 2, 14, 3], strides=[0, 84, 1, 28] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 420, 423, 426, 429, 432, 435, 438, 441, 444, 447, 450, 453, 456, 459, 1050, 1053, 1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089], + [ 421, 424, 427, 430, 433, 436, 439, 442, 445, 448, 451, 454, 457, 460, 1051, 1054, 1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090], + [ 422, 425, 428, 431, 434, 437, 440, 443, 446, 449, 452, 455, 458, 461, 1052, 1055, 1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091], + [ 462, 465, 468, 471, 474, 477, 480, 483, 486, 489, 492, 495, 498, 501, 1092, 1095, 1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131], + [ 463, 466, 469, 472, 475, 478, 481, 484, 487, 490, 493, 496, 499, 502, 1093, 1096, 1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132], + [ 464, 467, 470, 473, 476, 479, 482, 485, 488, 491, 494, 497, 500, 503, 1094, 1097, 1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133], + [ 504, 507, 510, 513, 516, 519, 522, 525, 528, 531, 534, 537, 540, 543, 1134, 1137, 1140, 1143, 1146, 1149, 1152, 1155, 1158, 1161, 1164, 1167, 1170, 1173], + [ 505, 508, 511, 514, 517, 520, 523, 526, 529, 532, 535, 538, 541, 544, 1135, 1138, 1141, 1144, 1147, 1150, 1153, 1156, 1159, 1162, 1165, 1168, 1171, 1174], + [ 506, 509, 512, 515, 518, 521, 524, 527, 530, 533, 536, 539, 542, 545, 1136, 1139, 1142, 1145, 1148, 1151, 1154, 1157, 1160, 1163, 1166, 1169, 1172, 1175], + [ 546, 549, 552, 555, 558, 561, 564, 567, 570, 573, 576, 579, 582, 585, 1176, 1179, 1182, 1185, 1188, 1191, 1194, 1197, 1200, 1203, 1206, 1209, 1212, 1215], + [ 547, 550, 553, 556, 559, 562, 565, 568, 571, 574, 577, 580, 583, 586, 1177, 1180, 1183, 1186, 1189, 1192, 1195, 1198, 1201, 1204, 1207, 1210, 1213, 1216], + [ 548, 551, 554, 557, 560, 563, 566, 569, 572, 575, 578, 581, 584, 587, 1178, 1181, 1184, 1187, 1190, 1193, 1196, 1199, 1202, 1205, 1208, 1211, 1214, 1217], + [ 588, 591, 594, 597, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627, 1218, 1221, 1224, 1227, 1230, 1233, 1236, 1239, 1242, 1245, 1248, 1251, 1254, 1257], + [ 589, 592, 595, 598, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628, 1219, 1222, 1225, 1228, 1231, 1234, 1237, 1240, 1243, 1246, 1249, 1252, 1255, 1258], + [ 590, 593, 596, 599, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629, 1220, 1223, 1226, 1229, 1232, 1235, 1238, 1241, 1244, 1247, 1250, 1253, 1256, 1259], + [1680, 1683, 1686, 1689, 1692, 1695, 1698, 1701, 1704, 1707, 1710, 1713, 1716, 1719, 2310, 2313, 2316, 2319, 2322, 2325, 2328, 2331, 2334, 2337, 2340, 2343, 2346, 2349], + [1681, 1684, 1687, 1690, 1693, 1696, 1699, 1702, 1705, 1708, 1711, 1714, 1717, 1720, 2311, 2314, 2317, 2320, 2323, 2326, 2329, 2332, 2335, 2338, 2341, 2344, 2347, 2350], + [1682, 1685, 1688, 1691, 1694, 1697, 1700, 1703, 1706, 1709, 1712, 1715, 1718, 1721, 2312, 2315, 2318, 2321, 2324, 2327, 2330, 2333, 2336, 2339, 2342, 2345, 2348, 2351], + [1722, 1725, 1728, 1731, 1734, 1737, 1740, 1743, 1746, 1749, 1752, 1755, 1758, 1761, 2352, 2355, 2358, 2361, 2364, 2367, 2370, 2373, 2376, 2379, 2382, 2385, 2388, 2391], + [1723, 1726, 1729, 1732, 1735, 1738, 1741, 1744, 1747, 1750, 1753, 1756, 1759, 1762, 2353, 2356, 2359, 2362, 2365, 2368, 2371, 2374, 2377, 2380, 2383, 2386, 2389, 2392], + [1724, 1727, 1730, 1733, 1736, 1739, 1742, 1745, 1748, 1751, 1754, 1757, 1760, 1763, 2354, 2357, 2360, 2363, 2366, 2369, 2372, 2375, 2378, 2381, 2384, 2387, 2390, 2393], + [1764, 1767, 1770, 1773, 1776, 1779, 1782, 1785, 1788, 1791, 1794, 1797, 1800, 1803, 2394, 2397, 2400, 2403, 2406, 2409, 2412, 2415, 2418, 2421, 2424, 2427, 2430, 2433], + [1765, 1768, 1771, 1774, 1777, 1780, 1783, 1786, 1789, 1792, 1795, 1798, 1801, 1804, 2395, 2398, 2401, 2404, 2407, 2410, 2413, 2416, 2419, 2422, 2425, 2428, 2431, 2434], + [1766, 1769, 1772, 1775, 1778, 1781, 1784, 1787, 1790, 1793, 1796, 1799, 1802, 1805, 2396, 2399, 2402, 2405, 2408, 2411, 2414, 2417, 2420, 2423, 2426, 2429, 2432, 2435], + [1806, 1809, 1812, 1815, 1818, 1821, 1824, 1827, 1830, 1833, 1836, 1839, 1842, 1845, 2436, 2439, 2442, 2445, 2448, 2451, 2454, 2457, 2460, 2463, 2466, 2469, 2472, 2475], + [1807, 1810, 1813, 1816, 1819, 1822, 1825, 1828, 1831, 1834, 1837, 1840, 1843, 1846, 2437, 2440, 2443, 2446, 2449, 2452, 2455, 2458, 2461, 2464, 2467, 2470, 2473, 2476], + [1808, 1811, 1814, 1817, 1820, 1823, 1826, 1829, 1832, 1835, 1838, 1841, 1844, 1847, 2438, 2441, 2444, 2447, 2450, 2453, 2456, 2459, 2462, 2465, 2468, 2471, 2474, 2477], + [1848, 1851, 1854, 1857, 1860, 1863, 1866, 1869, 1872, 1875, 1878, 1881, 1884, 1887, 2478, 2481, 2484, 2487, 2490, 2493, 2496, 2499, 2502, 2505, 2508, 2511, 2514, 2517], + [1849, 1852, 1855, 1858, 1861, 1864, 1867, 1870, 1873, 1876, 1879, 1882, 1885, 1888, 2479, 2482, 2485, 2488, 2491, 2494, 2497, 2500, 2503, 2506, 2509, 2512, 2515, 2518], + [1850, 1853, 1856, 1859, 1862, 1865, 1868, 1871, 1874, 1877, 1880, 1883, 1886, 1889, 2480, 2483, 2486, 2489, 2492, 2495, 2498, 2501, 2504, 2507, 2510, 2513, 2516, 2519], + [2688, 2691, 2694, 2697, 2700, 2703, 2706, 2709, 2712, 2715, 2718, 2721, 2724, 2727, 2940, 2943, 2946, 2949, 2952, 2955, 2958, 2961, 2964, 2967, 2970, 2973, 2976, 2979], + [2689, 2692, 2695, 2698, 2701, 2704, 2707, 2710, 2713, 2716, 2719, 2722, 2725, 2728, 2941, 2944, 2947, 2950, 2953, 2956, 2959, 2962, 2965, 2968, 2971, 2974, 2977, 2980], + [2690, 2693, 2696, 2699, 2702, 2705, 2708, 2711, 2714, 2717, 2720, 2723, 2726, 2729, 2942, 2945, 2948, 2951, 2954, 2957, 2960, 2963, 2966, 2969, 2972, 2975, 2978, 2981], + [2730, 2733, 2736, 2739, 2742, 2745, 2748, 2751, 2754, 2757, 2760, 2763, 2766, 2769, 2982, 2985, 2988, 2991, 2994, 2997, 3000, 3003, 3006, 3009, 3012, 3015, 3018, 3021], + [2731, 2734, 2737, 2740, 2743, 2746, 2749, 2752, 2755, 2758, 2761, 2764, 2767, 2770, 2983, 2986, 2989, 2992, 2995, 2998, 3001, 3004, 3007, 3010, 3013, 3016, 3019, 3022], + [2732, 2735, 2738, 2741, 2744, 2747, 2750, 2753, 2756, 2759, 2762, 2765, 2768, 2771, 2984, 2987, 2990, 2993, 2996, 2999, 3002, 3005, 3008, 3011, 3014, 3017, 3020, 3023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 3).all() + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: group_tiler_partial_both +@construct_test +def group_tiler_partial_both(): + + # All row major + tensor_dims = (3 * 4 * 3, 2 * 6 * 2) + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + allow_partial=True, + ) + + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=720, sizes=[2, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=734, sizes=[2, 5, 3, 2], strides=[72, 2, 24, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 210, 211, 216, 217, 222, 223, 228, 229, 234, 235], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 212, 213, 218, 219, 224, 225, 230, 231, 236, 237], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 214, 215, 220, 221, 226, 227, 232, 233, 238, 239], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 240, 241, 246, 247, 252, 253, 258, 259, 264, 265], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 242, 243, 248, 249, 254, 255, 260, 261, 266, 267], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 244, 245, 250, 251, 256, 257, 262, 263, 268, 269], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 270, 271, 276, 277, 282, 283, 288, 289, 294, 295], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 272, 273, 278, 279, 284, 285, 290, 291, 296, 297], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 274, 275, 280, 281, 286, 287, 292, 293, 298, 299], + [126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325], + [128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327], + [130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329], + [168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 330, 331, 336, 337, 342, 343, 348, 349, 354, 355], + [170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 332, 333, 338, 339, 344, 345, 350, 351, 356, 357], + [172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 334, 335, 340, 341, 346, 347, 352, 353, 358, 359], + [360, 361, 366, 367, 372, 373, 378, 379, 384, 385, 390, 391, 396, 397, 570, 571, 576, 577, 582, 583, 588, 589, 594, 595], + [362, 363, 368, 369, 374, 375, 380, 381, 386, 387, 392, 393, 398, 399, 572, 573, 578, 579, 584, 585, 590, 591, 596, 597], + [364, 365, 370, 371, 376, 377, 382, 383, 388, 389, 394, 395, 400, 401, 574, 575, 580, 581, 586, 587, 592, 593, 598, 599], + [402, 403, 408, 409, 414, 415, 420, 421, 426, 427, 432, 433, 438, 439, 600, 601, 606, 607, 612, 613, 618, 619, 624, 625], + [404, 405, 410, 411, 416, 417, 422, 423, 428, 429, 434, 435, 440, 441, 602, 603, 608, 609, 614, 615, 620, 621, 626, 627], + [406, 407, 412, 413, 418, 419, 424, 425, 430, 431, 436, 437, 442, 443, 604, 605, 610, 611, 616, 617, 622, 623, 628, 629], + [444, 445, 450, 451, 456, 457, 462, 463, 468, 469, 474, 475, 480, 481, 630, 631, 636, 637, 642, 643, 648, 649, 654, 655], + [446, 447, 452, 453, 458, 459, 464, 465, 470, 471, 476, 477, 482, 483, 632, 633, 638, 639, 644, 645, 650, 651, 656, 657], + [448, 449, 454, 455, 460, 461, 466, 467, 472, 473, 478, 479, 484, 485, 634, 635, 640, 641, 646, 647, 652, 653, 658, 659], + [486, 487, 492, 493, 498, 499, 504, 505, 510, 511, 516, 517, 522, 523, 660, 661, 666, 667, 672, 673, 678, 679, 684, 685], + [488, 489, 494, 495, 500, 501, 506, 507, 512, 513, 518, 519, 524, 525, 662, 663, 668, 669, 674, 675, 680, 681, 686, 687], + [490, 491, 496, 497, 502, 503, 508, 509, 514, 515, 520, 521, 526, 527, 664, 665, 670, 671, 676, 677, 682, 683, 688, 689], + [528, 529, 534, 535, 540, 541, 546, 547, 552, 553, 558, 559, 564, 565, 690, 691, 696, 697, 702, 703, 708, 709, 714, 715], + [530, 531, 536, 537, 542, 543, 548, 549, 554, 555, 560, 561, 566, 567, 692, 693, 698, 699, 704, 705, 710, 711, 716, 717], + [532, 533, 538, 539, 544, 545, 550, 551, 556, 557, 562, 563, 568, 569, 694, 695, 700, 701, 706, 707, 712, 713, 718, 719], + [720, 721, 726, 727, 732, 733, 738, 739, 744, 745, 750, 751, 756, 757, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829], + [722, 723, 728, 729, 734, 735, 740, 741, 746, 747, 752, 753, 758, 759, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831], + [724, 725, 730, 731, 736, 737, 742, 743, 748, 749, 754, 755, 760, 761, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833], + [762, 763, 768, 769, 774, 775, 780, 781, 786, 787, 792, 793, 798, 799, 834, 835, 840, 841, 846, 847, 852, 853, 858, 859], + [764, 765, 770, 771, 776, 777, 782, 783, 788, 789, 794, 795, 800, 801, 836, 837, 842, 843, 848, 849, 854, 855, 860, 861], + [766, 767, 772, 773, 778, 779, 784, 785, 790, 791, 796, 797, 802, 803, 838, 839, 844, 845, 850, 851, 856, 857, 862, 863]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + allow_partial=True, + ) + + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=14, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=360, sizes=[1, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=374, sizes=[1, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=720, sizes=[1, 2, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=734, sizes=[1, 2, 10, 3], strides=[0, 72, 1, 24] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 210, 213, 216, 219, 222, 225, 228, 231, 234, 237], + [ 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 211, 214, 217, 220, 223, 226, 229, 232, 235, 238], + [ 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 212, 215, 218, 221, 224, 227, 230, 233, 236, 239], + [ 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 240, 243, 246, 249, 252, 255, 258, 261, 264, 267], + [ 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 241, 244, 247, 250, 253, 256, 259, 262, 265, 268], + [ 44, 47, 50, 53, 56, 59, 62, 65, 68, 71, 74, 77, 80, 83, 242, 245, 248, 251, 254, 257, 260, 263, 266, 269], + [ 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 270, 273, 276, 279, 282, 285, 288, 291, 294, 297], + [ 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 271, 274, 277, 280, 283, 286, 289, 292, 295, 298], + [ 86, 89, 92, 95, 98, 101, 104, 107, 110, 113, 116, 119, 122, 125, 272, 275, 278, 281, 284, 287, 290, 293, 296, 299], + [126, 129, 132, 135, 138, 141, 144, 147, 150, 153, 156, 159, 162, 165, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327], + [127, 130, 133, 136, 139, 142, 145, 148, 151, 154, 157, 160, 163, 166, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328], + [128, 131, 134, 137, 140, 143, 146, 149, 152, 155, 158, 161, 164, 167, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329], + [168, 171, 174, 177, 180, 183, 186, 189, 192, 195, 198, 201, 204, 207, 330, 333, 336, 339, 342, 345, 348, 351, 354, 357], + [169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202, 205, 208, 331, 334, 337, 340, 343, 346, 349, 352, 355, 358], + [170, 173, 176, 179, 182, 185, 188, 191, 194, 197, 200, 203, 206, 209, 332, 335, 338, 341, 344, 347, 350, 353, 356, 359], + [360, 363, 366, 369, 372, 375, 378, 381, 384, 387, 390, 393, 396, 399, 570, 573, 576, 579, 582, 585, 588, 591, 594, 597], + [361, 364, 367, 370, 373, 376, 379, 382, 385, 388, 391, 394, 397, 400, 571, 574, 577, 580, 583, 586, 589, 592, 595, 598], + [362, 365, 368, 371, 374, 377, 380, 383, 386, 389, 392, 395, 398, 401, 572, 575, 578, 581, 584, 587, 590, 593, 596, 599], + [402, 405, 408, 411, 414, 417, 420, 423, 426, 429, 432, 435, 438, 441, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627], + [403, 406, 409, 412, 415, 418, 421, 424, 427, 430, 433, 436, 439, 442, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628], + [404, 407, 410, 413, 416, 419, 422, 425, 428, 431, 434, 437, 440, 443, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629], + [444, 447, 450, 453, 456, 459, 462, 465, 468, 471, 474, 477, 480, 483, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657], + [445, 448, 451, 454, 457, 460, 463, 466, 469, 472, 475, 478, 481, 484, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658], + [446, 449, 452, 455, 458, 461, 464, 467, 470, 473, 476, 479, 482, 485, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659], + [486, 489, 492, 495, 498, 501, 504, 507, 510, 513, 516, 519, 522, 525, 660, 663, 666, 669, 672, 675, 678, 681, 684, 687], + [487, 490, 493, 496, 499, 502, 505, 508, 511, 514, 517, 520, 523, 526, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688], + [488, 491, 494, 497, 500, 503, 506, 509, 512, 515, 518, 521, 524, 527, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689], + [528, 531, 534, 537, 540, 543, 546, 549, 552, 555, 558, 561, 564, 567, 690, 693, 696, 699, 702, 705, 708, 711, 714, 717], + [529, 532, 535, 538, 541, 544, 547, 550, 553, 556, 559, 562, 565, 568, 691, 694, 697, 700, 703, 706, 709, 712, 715, 718], + [530, 533, 536, 539, 542, 545, 548, 551, 554, 557, 560, 563, 566, 569, 692, 695, 698, 701, 704, 707, 710, 713, 716, 719], + [720, 723, 726, 729, 732, 735, 738, 741, 744, 747, 750, 753, 756, 759, 804, 807, 810, 813, 816, 819, 822, 825, 828, 831], + [721, 724, 727, 730, 733, 736, 739, 742, 745, 748, 751, 754, 757, 760, 805, 808, 811, 814, 817, 820, 823, 826, 829, 832], + [722, 725, 728, 731, 734, 737, 740, 743, 746, 749, 752, 755, 758, 761, 806, 809, 812, 815, 818, 821, 824, 827, 830, 833], + [762, 765, 768, 771, 774, 777, 780, 783, 786, 789, 792, 795, 798, 801, 834, 837, 840, 843, 846, 849, 852, 855, 858, 861], + [763, 766, 769, 772, 775, 778, 781, 784, 787, 790, 793, 796, 799, 802, 835, 838, 841, 844, 847, 850, 853, 856, 859, 862], + [764, 767, 770, 773, 776, 779, 782, 785, 788, 791, 794, 797, 800, 803, 836, 839, 842, 845, 848, 851, 854, 857, 860, 863]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile group col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_group_col_major=True, + allow_partial=True, + ) + + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=360, sizes=[1, 7, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=374, sizes=[1, 5, 15, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=720, sizes=[1, 7, 6, 2], strides=[0, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=734, sizes=[1, 5, 6, 2], strides=[0, 2, 24, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 30, 31, 60, 61, 90, 91, 120, 121, 150, 151, 180, 181, 210, 211, 240, 241, 270, 271, 300, 301, 330, 331], + [ 2, 3, 32, 33, 62, 63, 92, 93, 122, 123, 152, 153, 182, 183, 212, 213, 242, 243, 272, 273, 302, 303, 332, 333], + [ 4, 5, 34, 35, 64, 65, 94, 95, 124, 125, 154, 155, 184, 185, 214, 215, 244, 245, 274, 275, 304, 305, 334, 335], + [ 6, 7, 36, 37, 66, 67, 96, 97, 126, 127, 156, 157, 186, 187, 216, 217, 246, 247, 276, 277, 306, 307, 336, 337], + [ 8, 9, 38, 39, 68, 69, 98, 99, 128, 129, 158, 159, 188, 189, 218, 219, 248, 249, 278, 279, 308, 309, 338, 339], + [ 10, 11, 40, 41, 70, 71, 100, 101, 130, 131, 160, 161, 190, 191, 220, 221, 250, 251, 280, 281, 310, 311, 340, 341], + [ 12, 13, 42, 43, 72, 73, 102, 103, 132, 133, 162, 163, 192, 193, 222, 223, 252, 253, 282, 283, 312, 313, 342, 343], + [ 14, 15, 44, 45, 74, 75, 104, 105, 134, 135, 164, 165, 194, 195, 224, 225, 254, 255, 284, 285, 314, 315, 344, 345], + [ 16, 17, 46, 47, 76, 77, 106, 107, 136, 137, 166, 167, 196, 197, 226, 227, 256, 257, 286, 287, 316, 317, 346, 347], + [ 18, 19, 48, 49, 78, 79, 108, 109, 138, 139, 168, 169, 198, 199, 228, 229, 258, 259, 288, 289, 318, 319, 348, 349], + [ 20, 21, 50, 51, 80, 81, 110, 111, 140, 141, 170, 171, 200, 201, 230, 231, 260, 261, 290, 291, 320, 321, 350, 351], + [ 22, 23, 52, 53, 82, 83, 112, 113, 142, 143, 172, 173, 202, 203, 232, 233, 262, 263, 292, 293, 322, 323, 352, 353], + [ 24, 25, 54, 55, 84, 85, 114, 115, 144, 145, 174, 175, 204, 205, 234, 235, 264, 265, 294, 295, 324, 325, 354, 355], + [ 26, 27, 56, 57, 86, 87, 116, 117, 146, 147, 176, 177, 206, 207, 236, 237, 266, 267, 296, 297, 326, 327, 356, 357], + [ 28, 29, 58, 59, 88, 89, 118, 119, 148, 149, 178, 179, 208, 209, 238, 239, 268, 269, 298, 299, 328, 329, 358, 359], + [360, 361, 390, 391, 420, 421, 450, 451, 480, 481, 510, 511, 540, 541, 570, 571, 600, 601, 630, 631, 660, 661, 690, 691], + [362, 363, 392, 393, 422, 423, 452, 453, 482, 483, 512, 513, 542, 543, 572, 573, 602, 603, 632, 633, 662, 663, 692, 693], + [364, 365, 394, 395, 424, 425, 454, 455, 484, 485, 514, 515, 544, 545, 574, 575, 604, 605, 634, 635, 664, 665, 694, 695], + [366, 367, 396, 397, 426, 427, 456, 457, 486, 487, 516, 517, 546, 547, 576, 577, 606, 607, 636, 637, 666, 667, 696, 697], + [368, 369, 398, 399, 428, 429, 458, 459, 488, 489, 518, 519, 548, 549, 578, 579, 608, 609, 638, 639, 668, 669, 698, 699], + [370, 371, 400, 401, 430, 431, 460, 461, 490, 491, 520, 521, 550, 551, 580, 581, 610, 611, 640, 641, 670, 671, 700, 701], + [372, 373, 402, 403, 432, 433, 462, 463, 492, 493, 522, 523, 552, 553, 582, 583, 612, 613, 642, 643, 672, 673, 702, 703], + [374, 375, 404, 405, 434, 435, 464, 465, 494, 495, 524, 525, 554, 555, 584, 585, 614, 615, 644, 645, 674, 675, 704, 705], + [376, 377, 406, 407, 436, 437, 466, 467, 496, 497, 526, 527, 556, 557, 586, 587, 616, 617, 646, 647, 676, 677, 706, 707], + [378, 379, 408, 409, 438, 439, 468, 469, 498, 499, 528, 529, 558, 559, 588, 589, 618, 619, 648, 649, 678, 679, 708, 709], + [380, 381, 410, 411, 440, 441, 470, 471, 500, 501, 530, 531, 560, 561, 590, 591, 620, 621, 650, 651, 680, 681, 710, 711], + [382, 383, 412, 413, 442, 443, 472, 473, 502, 503, 532, 533, 562, 563, 592, 593, 622, 623, 652, 653, 682, 683, 712, 713], + [384, 385, 414, 415, 444, 445, 474, 475, 504, 505, 534, 535, 564, 565, 594, 595, 624, 625, 654, 655, 684, 685, 714, 715], + [386, 387, 416, 417, 446, 447, 476, 477, 506, 507, 536, 537, 566, 567, 596, 597, 626, 627, 656, 657, 686, 687, 716, 717], + [388, 389, 418, 419, 448, 449, 478, 479, 508, 509, 538, 539, 568, 569, 598, 599, 628, 629, 658, 659, 688, 689, 718, 719], + [720, 721, 732, 733, 744, 745, 756, 757, 768, 769, 780, 781, 792, 793, 804, 805, 816, 817, 828, 829, 840, 841, 852, 853], + [722, 723, 734, 735, 746, 747, 758, 759, 770, 771, 782, 783, 794, 795, 806, 807, 818, 819, 830, 831, 842, 843, 854, 855], + [724, 725, 736, 737, 748, 749, 760, 761, 772, 773, 784, 785, 796, 797, 808, 809, 820, 821, 832, 833, 844, 845, 856, 857], + [726, 727, 738, 739, 750, 751, 762, 763, 774, 775, 786, 787, 798, 799, 810, 811, 822, 823, 834, 835, 846, 847, 858, 859], + [728, 729, 740, 741, 752, 753, 764, 765, 776, 777, 788, 789, 800, 801, 812, 813, 824, 825, 836, 837, 848, 849, 860, 861], + [730, 731, 742, 743, 754, 755, 766, 767, 778, 779, 790, 791, 802, 803, 814, 815, 826, 827, 838, 839, 850, 851, 862, 863]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # iter col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + iter_col_major=True, + allow_partial=True, + ) + + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=360, sizes=[5, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=720, sizes=[2, 7, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=374, sizes=[5, 5, 3, 2], strides=[72, 2, 24, 1] + ), + TensorTile( + tensor_dims, offset=734, sizes=[2, 5, 3, 2], strides=[72, 2, 24, 1] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7, 12, 13, 18, 19, 24, 25, 30, 31, 36, 37, 504, 505, 510, 511, 516, 517, 522, 523, 528, 529], + [ 2, 3, 8, 9, 14, 15, 20, 21, 26, 27, 32, 33, 38, 39, 506, 507, 512, 513, 518, 519, 524, 525, 530, 531], + [ 4, 5, 10, 11, 16, 17, 22, 23, 28, 29, 34, 35, 40, 41, 508, 509, 514, 515, 520, 521, 526, 527, 532, 533], + [ 42, 43, 48, 49, 54, 55, 60, 61, 66, 67, 72, 73, 78, 79, 534, 535, 540, 541, 546, 547, 552, 553, 558, 559], + [ 44, 45, 50, 51, 56, 57, 62, 63, 68, 69, 74, 75, 80, 81, 536, 537, 542, 543, 548, 549, 554, 555, 560, 561], + [ 46, 47, 52, 53, 58, 59, 64, 65, 70, 71, 76, 77, 82, 83, 538, 539, 544, 545, 550, 551, 556, 557, 562, 563], + [ 84, 85, 90, 91, 96, 97, 102, 103, 108, 109, 114, 115, 120, 121, 564, 565, 570, 571, 576, 577, 582, 583, 588, 589], + [ 86, 87, 92, 93, 98, 99, 104, 105, 110, 111, 116, 117, 122, 123, 566, 567, 572, 573, 578, 579, 584, 585, 590, 591], + [ 88, 89, 94, 95, 100, 101, 106, 107, 112, 113, 118, 119, 124, 125, 568, 569, 574, 575, 580, 581, 586, 587, 592, 593], + [126, 127, 132, 133, 138, 139, 144, 145, 150, 151, 156, 157, 162, 163, 594, 595, 600, 601, 606, 607, 612, 613, 618, 619], + [128, 129, 134, 135, 140, 141, 146, 147, 152, 153, 158, 159, 164, 165, 596, 597, 602, 603, 608, 609, 614, 615, 620, 621], + [130, 131, 136, 137, 142, 143, 148, 149, 154, 155, 160, 161, 166, 167, 598, 599, 604, 605, 610, 611, 616, 617, 622, 623], + [168, 169, 174, 175, 180, 181, 186, 187, 192, 193, 198, 199, 204, 205, 624, 625, 630, 631, 636, 637, 642, 643, 648, 649], + [170, 171, 176, 177, 182, 183, 188, 189, 194, 195, 200, 201, 206, 207, 626, 627, 632, 633, 638, 639, 644, 645, 650, 651], + [172, 173, 178, 179, 184, 185, 190, 191, 196, 197, 202, 203, 208, 209, 628, 629, 634, 635, 640, 641, 646, 647, 652, 653], + [210, 211, 216, 217, 222, 223, 228, 229, 234, 235, 240, 241, 246, 247, 654, 655, 660, 661, 666, 667, 672, 673, 678, 679], + [212, 213, 218, 219, 224, 225, 230, 231, 236, 237, 242, 243, 248, 249, 656, 657, 662, 663, 668, 669, 674, 675, 680, 681], + [214, 215, 220, 221, 226, 227, 232, 233, 238, 239, 244, 245, 250, 251, 658, 659, 664, 665, 670, 671, 676, 677, 682, 683], + [252, 253, 258, 259, 264, 265, 270, 271, 276, 277, 282, 283, 288, 289, 684, 685, 690, 691, 696, 697, 702, 703, 708, 709], + [254, 255, 260, 261, 266, 267, 272, 273, 278, 279, 284, 285, 290, 291, 686, 687, 692, 693, 698, 699, 704, 705, 710, 711], + [256, 257, 262, 263, 268, 269, 274, 275, 280, 281, 286, 287, 292, 293, 688, 689, 694, 695, 700, 701, 706, 707, 712, 713], + [294, 295, 300, 301, 306, 307, 312, 313, 318, 319, 324, 325, 330, 331, 714, 715, 720, 721, 726, 727, 732, 733, 738, 739], + [296, 297, 302, 303, 308, 309, 314, 315, 320, 321, 326, 327, 332, 333, 716, 717, 722, 723, 728, 729, 734, 735, 740, 741], + [298, 299, 304, 305, 310, 311, 316, 317, 322, 323, 328, 329, 334, 335, 718, 719, 724, 725, 730, 731, 736, 737, 742, 743], + [336, 337, 342, 343, 348, 349, 354, 355, 360, 361, 366, 367, 372, 373, 744, 745, 750, 751, 756, 757, 762, 763, 768, 769], + [338, 339, 344, 345, 350, 351, 356, 357, 362, 363, 368, 369, 374, 375, 746, 747, 752, 753, 758, 759, 764, 765, 770, 771], + [340, 341, 346, 347, 352, 353, 358, 359, 364, 365, 370, 371, 376, 377, 748, 749, 754, 755, 760, 761, 766, 767, 772, 773], + [378, 379, 384, 385, 390, 391, 396, 397, 402, 403, 408, 409, 414, 415, 774, 775, 780, 781, 786, 787, 792, 793, 798, 799], + [380, 381, 386, 387, 392, 393, 398, 399, 404, 405, 410, 411, 416, 417, 776, 777, 782, 783, 788, 789, 794, 795, 800, 801], + [382, 383, 388, 389, 394, 395, 400, 401, 406, 407, 412, 413, 418, 419, 778, 779, 784, 785, 790, 791, 796, 797, 802, 803], + [420, 421, 426, 427, 432, 433, 438, 439, 444, 445, 450, 451, 456, 457, 804, 805, 810, 811, 816, 817, 822, 823, 828, 829], + [422, 423, 428, 429, 434, 435, 440, 441, 446, 447, 452, 453, 458, 459, 806, 807, 812, 813, 818, 819, 824, 825, 830, 831], + [424, 425, 430, 431, 436, 437, 442, 443, 448, 449, 454, 455, 460, 461, 808, 809, 814, 815, 820, 821, 826, 827, 832, 833], + [462, 463, 468, 469, 474, 475, 480, 481, 486, 487, 492, 493, 498, 499, 834, 835, 840, 841, 846, 847, 852, 853, 858, 859], + [464, 465, 470, 471, 476, 477, 482, 483, 488, 489, 494, 495, 500, 501, 836, 837, 842, 843, 848, 849, 854, 855, 860, 861], + [466, 467, 472, 473, 478, 479, 484, 485, 490, 491, 496, 497, 502, 503, 838, 839, 844, 845, 850, 851, 856, 857, 862, 863]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # all col major + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + iter_col_major=True, + tile_col_major=True, + tile_group_col_major=True, + allow_partial=True, + ) + + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=360, sizes=[7, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=720, sizes=[7, 2, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=14, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=374, sizes=[5, 5, 2, 3], strides=[2, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=734, sizes=[5, 2, 2, 3], strides=[2, 72, 1, 24] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 30, 33, 60, 63, 90, 93, 120, 123, 150, 153, 180, 183, 504, 507, 534, 537, 564, 567, 594, 597, 624, 627], + [ 1, 4, 31, 34, 61, 64, 91, 94, 121, 124, 151, 154, 181, 184, 505, 508, 535, 538, 565, 568, 595, 598, 625, 628], + [ 2, 5, 32, 35, 62, 65, 92, 95, 122, 125, 152, 155, 182, 185, 506, 509, 536, 539, 566, 569, 596, 599, 626, 629], + [ 6, 9, 36, 39, 66, 69, 96, 99, 126, 129, 156, 159, 186, 189, 510, 513, 540, 543, 570, 573, 600, 603, 630, 633], + [ 7, 10, 37, 40, 67, 70, 97, 100, 127, 130, 157, 160, 187, 190, 511, 514, 541, 544, 571, 574, 601, 604, 631, 634], + [ 8, 11, 38, 41, 68, 71, 98, 101, 128, 131, 158, 161, 188, 191, 512, 515, 542, 545, 572, 575, 602, 605, 632, 635], + [ 12, 15, 42, 45, 72, 75, 102, 105, 132, 135, 162, 165, 192, 195, 516, 519, 546, 549, 576, 579, 606, 609, 636, 639], + [ 13, 16, 43, 46, 73, 76, 103, 106, 133, 136, 163, 166, 193, 196, 517, 520, 547, 550, 577, 580, 607, 610, 637, 640], + [ 14, 17, 44, 47, 74, 77, 104, 107, 134, 137, 164, 167, 194, 197, 518, 521, 548, 551, 578, 581, 608, 611, 638, 641], + [ 18, 21, 48, 51, 78, 81, 108, 111, 138, 141, 168, 171, 198, 201, 522, 525, 552, 555, 582, 585, 612, 615, 642, 645], + [ 19, 22, 49, 52, 79, 82, 109, 112, 139, 142, 169, 172, 199, 202, 523, 526, 553, 556, 583, 586, 613, 616, 643, 646], + [ 20, 23, 50, 53, 80, 83, 110, 113, 140, 143, 170, 173, 200, 203, 524, 527, 554, 557, 584, 587, 614, 617, 644, 647], + [ 24, 27, 54, 57, 84, 87, 114, 117, 144, 147, 174, 177, 204, 207, 528, 531, 558, 561, 588, 591, 618, 621, 648, 651], + [ 25, 28, 55, 58, 85, 88, 115, 118, 145, 148, 175, 178, 205, 208, 529, 532, 559, 562, 589, 592, 619, 622, 649, 652], + [ 26, 29, 56, 59, 86, 89, 116, 119, 146, 149, 176, 179, 206, 209, 530, 533, 560, 563, 590, 593, 620, 623, 650, 653], + [210, 213, 240, 243, 270, 273, 300, 303, 330, 333, 360, 363, 390, 393, 654, 657, 684, 687, 714, 717, 744, 747, 774, 777], + [211, 214, 241, 244, 271, 274, 301, 304, 331, 334, 361, 364, 391, 394, 655, 658, 685, 688, 715, 718, 745, 748, 775, 778], + [212, 215, 242, 245, 272, 275, 302, 305, 332, 335, 362, 365, 392, 395, 656, 659, 686, 689, 716, 719, 746, 749, 776, 779], + [216, 219, 246, 249, 276, 279, 306, 309, 336, 339, 366, 369, 396, 399, 660, 663, 690, 693, 720, 723, 750, 753, 780, 783], + [217, 220, 247, 250, 277, 280, 307, 310, 337, 340, 367, 370, 397, 400, 661, 664, 691, 694, 721, 724, 751, 754, 781, 784], + [218, 221, 248, 251, 278, 281, 308, 311, 338, 341, 368, 371, 398, 401, 662, 665, 692, 695, 722, 725, 752, 755, 782, 785], + [222, 225, 252, 255, 282, 285, 312, 315, 342, 345, 372, 375, 402, 405, 666, 669, 696, 699, 726, 729, 756, 759, 786, 789], + [223, 226, 253, 256, 283, 286, 313, 316, 343, 346, 373, 376, 403, 406, 667, 670, 697, 700, 727, 730, 757, 760, 787, 790], + [224, 227, 254, 257, 284, 287, 314, 317, 344, 347, 374, 377, 404, 407, 668, 671, 698, 701, 728, 731, 758, 761, 788, 791], + [228, 231, 258, 261, 288, 291, 318, 321, 348, 351, 378, 381, 408, 411, 672, 675, 702, 705, 732, 735, 762, 765, 792, 795], + [229, 232, 259, 262, 289, 292, 319, 322, 349, 352, 379, 382, 409, 412, 673, 676, 703, 706, 733, 736, 763, 766, 793, 796], + [230, 233, 260, 263, 290, 293, 320, 323, 350, 353, 380, 383, 410, 413, 674, 677, 704, 707, 734, 737, 764, 767, 794, 797], + [234, 237, 264, 267, 294, 297, 324, 327, 354, 357, 384, 387, 414, 417, 678, 681, 708, 711, 738, 741, 768, 771, 798, 801], + [235, 238, 265, 268, 295, 298, 325, 328, 355, 358, 385, 388, 415, 418, 679, 682, 709, 712, 739, 742, 769, 772, 799, 802], + [236, 239, 266, 269, 296, 299, 326, 329, 356, 359, 386, 389, 416, 419, 680, 683, 710, 713, 740, 743, 770, 773, 800, 803], + [420, 423, 432, 435, 444, 447, 456, 459, 468, 471, 480, 483, 492, 495, 804, 807, 816, 819, 828, 831, 840, 843, 852, 855], + [421, 424, 433, 436, 445, 448, 457, 460, 469, 472, 481, 484, 493, 496, 805, 808, 817, 820, 829, 832, 841, 844, 853, 856], + [422, 425, 434, 437, 446, 449, 458, 461, 470, 473, 482, 485, 494, 497, 806, 809, 818, 821, 830, 833, 842, 845, 854, 857], + [426, 429, 438, 441, 450, 453, 462, 465, 474, 477, 486, 489, 498, 501, 810, 813, 822, 825, 834, 837, 846, 849, 858, 861], + [427, 430, 439, 442, 451, 454, 463, 466, 475, 478, 487, 490, 499, 502, 811, 814, 823, 826, 835, 838, 847, 850, 859, 862], + [428, 431, 440, 443, 452, 455, 464, 467, 476, 479, 488, 491, 500, 503, 812, 815, 824, 827, 836, 839, 848, 851, 860, 863]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # pattern repeat + tiles = TensorTiler2D.group_tiler( + tensor_dims, + tile_dims=(3, 2), + tile_group_dims=(5, 7), + tile_col_major=True, + allow_partial=True, + pattern_repeat=2, + ) + + reference_tiles = TensorTileSequence.from_tiles( + [ + TensorTile( + tensor_dims, offset=0, sizes=[2, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=14, sizes=[2, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=360, sizes=[2, 5, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=374, sizes=[2, 5, 10, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=720, sizes=[2, 2, 14, 3], strides=[0, 72, 1, 24] + ), + TensorTile( + tensor_dims, offset=734, sizes=[2, 2, 10, 3], strides=[0, 72, 1, 24] + ), + ] + ) + assert tiles == reference_tiles + # fmt: off + ref_access_order_tensor = np.array([ + [ 210, 213, 216, 219, 222, 225, 228, 231, 234, 237, 240, 243, 246, 249, 570, 573, 576, 579, 582, 585, 588, 591, 594, 597], + [ 211, 214, 217, 220, 223, 226, 229, 232, 235, 238, 241, 244, 247, 250, 571, 574, 577, 580, 583, 586, 589, 592, 595, 598], + [ 212, 215, 218, 221, 224, 227, 230, 233, 236, 239, 242, 245, 248, 251, 572, 575, 578, 581, 584, 587, 590, 593, 596, 599], + [ 252, 255, 258, 261, 264, 267, 270, 273, 276, 279, 282, 285, 288, 291, 600, 603, 606, 609, 612, 615, 618, 621, 624, 627], + [ 253, 256, 259, 262, 265, 268, 271, 274, 277, 280, 283, 286, 289, 292, 601, 604, 607, 610, 613, 616, 619, 622, 625, 628], + [ 254, 257, 260, 263, 266, 269, 272, 275, 278, 281, 284, 287, 290, 293, 602, 605, 608, 611, 614, 617, 620, 623, 626, 629], + [ 294, 297, 300, 303, 306, 309, 312, 315, 318, 321, 324, 327, 330, 333, 630, 633, 636, 639, 642, 645, 648, 651, 654, 657], + [ 295, 298, 301, 304, 307, 310, 313, 316, 319, 322, 325, 328, 331, 334, 631, 634, 637, 640, 643, 646, 649, 652, 655, 658], + [ 296, 299, 302, 305, 308, 311, 314, 317, 320, 323, 326, 329, 332, 335, 632, 635, 638, 641, 644, 647, 650, 653, 656, 659], + [ 336, 339, 342, 345, 348, 351, 354, 357, 360, 363, 366, 369, 372, 375, 660, 663, 666, 669, 672, 675, 678, 681, 684, 687], + [ 337, 340, 343, 346, 349, 352, 355, 358, 361, 364, 367, 370, 373, 376, 661, 664, 667, 670, 673, 676, 679, 682, 685, 688], + [ 338, 341, 344, 347, 350, 353, 356, 359, 362, 365, 368, 371, 374, 377, 662, 665, 668, 671, 674, 677, 680, 683, 686, 689], + [ 378, 381, 384, 387, 390, 393, 396, 399, 402, 405, 408, 411, 414, 417, 690, 693, 696, 699, 702, 705, 708, 711, 714, 717], + [ 379, 382, 385, 388, 391, 394, 397, 400, 403, 406, 409, 412, 415, 418, 691, 694, 697, 700, 703, 706, 709, 712, 715, 718], + [ 380, 383, 386, 389, 392, 395, 398, 401, 404, 407, 410, 413, 416, 419, 692, 695, 698, 701, 704, 707, 710, 713, 716, 719], + [ 930, 933, 936, 939, 942, 945, 948, 951, 954, 957, 960, 963, 966, 969, 1290, 1293, 1296, 1299, 1302, 1305, 1308, 1311, 1314, 1317], + [ 931, 934, 937, 940, 943, 946, 949, 952, 955, 958, 961, 964, 967, 970, 1291, 1294, 1297, 1300, 1303, 1306, 1309, 1312, 1315, 1318], + [ 932, 935, 938, 941, 944, 947, 950, 953, 956, 959, 962, 965, 968, 971, 1292, 1295, 1298, 1301, 1304, 1307, 1310, 1313, 1316, 1319], + [ 972, 975, 978, 981, 984, 987, 990, 993, 996, 999, 1002, 1005, 1008, 1011, 1320, 1323, 1326, 1329, 1332, 1335, 1338, 1341, 1344, 1347], + [ 973, 976, 979, 982, 985, 988, 991, 994, 997, 1000, 1003, 1006, 1009, 1012, 1321, 1324, 1327, 1330, 1333, 1336, 1339, 1342, 1345, 1348], + [ 974, 977, 980, 983, 986, 989, 992, 995, 998, 1001, 1004, 1007, 1010, 1013, 1322, 1325, 1328, 1331, 1334, 1337, 1340, 1343, 1346, 1349], + [1014, 1017, 1020, 1023, 1026, 1029, 1032, 1035, 1038, 1041, 1044, 1047, 1050, 1053, 1350, 1353, 1356, 1359, 1362, 1365, 1368, 1371, 1374, 1377], + [1015, 1018, 1021, 1024, 1027, 1030, 1033, 1036, 1039, 1042, 1045, 1048, 1051, 1054, 1351, 1354, 1357, 1360, 1363, 1366, 1369, 1372, 1375, 1378], + [1016, 1019, 1022, 1025, 1028, 1031, 1034, 1037, 1040, 1043, 1046, 1049, 1052, 1055, 1352, 1355, 1358, 1361, 1364, 1367, 1370, 1373, 1376, 1379], + [1056, 1059, 1062, 1065, 1068, 1071, 1074, 1077, 1080, 1083, 1086, 1089, 1092, 1095, 1380, 1383, 1386, 1389, 1392, 1395, 1398, 1401, 1404, 1407], + [1057, 1060, 1063, 1066, 1069, 1072, 1075, 1078, 1081, 1084, 1087, 1090, 1093, 1096, 1381, 1384, 1387, 1390, 1393, 1396, 1399, 1402, 1405, 1408], + [1058, 1061, 1064, 1067, 1070, 1073, 1076, 1079, 1082, 1085, 1088, 1091, 1094, 1097, 1382, 1385, 1388, 1391, 1394, 1397, 1400, 1403, 1406, 1409], + [1098, 1101, 1104, 1107, 1110, 1113, 1116, 1119, 1122, 1125, 1128, 1131, 1134, 1137, 1410, 1413, 1416, 1419, 1422, 1425, 1428, 1431, 1434, 1437], + [1099, 1102, 1105, 1108, 1111, 1114, 1117, 1120, 1123, 1126, 1129, 1132, 1135, 1138, 1411, 1414, 1417, 1420, 1423, 1426, 1429, 1432, 1435, 1438], + [1100, 1103, 1106, 1109, 1112, 1115, 1118, 1121, 1124, 1127, 1130, 1133, 1136, 1139, 1412, 1415, 1418, 1421, 1424, 1427, 1430, 1433, 1436, 1439], + [1524, 1527, 1530, 1533, 1536, 1539, 1542, 1545, 1548, 1551, 1554, 1557, 1560, 1563, 1668, 1671, 1674, 1677, 1680, 1683, 1686, 1689, 1692, 1695], + [1525, 1528, 1531, 1534, 1537, 1540, 1543, 1546, 1549, 1552, 1555, 1558, 1561, 1564, 1669, 1672, 1675, 1678, 1681, 1684, 1687, 1690, 1693, 1696], + [1526, 1529, 1532, 1535, 1538, 1541, 1544, 1547, 1550, 1553, 1556, 1559, 1562, 1565, 1670, 1673, 1676, 1679, 1682, 1685, 1688, 1691, 1694, 1697], + [1566, 1569, 1572, 1575, 1578, 1581, 1584, 1587, 1590, 1593, 1596, 1599, 1602, 1605, 1698, 1701, 1704, 1707, 1710, 1713, 1716, 1719, 1722, 1725], + [1567, 1570, 1573, 1576, 1579, 1582, 1585, 1588, 1591, 1594, 1597, 1600, 1603, 1606, 1699, 1702, 1705, 1708, 1711, 1714, 1717, 1720, 1723, 1726], + [1568, 1571, 1574, 1577, 1580, 1583, 1586, 1589, 1592, 1595, 1598, 1601, 1604, 1607, 1700, 1703, 1706, 1709, 1712, 1715, 1718, 1721, 1724, 1727]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 2).all() + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/simple_tiler.py b/test/python/tensortiler/simple_tiler.py new file mode 100644 index 0000000000..616b9cb8cc --- /dev/null +++ b/test/python/tensortiler/simple_tiler.py @@ -0,0 +1,265 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile, TensorTileSequence, TensorTiler2D +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: simple_tiler +@construct_test +def simple_tiler(): + single_tile = TensorTiler2D.simple_tiler((3, 5)) + assert len(single_tile) == 1 + single_tile[0] == TensorTile( + (3, 5), offset=0, sizes=[1, 1, 3, 5], strides=[0, 0, 5, 1] + ) + ref_access_order_tensor = np.array( + [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]] + ) + access_order, access_count = single_tile.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + tiles = TensorTiler2D.simple_tiler((9, 4), (3, 2)) + assert len(tiles) == 6 + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 6, 7], + [ 2, 3, 8, 9], + [ 4, 5, 10, 11], + [12, 13, 18, 19], + [14, 15, 20, 21], + [16, 17, 22, 23], + [24, 25, 30, 31], + [26, 27, 32, 33], + [28, 29, 34, 35]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + def offset_fn(step, _prev_offset): + offsets = [0, 2, 12, 14, 24, 26] + return offsets[step] + + tiles2 = TensorTileSequence( + (9, 4), + num_steps=6, + sizes=[1, 1, 3, 2], + strides=[0, 0, 4, 1], + offset_fn=offset_fn, + ) + assert tiles == tiles2 + + tile0_0 = TensorTile((9, 4), offset=0, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1]) + tile0_1 = TensorTile((9, 4), offset=2, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1]) + tile1_0 = TensorTile((9, 4), offset=12, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1]) + tile1_1 = TensorTile((9, 4), offset=14, sizes=[1, 1, 3, 2], strides=[0, 0, 4, 1]) + + assert tiles2[0] == tile0_0 + assert tiles2[1] == tile0_1 + assert tiles2[2] == tile1_0 + assert tiles2[3] == tile1_1 + + access_order, access_count = tiles2.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Check with column major iter order + tiles_iter_col_major = TensorTiler2D.simple_tiler( + (9, 4), (3, 2), iter_col_major=True + ) + assert tiles_iter_col_major[0] == tile0_0 + assert tiles_iter_col_major[1] == tile1_0 + assert tiles_iter_col_major[3] == tile0_1 + assert tiles_iter_col_major[4] == tile1_1 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 18, 19], + [ 2, 3, 20, 21], + [ 4, 5, 22, 23], + [ 6, 7, 24, 25], + [ 8, 9, 26, 27], + [10, 11, 28, 29], + [12, 13, 30, 31], + [14, 15, 32, 33], + [16, 17, 34, 35]]) + # fmt: on + access_order, access_count = tiles_iter_col_major.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + tiles_tile_col_major = TensorTiler2D.simple_tiler( + (9, 4), (3, 2), tile_col_major=True + ) + tile0_0 = TensorTile((9, 4), offset=0, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4]) + tile0_1 = TensorTile((9, 4), offset=2, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4]) + tile1_0 = TensorTile((9, 4), offset=12, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4]) + tile1_1 = TensorTile((9, 4), offset=14, sizes=[1, 1, 2, 3], strides=[0, 0, 1, 4]) + assert tiles_tile_col_major[0] == tile0_0 + assert tiles_tile_col_major[1] == tile0_1 + assert tiles_tile_col_major[2] == tile1_0 + assert tiles_tile_col_major[3] == tile1_1 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 6, 9], + [ 1, 4, 7, 10], + [ 2, 5, 8, 11], + [12, 15, 18, 21], + [13, 16, 19, 22], + [14, 17, 20, 23], + [24, 27, 30, 33], + [25, 28, 31, 34], + [26, 29, 32, 35]]) + # fmt: on + access_order, access_count = tiles_tile_col_major.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + tiles_tile_col_major_iter_col_major = TensorTiler2D.simple_tiler( + (9, 4), (3, 2), tile_col_major=True, iter_col_major=True + ) + assert tiles_tile_col_major_iter_col_major[0] == tile0_0 + assert tiles_tile_col_major_iter_col_major[1] == tile1_0 + assert tiles_tile_col_major_iter_col_major[3] == tile0_1 + assert tiles_tile_col_major_iter_col_major[4] == tile1_1 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 3, 18, 21], + [ 1, 4, 19, 22], + [ 2, 5, 20, 23], + [ 6, 9, 24, 27], + [ 7, 10, 25, 28], + [ 8, 11, 26, 29], + [12, 15, 30, 33], + [13, 16, 31, 34], + [14, 17, 32, 35]]) + # fmt: on + access_order, access_count = tiles_tile_col_major_iter_col_major.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + tiles_repeat = TensorTiler2D.simple_tiler((9, 4), (3, 2), pattern_repeat=5) + tile_repeat0_0 = TensorTile( + (9, 4), offset=0, sizes=[1, 5, 3, 2], strides=[0, 0, 4, 1] + ) + assert tiles_repeat[0] == tile_repeat0_0 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 24, 25, 54, 55], + [ 26, 27, 56, 57], + [ 28, 29, 58, 59], + [ 84, 85, 114, 115], + [ 86, 87, 116, 117], + [ 88, 89, 118, 119], + [144, 145, 174, 175], + [146, 147, 176, 177], + [148, 149, 178, 179]]) + # fmt: on + access_order, access_count = tiles_repeat.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 5).all() + + tiles_repeat = TensorTiler2D.simple_tiler( + (9, 4), (3, 2), tile_col_major=True, pattern_repeat=5 + ) + tile_repeat0_0 = TensorTile( + (9, 4), offset=0, sizes=[1, 5, 2, 3], strides=[0, 0, 1, 4] + ) + assert tiles_repeat[0] == tile_repeat0_0 + + # fmt: off + ref_access_order_tensor = np.array([ + [ 24, 27, 54, 57], + [ 25, 28, 55, 58], + [ 26, 29, 56, 59], + [ 84, 87, 114, 117], + [ 85, 88, 115, 118], + [ 86, 89, 116, 119], + [144, 147, 174, 177], + [145, 148, 175, 178], + [146, 149, 176, 179]]) + # fmt: on + access_order, access_count = tiles_repeat.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 5).all() + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: simple_tiler_invalid +@construct_test +def simple_tiler_invalid(): + try: + tiles = TensorTiler2D.simple_tiler( + (), (3, 2), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tensor dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (10, 9, 4), (3, 2), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too many tensor dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (9, 4), (3, -1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (9, 4), (3,), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too few tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (9, 4), (1, 1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too many tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (9, 4), (3, 2), tile_col_major=True, pattern_repeat=0 + ) + raise ValueError("Invalid repeat.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (9, 4), (4, 2), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Indivisible tile (height)") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.simple_tiler( + (9, 4), (3, 3), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Indivisible tile (width)") + except ValueError: + # good + pass + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/step_tiler.py b/test/python/tensortiler/step_tiler.py new file mode 100644 index 0000000000..11821c60fe --- /dev/null +++ b/test/python/tensortiler/step_tiler.py @@ -0,0 +1,857 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile, TensorTiler2D +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: step_tiler +@construct_test +def step_tiler(): + # Start with Step == (1, 1) + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(2, 2), + tile_group_steps=(1, 1), + ) + assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[64, 2, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=4, sizes=[2, 2, 2, 2], strides=[64, 2, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=392, sizes=[2, 2, 2, 2], strides=[64, 2, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 4, 5, 16, 17, 20, 21, 32, 33, 36, 37, 48, 49, 52, 53, 64, 65, 68, 69, 80, 81, 84, 85, 96, 97, 100, 101, 112, 113, 116, 117], + [ 2, 3, 6, 7, 18, 19, 22, 23, 34, 35, 38, 39, 50, 51, 54, 55, 66, 67, 70, 71, 82, 83, 86, 87, 98, 99, 102, 103, 114, 115, 118, 119], + [ 8, 9, 12, 13, 24, 25, 28, 29, 40, 41, 44, 45, 56, 57, 60, 61, 72, 73, 76, 77, 88, 89, 92, 93, 104, 105, 108, 109, 120, 121, 124, 125], + [ 10, 11, 14, 15, 26, 27, 30, 31, 42, 43, 46, 47, 58, 59, 62, 63, 74, 75, 78, 79, 90, 91, 94, 95, 106, 107, 110, 111, 122, 123, 126, 127], + [ 128, 129, 132, 133, 144, 145, 148, 149, 160, 161, 164, 165, 176, 177, 180, 181, 192, 193, 196, 197, 208, 209, 212, 213, 224, 225, 228, 229, 240, 241, 244, 245], + [ 130, 131, 134, 135, 146, 147, 150, 151, 162, 163, 166, 167, 178, 179, 182, 183, 194, 195, 198, 199, 210, 211, 214, 215, 226, 227, 230, 231, 242, 243, 246, 247], + [ 136, 137, 140, 141, 152, 153, 156, 157, 168, 169, 172, 173, 184, 185, 188, 189, 200, 201, 204, 205, 216, 217, 220, 221, 232, 233, 236, 237, 248, 249, 252, 253], + [ 138, 139, 142, 143, 154, 155, 158, 159, 170, 171, 174, 175, 186, 187, 190, 191, 202, 203, 206, 207, 218, 219, 222, 223, 234, 235, 238, 239, 250, 251, 254, 255], + [ 256, 257, 260, 261, 272, 273, 276, 277, 288, 289, 292, 293, 304, 305, 308, 309, 320, 321, 324, 325, 336, 337, 340, 341, 352, 353, 356, 357, 368, 369, 372, 373], + [ 258, 259, 262, 263, 274, 275, 278, 279, 290, 291, 294, 295, 306, 307, 310, 311, 322, 323, 326, 327, 338, 339, 342, 343, 354, 355, 358, 359, 370, 371, 374, 375], + [ 264, 265, 268, 269, 280, 281, 284, 285, 296, 297, 300, 301, 312, 313, 316, 317, 328, 329, 332, 333, 344, 345, 348, 349, 360, 361, 364, 365, 376, 377, 380, 381], + [ 266, 267, 270, 271, 282, 283, 286, 287, 298, 299, 302, 303, 314, 315, 318, 319, 330, 331, 334, 335, 346, 347, 350, 351, 362, 363, 366, 367, 378, 379, 382, 383], + [ 384, 385, 388, 389, 400, 401, 404, 405, 416, 417, 420, 421, 432, 433, 436, 437, 448, 449, 452, 453, 464, 465, 468, 469, 480, 481, 484, 485, 496, 497, 500, 501], + [ 386, 387, 390, 391, 402, 403, 406, 407, 418, 419, 422, 423, 434, 435, 438, 439, 450, 451, 454, 455, 466, 467, 470, 471, 482, 483, 486, 487, 498, 499, 502, 503], + [ 392, 393, 396, 397, 408, 409, 412, 413, 424, 425, 428, 429, 440, 441, 444, 445, 456, 457, 460, 461, 472, 473, 476, 477, 488, 489, 492, 493, 504, 505, 508, 509], + [ 394, 395, 398, 399, 410, 411, 414, 415, 426, 427, 430, 431, 442, 443, 446, 447, 458, 459, 462, 463, 474, 475, 478, 479, 490, 491, 494, 495, 506, 507, 510, 511], + [ 512, 513, 516, 517, 528, 529, 532, 533, 544, 545, 548, 549, 560, 561, 564, 565, 576, 577, 580, 581, 592, 593, 596, 597, 608, 609, 612, 613, 624, 625, 628, 629], + [ 514, 515, 518, 519, 530, 531, 534, 535, 546, 547, 550, 551, 562, 563, 566, 567, 578, 579, 582, 583, 594, 595, 598, 599, 610, 611, 614, 615, 626, 627, 630, 631], + [ 520, 521, 524, 525, 536, 537, 540, 541, 552, 553, 556, 557, 568, 569, 572, 573, 584, 585, 588, 589, 600, 601, 604, 605, 616, 617, 620, 621, 632, 633, 636, 637], + [ 522, 523, 526, 527, 538, 539, 542, 543, 554, 555, 558, 559, 570, 571, 574, 575, 586, 587, 590, 591, 602, 603, 606, 607, 618, 619, 622, 623, 634, 635, 638, 639], + [ 640, 641, 644, 645, 656, 657, 660, 661, 672, 673, 676, 677, 688, 689, 692, 693, 704, 705, 708, 709, 720, 721, 724, 725, 736, 737, 740, 741, 752, 753, 756, 757], + [ 642, 643, 646, 647, 658, 659, 662, 663, 674, 675, 678, 679, 690, 691, 694, 695, 706, 707, 710, 711, 722, 723, 726, 727, 738, 739, 742, 743, 754, 755, 758, 759], + [ 648, 649, 652, 653, 664, 665, 668, 669, 680, 681, 684, 685, 696, 697, 700, 701, 712, 713, 716, 717, 728, 729, 732, 733, 744, 745, 748, 749, 760, 761, 764, 765], + [ 650, 651, 654, 655, 666, 667, 670, 671, 682, 683, 686, 687, 698, 699, 702, 703, 714, 715, 718, 719, 730, 731, 734, 735, 746, 747, 750, 751, 762, 763, 766, 767], + [ 768, 769, 772, 773, 784, 785, 788, 789, 800, 801, 804, 805, 816, 817, 820, 821, 832, 833, 836, 837, 848, 849, 852, 853, 864, 865, 868, 869, 880, 881, 884, 885], + [ 770, 771, 774, 775, 786, 787, 790, 791, 802, 803, 806, 807, 818, 819, 822, 823, 834, 835, 838, 839, 850, 851, 854, 855, 866, 867, 870, 871, 882, 883, 886, 887], + [ 776, 777, 780, 781, 792, 793, 796, 797, 808, 809, 812, 813, 824, 825, 828, 829, 840, 841, 844, 845, 856, 857, 860, 861, 872, 873, 876, 877, 888, 889, 892, 893], + [ 778, 779, 782, 783, 794, 795, 798, 799, 810, 811, 814, 815, 826, 827, 830, 831, 842, 843, 846, 847, 858, 859, 862, 863, 874, 875, 878, 879, 890, 891, 894, 895], + [ 896, 897, 900, 901, 912, 913, 916, 917, 928, 929, 932, 933, 944, 945, 948, 949, 960, 961, 964, 965, 976, 977, 980, 981, 992, 993, 996, 997, 1008, 1009, 1012, 1013], + [ 898, 899, 902, 903, 914, 915, 918, 919, 930, 931, 934, 935, 946, 947, 950, 951, 962, 963, 966, 967, 978, 979, 982, 983, 994, 995, 998, 999, 1010, 1011, 1014, 1015], + [ 904, 905, 908, 909, 920, 921, 924, 925, 936, 937, 940, 941, 952, 953, 956, 957, 968, 969, 972, 973, 984, 985, 988, 989, 1000, 1001, 1004, 1005, 1016, 1017, 1020, 1021], + [ 906, 907, 910, 911, 922, 923, 926, 927, 938, 939, 942, 943, 954, 955, 958, 959, 970, 971, 974, 975, 986, 987, 990, 991, 1002, 1003, 1006, 1007, 1018, 1019, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Step == (2, 1) + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(2, 2), + tile_group_steps=(2, 1), + ) + assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=4, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=328, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=860, sizes=[2, 2, 2, 2], strides=[128, 2, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 4, 5, 16, 17, 20, 21, 32, 33, 36, 37, 48, 49, 52, 53, 64, 65, 68, 69, 80, 81, 84, 85, 96, 97, 100, 101, 112, 113, 116, 117], + [ 2, 3, 6, 7, 18, 19, 22, 23, 34, 35, 38, 39, 50, 51, 54, 55, 66, 67, 70, 71, 82, 83, 86, 87, 98, 99, 102, 103, 114, 115, 118, 119], + [ 128, 129, 132, 133, 144, 145, 148, 149, 160, 161, 164, 165, 176, 177, 180, 181, 192, 193, 196, 197, 208, 209, 212, 213, 224, 225, 228, 229, 240, 241, 244, 245], + [ 130, 131, 134, 135, 146, 147, 150, 151, 162, 163, 166, 167, 178, 179, 182, 183, 194, 195, 198, 199, 210, 211, 214, 215, 226, 227, 230, 231, 242, 243, 246, 247], + [ 8, 9, 12, 13, 24, 25, 28, 29, 40, 41, 44, 45, 56, 57, 60, 61, 72, 73, 76, 77, 88, 89, 92, 93, 104, 105, 108, 109, 120, 121, 124, 125], + [ 10, 11, 14, 15, 26, 27, 30, 31, 42, 43, 46, 47, 58, 59, 62, 63, 74, 75, 78, 79, 90, 91, 94, 95, 106, 107, 110, 111, 122, 123, 126, 127], + [ 136, 137, 140, 141, 152, 153, 156, 157, 168, 169, 172, 173, 184, 185, 188, 189, 200, 201, 204, 205, 216, 217, 220, 221, 232, 233, 236, 237, 248, 249, 252, 253], + [ 138, 139, 142, 143, 154, 155, 158, 159, 170, 171, 174, 175, 186, 187, 190, 191, 202, 203, 206, 207, 218, 219, 222, 223, 234, 235, 238, 239, 250, 251, 254, 255], + [ 256, 257, 260, 261, 272, 273, 276, 277, 288, 289, 292, 293, 304, 305, 308, 309, 320, 321, 324, 325, 336, 337, 340, 341, 352, 353, 356, 357, 368, 369, 372, 373], + [ 258, 259, 262, 263, 274, 275, 278, 279, 290, 291, 294, 295, 306, 307, 310, 311, 322, 323, 326, 327, 338, 339, 342, 343, 354, 355, 358, 359, 370, 371, 374, 375], + [ 384, 385, 388, 389, 400, 401, 404, 405, 416, 417, 420, 421, 432, 433, 436, 437, 448, 449, 452, 453, 464, 465, 468, 469, 480, 481, 484, 485, 496, 497, 500, 501], + [ 386, 387, 390, 391, 402, 403, 406, 407, 418, 419, 422, 423, 434, 435, 438, 439, 450, 451, 454, 455, 466, 467, 470, 471, 482, 483, 486, 487, 498, 499, 502, 503], + [ 264, 265, 268, 269, 280, 281, 284, 285, 296, 297, 300, 301, 312, 313, 316, 317, 328, 329, 332, 333, 344, 345, 348, 349, 360, 361, 364, 365, 376, 377, 380, 381], + [ 266, 267, 270, 271, 282, 283, 286, 287, 298, 299, 302, 303, 314, 315, 318, 319, 330, 331, 334, 335, 346, 347, 350, 351, 362, 363, 366, 367, 378, 379, 382, 383], + [ 392, 393, 396, 397, 408, 409, 412, 413, 424, 425, 428, 429, 440, 441, 444, 445, 456, 457, 460, 461, 472, 473, 476, 477, 488, 489, 492, 493, 504, 505, 508, 509], + [ 394, 395, 398, 399, 410, 411, 414, 415, 426, 427, 430, 431, 442, 443, 446, 447, 458, 459, 462, 463, 474, 475, 478, 479, 490, 491, 494, 495, 506, 507, 510, 511], + [ 512, 513, 516, 517, 528, 529, 532, 533, 544, 545, 548, 549, 560, 561, 564, 565, 576, 577, 580, 581, 592, 593, 596, 597, 608, 609, 612, 613, 624, 625, 628, 629], + [ 514, 515, 518, 519, 530, 531, 534, 535, 546, 547, 550, 551, 562, 563, 566, 567, 578, 579, 582, 583, 594, 595, 598, 599, 610, 611, 614, 615, 626, 627, 630, 631], + [ 640, 641, 644, 645, 656, 657, 660, 661, 672, 673, 676, 677, 688, 689, 692, 693, 704, 705, 708, 709, 720, 721, 724, 725, 736, 737, 740, 741, 752, 753, 756, 757], + [ 642, 643, 646, 647, 658, 659, 662, 663, 674, 675, 678, 679, 690, 691, 694, 695, 706, 707, 710, 711, 722, 723, 726, 727, 738, 739, 742, 743, 754, 755, 758, 759], + [ 520, 521, 524, 525, 536, 537, 540, 541, 552, 553, 556, 557, 568, 569, 572, 573, 584, 585, 588, 589, 600, 601, 604, 605, 616, 617, 620, 621, 632, 633, 636, 637], + [ 522, 523, 526, 527, 538, 539, 542, 543, 554, 555, 558, 559, 570, 571, 574, 575, 586, 587, 590, 591, 602, 603, 606, 607, 618, 619, 622, 623, 634, 635, 638, 639], + [ 648, 649, 652, 653, 664, 665, 668, 669, 680, 681, 684, 685, 696, 697, 700, 701, 712, 713, 716, 717, 728, 729, 732, 733, 744, 745, 748, 749, 760, 761, 764, 765], + [ 650, 651, 654, 655, 666, 667, 670, 671, 682, 683, 686, 687, 698, 699, 702, 703, 714, 715, 718, 719, 730, 731, 734, 735, 746, 747, 750, 751, 762, 763, 766, 767], + [ 768, 769, 772, 773, 784, 785, 788, 789, 800, 801, 804, 805, 816, 817, 820, 821, 832, 833, 836, 837, 848, 849, 852, 853, 864, 865, 868, 869, 880, 881, 884, 885], + [ 770, 771, 774, 775, 786, 787, 790, 791, 802, 803, 806, 807, 818, 819, 822, 823, 834, 835, 838, 839, 850, 851, 854, 855, 866, 867, 870, 871, 882, 883, 886, 887], + [ 896, 897, 900, 901, 912, 913, 916, 917, 928, 929, 932, 933, 944, 945, 948, 949, 960, 961, 964, 965, 976, 977, 980, 981, 992, 993, 996, 997, 1008, 1009, 1012, 1013], + [ 898, 899, 902, 903, 914, 915, 918, 919, 930, 931, 934, 935, 946, 947, 950, 951, 962, 963, 966, 967, 978, 979, 982, 983, 994, 995, 998, 999, 1010, 1011, 1014, 1015], + [ 776, 777, 780, 781, 792, 793, 796, 797, 808, 809, 812, 813, 824, 825, 828, 829, 840, 841, 844, 845, 856, 857, 860, 861, 872, 873, 876, 877, 888, 889, 892, 893], + [ 778, 779, 782, 783, 794, 795, 798, 799, 810, 811, 814, 815, 826, 827, 830, 831, 842, 843, 846, 847, 858, 859, 862, 863, 874, 875, 878, 879, 890, 891, 894, 895], + [ 904, 905, 908, 909, 920, 921, 924, 925, 936, 937, 940, 941, 952, 953, 956, 957, 968, 969, 972, 973, 984, 985, 988, 989, 1000, 1001, 1004, 1005, 1016, 1017, 1020, 1021], + [ 906, 907, 910, 911, 922, 923, 926, 927, 938, 939, 942, 943, 954, 955, 958, 959, 970, 971, 974, 975, 986, 987, 990, 991, 1002, 1003, 1006, 1007, 1018, 1019, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Step == (1, 2) + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(2, 2), + tile_group_steps=(1, 2), + ) + assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=392, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=922, sizes=[2, 2, 2, 2], strides=[64, 4, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 16, 17, 4, 5, 20, 21, 32, 33, 48, 49, 36, 37, 52, 53, 64, 65, 80, 81, 68, 69, 84, 85, 96, 97, 112, 113, 100, 101, 116, 117], + [ 2, 3, 18, 19, 6, 7, 22, 23, 34, 35, 50, 51, 38, 39, 54, 55, 66, 67, 82, 83, 70, 71, 86, 87, 98, 99, 114, 115, 102, 103, 118, 119], + [ 8, 9, 24, 25, 12, 13, 28, 29, 40, 41, 56, 57, 44, 45, 60, 61, 72, 73, 88, 89, 76, 77, 92, 93, 104, 105, 120, 121, 108, 109, 124, 125], + [ 10, 11, 26, 27, 14, 15, 30, 31, 42, 43, 58, 59, 46, 47, 62, 63, 74, 75, 90, 91, 78, 79, 94, 95, 106, 107, 122, 123, 110, 111, 126, 127], + [ 128, 129, 144, 145, 132, 133, 148, 149, 160, 161, 176, 177, 164, 165, 180, 181, 192, 193, 208, 209, 196, 197, 212, 213, 224, 225, 240, 241, 228, 229, 244, 245], + [ 130, 131, 146, 147, 134, 135, 150, 151, 162, 163, 178, 179, 166, 167, 182, 183, 194, 195, 210, 211, 198, 199, 214, 215, 226, 227, 242, 243, 230, 231, 246, 247], + [ 136, 137, 152, 153, 140, 141, 156, 157, 168, 169, 184, 185, 172, 173, 188, 189, 200, 201, 216, 217, 204, 205, 220, 221, 232, 233, 248, 249, 236, 237, 252, 253], + [ 138, 139, 154, 155, 142, 143, 158, 159, 170, 171, 186, 187, 174, 175, 190, 191, 202, 203, 218, 219, 206, 207, 222, 223, 234, 235, 250, 251, 238, 239, 254, 255], + [ 256, 257, 272, 273, 260, 261, 276, 277, 288, 289, 304, 305, 292, 293, 308, 309, 320, 321, 336, 337, 324, 325, 340, 341, 352, 353, 368, 369, 356, 357, 372, 373], + [ 258, 259, 274, 275, 262, 263, 278, 279, 290, 291, 306, 307, 294, 295, 310, 311, 322, 323, 338, 339, 326, 327, 342, 343, 354, 355, 370, 371, 358, 359, 374, 375], + [ 264, 265, 280, 281, 268, 269, 284, 285, 296, 297, 312, 313, 300, 301, 316, 317, 328, 329, 344, 345, 332, 333, 348, 349, 360, 361, 376, 377, 364, 365, 380, 381], + [ 266, 267, 282, 283, 270, 271, 286, 287, 298, 299, 314, 315, 302, 303, 318, 319, 330, 331, 346, 347, 334, 335, 350, 351, 362, 363, 378, 379, 366, 367, 382, 383], + [ 384, 385, 400, 401, 388, 389, 404, 405, 416, 417, 432, 433, 420, 421, 436, 437, 448, 449, 464, 465, 452, 453, 468, 469, 480, 481, 496, 497, 484, 485, 500, 501], + [ 386, 387, 402, 403, 390, 391, 406, 407, 418, 419, 434, 435, 422, 423, 438, 439, 450, 451, 466, 467, 454, 455, 470, 471, 482, 483, 498, 499, 486, 487, 502, 503], + [ 392, 393, 408, 409, 396, 397, 412, 413, 424, 425, 440, 441, 428, 429, 444, 445, 456, 457, 472, 473, 460, 461, 476, 477, 488, 489, 504, 505, 492, 493, 508, 509], + [ 394, 395, 410, 411, 398, 399, 414, 415, 426, 427, 442, 443, 430, 431, 446, 447, 458, 459, 474, 475, 462, 463, 478, 479, 490, 491, 506, 507, 494, 495, 510, 511], + [ 512, 513, 528, 529, 516, 517, 532, 533, 544, 545, 560, 561, 548, 549, 564, 565, 576, 577, 592, 593, 580, 581, 596, 597, 608, 609, 624, 625, 612, 613, 628, 629], + [ 514, 515, 530, 531, 518, 519, 534, 535, 546, 547, 562, 563, 550, 551, 566, 567, 578, 579, 594, 595, 582, 583, 598, 599, 610, 611, 626, 627, 614, 615, 630, 631], + [ 520, 521, 536, 537, 524, 525, 540, 541, 552, 553, 568, 569, 556, 557, 572, 573, 584, 585, 600, 601, 588, 589, 604, 605, 616, 617, 632, 633, 620, 621, 636, 637], + [ 522, 523, 538, 539, 526, 527, 542, 543, 554, 555, 570, 571, 558, 559, 574, 575, 586, 587, 602, 603, 590, 591, 606, 607, 618, 619, 634, 635, 622, 623, 638, 639], + [ 640, 641, 656, 657, 644, 645, 660, 661, 672, 673, 688, 689, 676, 677, 692, 693, 704, 705, 720, 721, 708, 709, 724, 725, 736, 737, 752, 753, 740, 741, 756, 757], + [ 642, 643, 658, 659, 646, 647, 662, 663, 674, 675, 690, 691, 678, 679, 694, 695, 706, 707, 722, 723, 710, 711, 726, 727, 738, 739, 754, 755, 742, 743, 758, 759], + [ 648, 649, 664, 665, 652, 653, 668, 669, 680, 681, 696, 697, 684, 685, 700, 701, 712, 713, 728, 729, 716, 717, 732, 733, 744, 745, 760, 761, 748, 749, 764, 765], + [ 650, 651, 666, 667, 654, 655, 670, 671, 682, 683, 698, 699, 686, 687, 702, 703, 714, 715, 730, 731, 718, 719, 734, 735, 746, 747, 762, 763, 750, 751, 766, 767], + [ 768, 769, 784, 785, 772, 773, 788, 789, 800, 801, 816, 817, 804, 805, 820, 821, 832, 833, 848, 849, 836, 837, 852, 853, 864, 865, 880, 881, 868, 869, 884, 885], + [ 770, 771, 786, 787, 774, 775, 790, 791, 802, 803, 818, 819, 806, 807, 822, 823, 834, 835, 850, 851, 838, 839, 854, 855, 866, 867, 882, 883, 870, 871, 886, 887], + [ 776, 777, 792, 793, 780, 781, 796, 797, 808, 809, 824, 825, 812, 813, 828, 829, 840, 841, 856, 857, 844, 845, 860, 861, 872, 873, 888, 889, 876, 877, 892, 893], + [ 778, 779, 794, 795, 782, 783, 798, 799, 810, 811, 826, 827, 814, 815, 830, 831, 842, 843, 858, 859, 846, 847, 862, 863, 874, 875, 890, 891, 878, 879, 894, 895], + [ 896, 897, 912, 913, 900, 901, 916, 917, 928, 929, 944, 945, 932, 933, 948, 949, 960, 961, 976, 977, 964, 965, 980, 981, 992, 993, 1008, 1009, 996, 997, 1012, 1013], + [ 898, 899, 914, 915, 902, 903, 918, 919, 930, 931, 946, 947, 934, 935, 950, 951, 962, 963, 978, 979, 966, 967, 982, 983, 994, 995, 1010, 1011, 998, 999, 1014, 1015], + [ 904, 905, 920, 921, 908, 909, 924, 925, 936, 937, 952, 953, 940, 941, 956, 957, 968, 969, 984, 985, 972, 973, 988, 989, 1000, 1001, 1016, 1017, 1004, 1005, 1020, 1021], + [ 906, 907, 922, 923, 910, 911, 926, 927, 938, 939, 954, 955, 942, 943, 958, 959, 970, 971, 986, 987, 974, 975, 990, 991, 1002, 1003, 1018, 1019, 1006, 1007, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Step == (2, 2) + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(2, 2), + tile_group_steps=(2, 2), + ) + assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=328, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=858, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 16, 17, 4, 5, 20, 21, 32, 33, 48, 49, 36, 37, 52, 53, 64, 65, 80, 81, 68, 69, 84, 85, 96, 97, 112, 113, 100, 101, 116, 117], + [ 2, 3, 18, 19, 6, 7, 22, 23, 34, 35, 50, 51, 38, 39, 54, 55, 66, 67, 82, 83, 70, 71, 86, 87, 98, 99, 114, 115, 102, 103, 118, 119], + [ 128, 129, 144, 145, 132, 133, 148, 149, 160, 161, 176, 177, 164, 165, 180, 181, 192, 193, 208, 209, 196, 197, 212, 213, 224, 225, 240, 241, 228, 229, 244, 245], + [ 130, 131, 146, 147, 134, 135, 150, 151, 162, 163, 178, 179, 166, 167, 182, 183, 194, 195, 210, 211, 198, 199, 214, 215, 226, 227, 242, 243, 230, 231, 246, 247], + [ 8, 9, 24, 25, 12, 13, 28, 29, 40, 41, 56, 57, 44, 45, 60, 61, 72, 73, 88, 89, 76, 77, 92, 93, 104, 105, 120, 121, 108, 109, 124, 125], + [ 10, 11, 26, 27, 14, 15, 30, 31, 42, 43, 58, 59, 46, 47, 62, 63, 74, 75, 90, 91, 78, 79, 94, 95, 106, 107, 122, 123, 110, 111, 126, 127], + [ 136, 137, 152, 153, 140, 141, 156, 157, 168, 169, 184, 185, 172, 173, 188, 189, 200, 201, 216, 217, 204, 205, 220, 221, 232, 233, 248, 249, 236, 237, 252, 253], + [ 138, 139, 154, 155, 142, 143, 158, 159, 170, 171, 186, 187, 174, 175, 190, 191, 202, 203, 218, 219, 206, 207, 222, 223, 234, 235, 250, 251, 238, 239, 254, 255], + [ 256, 257, 272, 273, 260, 261, 276, 277, 288, 289, 304, 305, 292, 293, 308, 309, 320, 321, 336, 337, 324, 325, 340, 341, 352, 353, 368, 369, 356, 357, 372, 373], + [ 258, 259, 274, 275, 262, 263, 278, 279, 290, 291, 306, 307, 294, 295, 310, 311, 322, 323, 338, 339, 326, 327, 342, 343, 354, 355, 370, 371, 358, 359, 374, 375], + [ 384, 385, 400, 401, 388, 389, 404, 405, 416, 417, 432, 433, 420, 421, 436, 437, 448, 449, 464, 465, 452, 453, 468, 469, 480, 481, 496, 497, 484, 485, 500, 501], + [ 386, 387, 402, 403, 390, 391, 406, 407, 418, 419, 434, 435, 422, 423, 438, 439, 450, 451, 466, 467, 454, 455, 470, 471, 482, 483, 498, 499, 486, 487, 502, 503], + [ 264, 265, 280, 281, 268, 269, 284, 285, 296, 297, 312, 313, 300, 301, 316, 317, 328, 329, 344, 345, 332, 333, 348, 349, 360, 361, 376, 377, 364, 365, 380, 381], + [ 266, 267, 282, 283, 270, 271, 286, 287, 298, 299, 314, 315, 302, 303, 318, 319, 330, 331, 346, 347, 334, 335, 350, 351, 362, 363, 378, 379, 366, 367, 382, 383], + [ 392, 393, 408, 409, 396, 397, 412, 413, 424, 425, 440, 441, 428, 429, 444, 445, 456, 457, 472, 473, 460, 461, 476, 477, 488, 489, 504, 505, 492, 493, 508, 509], + [ 394, 395, 410, 411, 398, 399, 414, 415, 426, 427, 442, 443, 430, 431, 446, 447, 458, 459, 474, 475, 462, 463, 478, 479, 490, 491, 506, 507, 494, 495, 510, 511], + [ 512, 513, 528, 529, 516, 517, 532, 533, 544, 545, 560, 561, 548, 549, 564, 565, 576, 577, 592, 593, 580, 581, 596, 597, 608, 609, 624, 625, 612, 613, 628, 629], + [ 514, 515, 530, 531, 518, 519, 534, 535, 546, 547, 562, 563, 550, 551, 566, 567, 578, 579, 594, 595, 582, 583, 598, 599, 610, 611, 626, 627, 614, 615, 630, 631], + [ 640, 641, 656, 657, 644, 645, 660, 661, 672, 673, 688, 689, 676, 677, 692, 693, 704, 705, 720, 721, 708, 709, 724, 725, 736, 737, 752, 753, 740, 741, 756, 757], + [ 642, 643, 658, 659, 646, 647, 662, 663, 674, 675, 690, 691, 678, 679, 694, 695, 706, 707, 722, 723, 710, 711, 726, 727, 738, 739, 754, 755, 742, 743, 758, 759], + [ 520, 521, 536, 537, 524, 525, 540, 541, 552, 553, 568, 569, 556, 557, 572, 573, 584, 585, 600, 601, 588, 589, 604, 605, 616, 617, 632, 633, 620, 621, 636, 637], + [ 522, 523, 538, 539, 526, 527, 542, 543, 554, 555, 570, 571, 558, 559, 574, 575, 586, 587, 602, 603, 590, 591, 606, 607, 618, 619, 634, 635, 622, 623, 638, 639], + [ 648, 649, 664, 665, 652, 653, 668, 669, 680, 681, 696, 697, 684, 685, 700, 701, 712, 713, 728, 729, 716, 717, 732, 733, 744, 745, 760, 761, 748, 749, 764, 765], + [ 650, 651, 666, 667, 654, 655, 670, 671, 682, 683, 698, 699, 686, 687, 702, 703, 714, 715, 730, 731, 718, 719, 734, 735, 746, 747, 762, 763, 750, 751, 766, 767], + [ 768, 769, 784, 785, 772, 773, 788, 789, 800, 801, 816, 817, 804, 805, 820, 821, 832, 833, 848, 849, 836, 837, 852, 853, 864, 865, 880, 881, 868, 869, 884, 885], + [ 770, 771, 786, 787, 774, 775, 790, 791, 802, 803, 818, 819, 806, 807, 822, 823, 834, 835, 850, 851, 838, 839, 854, 855, 866, 867, 882, 883, 870, 871, 886, 887], + [ 896, 897, 912, 913, 900, 901, 916, 917, 928, 929, 944, 945, 932, 933, 948, 949, 960, 961, 976, 977, 964, 965, 980, 981, 992, 993, 1008, 1009, 996, 997, 1012, 1013], + [ 898, 899, 914, 915, 902, 903, 918, 919, 930, 931, 946, 947, 934, 935, 950, 951, 962, 963, 978, 979, 966, 967, 982, 983, 994, 995, 1010, 1011, 998, 999, 1014, 1015], + [ 776, 777, 792, 793, 780, 781, 796, 797, 808, 809, 824, 825, 812, 813, 828, 829, 840, 841, 856, 857, 844, 845, 860, 861, 872, 873, 888, 889, 876, 877, 892, 893], + [ 778, 779, 794, 795, 782, 783, 798, 799, 810, 811, 826, 827, 814, 815, 830, 831, 842, 843, 858, 859, 846, 847, 862, 863, 874, 875, 890, 891, 878, 879, 894, 895], + [ 904, 905, 920, 921, 908, 909, 924, 925, 936, 937, 952, 953, 940, 941, 956, 957, 968, 969, 984, 985, 972, 973, 988, 989, 1000, 1001, 1016, 1017, 1004, 1005, 1020, 1021], + [ 906, 907, 922, 923, 910, 911, 926, 927, 938, 939, 954, 955, 942, 943, 958, 959, 970, 971, 986, 987, 974, 975, 990, 991, 1002, 1003, 1018, 1019, 1006, 1007, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Step == (2, 2) + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(2, 2), + tile_group_steps=(2, 2), + ) + assert len(tiles) == (32 // (2 * 2)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=328, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=858, sizes=[2, 2, 2, 2], strides=[128, 4, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 16, 17, 4, 5, 20, 21, 32, 33, 48, 49, 36, 37, 52, 53, 64, 65, 80, 81, 68, 69, 84, 85, 96, 97, 112, 113, 100, 101, 116, 117], + [ 2, 3, 18, 19, 6, 7, 22, 23, 34, 35, 50, 51, 38, 39, 54, 55, 66, 67, 82, 83, 70, 71, 86, 87, 98, 99, 114, 115, 102, 103, 118, 119], + [ 128, 129, 144, 145, 132, 133, 148, 149, 160, 161, 176, 177, 164, 165, 180, 181, 192, 193, 208, 209, 196, 197, 212, 213, 224, 225, 240, 241, 228, 229, 244, 245], + [ 130, 131, 146, 147, 134, 135, 150, 151, 162, 163, 178, 179, 166, 167, 182, 183, 194, 195, 210, 211, 198, 199, 214, 215, 226, 227, 242, 243, 230, 231, 246, 247], + [ 8, 9, 24, 25, 12, 13, 28, 29, 40, 41, 56, 57, 44, 45, 60, 61, 72, 73, 88, 89, 76, 77, 92, 93, 104, 105, 120, 121, 108, 109, 124, 125], + [ 10, 11, 26, 27, 14, 15, 30, 31, 42, 43, 58, 59, 46, 47, 62, 63, 74, 75, 90, 91, 78, 79, 94, 95, 106, 107, 122, 123, 110, 111, 126, 127], + [ 136, 137, 152, 153, 140, 141, 156, 157, 168, 169, 184, 185, 172, 173, 188, 189, 200, 201, 216, 217, 204, 205, 220, 221, 232, 233, 248, 249, 236, 237, 252, 253], + [ 138, 139, 154, 155, 142, 143, 158, 159, 170, 171, 186, 187, 174, 175, 190, 191, 202, 203, 218, 219, 206, 207, 222, 223, 234, 235, 250, 251, 238, 239, 254, 255], + [ 256, 257, 272, 273, 260, 261, 276, 277, 288, 289, 304, 305, 292, 293, 308, 309, 320, 321, 336, 337, 324, 325, 340, 341, 352, 353, 368, 369, 356, 357, 372, 373], + [ 258, 259, 274, 275, 262, 263, 278, 279, 290, 291, 306, 307, 294, 295, 310, 311, 322, 323, 338, 339, 326, 327, 342, 343, 354, 355, 370, 371, 358, 359, 374, 375], + [ 384, 385, 400, 401, 388, 389, 404, 405, 416, 417, 432, 433, 420, 421, 436, 437, 448, 449, 464, 465, 452, 453, 468, 469, 480, 481, 496, 497, 484, 485, 500, 501], + [ 386, 387, 402, 403, 390, 391, 406, 407, 418, 419, 434, 435, 422, 423, 438, 439, 450, 451, 466, 467, 454, 455, 470, 471, 482, 483, 498, 499, 486, 487, 502, 503], + [ 264, 265, 280, 281, 268, 269, 284, 285, 296, 297, 312, 313, 300, 301, 316, 317, 328, 329, 344, 345, 332, 333, 348, 349, 360, 361, 376, 377, 364, 365, 380, 381], + [ 266, 267, 282, 283, 270, 271, 286, 287, 298, 299, 314, 315, 302, 303, 318, 319, 330, 331, 346, 347, 334, 335, 350, 351, 362, 363, 378, 379, 366, 367, 382, 383], + [ 392, 393, 408, 409, 396, 397, 412, 413, 424, 425, 440, 441, 428, 429, 444, 445, 456, 457, 472, 473, 460, 461, 476, 477, 488, 489, 504, 505, 492, 493, 508, 509], + [ 394, 395, 410, 411, 398, 399, 414, 415, 426, 427, 442, 443, 430, 431, 446, 447, 458, 459, 474, 475, 462, 463, 478, 479, 490, 491, 506, 507, 494, 495, 510, 511], + [ 512, 513, 528, 529, 516, 517, 532, 533, 544, 545, 560, 561, 548, 549, 564, 565, 576, 577, 592, 593, 580, 581, 596, 597, 608, 609, 624, 625, 612, 613, 628, 629], + [ 514, 515, 530, 531, 518, 519, 534, 535, 546, 547, 562, 563, 550, 551, 566, 567, 578, 579, 594, 595, 582, 583, 598, 599, 610, 611, 626, 627, 614, 615, 630, 631], + [ 640, 641, 656, 657, 644, 645, 660, 661, 672, 673, 688, 689, 676, 677, 692, 693, 704, 705, 720, 721, 708, 709, 724, 725, 736, 737, 752, 753, 740, 741, 756, 757], + [ 642, 643, 658, 659, 646, 647, 662, 663, 674, 675, 690, 691, 678, 679, 694, 695, 706, 707, 722, 723, 710, 711, 726, 727, 738, 739, 754, 755, 742, 743, 758, 759], + [ 520, 521, 536, 537, 524, 525, 540, 541, 552, 553, 568, 569, 556, 557, 572, 573, 584, 585, 600, 601, 588, 589, 604, 605, 616, 617, 632, 633, 620, 621, 636, 637], + [ 522, 523, 538, 539, 526, 527, 542, 543, 554, 555, 570, 571, 558, 559, 574, 575, 586, 587, 602, 603, 590, 591, 606, 607, 618, 619, 634, 635, 622, 623, 638, 639], + [ 648, 649, 664, 665, 652, 653, 668, 669, 680, 681, 696, 697, 684, 685, 700, 701, 712, 713, 728, 729, 716, 717, 732, 733, 744, 745, 760, 761, 748, 749, 764, 765], + [ 650, 651, 666, 667, 654, 655, 670, 671, 682, 683, 698, 699, 686, 687, 702, 703, 714, 715, 730, 731, 718, 719, 734, 735, 746, 747, 762, 763, 750, 751, 766, 767], + [ 768, 769, 784, 785, 772, 773, 788, 789, 800, 801, 816, 817, 804, 805, 820, 821, 832, 833, 848, 849, 836, 837, 852, 853, 864, 865, 880, 881, 868, 869, 884, 885], + [ 770, 771, 786, 787, 774, 775, 790, 791, 802, 803, 818, 819, 806, 807, 822, 823, 834, 835, 850, 851, 838, 839, 854, 855, 866, 867, 882, 883, 870, 871, 886, 887], + [ 896, 897, 912, 913, 900, 901, 916, 917, 928, 929, 944, 945, 932, 933, 948, 949, 960, 961, 976, 977, 964, 965, 980, 981, 992, 993, 1008, 1009, 996, 997, 1012, 1013], + [ 898, 899, 914, 915, 902, 903, 918, 919, 930, 931, 946, 947, 934, 935, 950, 951, 962, 963, 978, 979, 966, 967, 982, 983, 994, 995, 1010, 1011, 998, 999, 1014, 1015], + [ 776, 777, 792, 793, 780, 781, 796, 797, 808, 809, 824, 825, 812, 813, 828, 829, 840, 841, 856, 857, 844, 845, 860, 861, 872, 873, 888, 889, 876, 877, 892, 893], + [ 778, 779, 794, 795, 782, 783, 798, 799, 810, 811, 826, 827, 814, 815, 830, 831, 842, 843, 858, 859, 846, 847, 862, 863, 874, 875, 890, 891, 878, 879, 894, 895], + [ 904, 905, 920, 921, 908, 909, 924, 925, 936, 937, 952, 953, 940, 941, 956, 957, 968, 969, 984, 985, 972, 973, 988, 989, 1000, 1001, 1016, 1017, 1004, 1005, 1020, 1021], + [ 906, 907, 922, 923, 910, 911, 926, 927, 938, 939, 954, 955, 942, 943, 958, 959, 970, 971, 986, 987, 974, 975, 990, 991, 1002, 1003, 1018, 1019, 1006, 1007, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Repeat across column/row + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(32 // 4, 32 // 4), + tile_group_steps=(2, 2), + ) + assert len(tiles) == 4 # (32//(2*(32//4))) * (32//(2*(32//4))) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[2] == TensorTile( + (32, 32), offset=64, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1] + ) + assert tiles[3] == TensorTile( + (32, 32), offset=66, sizes=[8, 8, 2, 2], strides=[128, 4, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 256, 257, 4, 5, 260, 261, 8, 9, 264, 265, 12, 13, 268, 269, 16, 17, 272, 273, 20, 21, 276, 277, 24, 25, 280, 281, 28, 29, 284, 285], + [ 2, 3, 258, 259, 6, 7, 262, 263, 10, 11, 266, 267, 14, 15, 270, 271, 18, 19, 274, 275, 22, 23, 278, 279, 26, 27, 282, 283, 30, 31, 286, 287], + [ 512, 513, 768, 769, 516, 517, 772, 773, 520, 521, 776, 777, 524, 525, 780, 781, 528, 529, 784, 785, 532, 533, 788, 789, 536, 537, 792, 793, 540, 541, 796, 797], + [ 514, 515, 770, 771, 518, 519, 774, 775, 522, 523, 778, 779, 526, 527, 782, 783, 530, 531, 786, 787, 534, 535, 790, 791, 538, 539, 794, 795, 542, 543, 798, 799], + [ 32, 33, 288, 289, 36, 37, 292, 293, 40, 41, 296, 297, 44, 45, 300, 301, 48, 49, 304, 305, 52, 53, 308, 309, 56, 57, 312, 313, 60, 61, 316, 317], + [ 34, 35, 290, 291, 38, 39, 294, 295, 42, 43, 298, 299, 46, 47, 302, 303, 50, 51, 306, 307, 54, 55, 310, 311, 58, 59, 314, 315, 62, 63, 318, 319], + [ 544, 545, 800, 801, 548, 549, 804, 805, 552, 553, 808, 809, 556, 557, 812, 813, 560, 561, 816, 817, 564, 565, 820, 821, 568, 569, 824, 825, 572, 573, 828, 829], + [ 546, 547, 802, 803, 550, 551, 806, 807, 554, 555, 810, 811, 558, 559, 814, 815, 562, 563, 818, 819, 566, 567, 822, 823, 570, 571, 826, 827, 574, 575, 830, 831], + [ 64, 65, 320, 321, 68, 69, 324, 325, 72, 73, 328, 329, 76, 77, 332, 333, 80, 81, 336, 337, 84, 85, 340, 341, 88, 89, 344, 345, 92, 93, 348, 349], + [ 66, 67, 322, 323, 70, 71, 326, 327, 74, 75, 330, 331, 78, 79, 334, 335, 82, 83, 338, 339, 86, 87, 342, 343, 90, 91, 346, 347, 94, 95, 350, 351], + [ 576, 577, 832, 833, 580, 581, 836, 837, 584, 585, 840, 841, 588, 589, 844, 845, 592, 593, 848, 849, 596, 597, 852, 853, 600, 601, 856, 857, 604, 605, 860, 861], + [ 578, 579, 834, 835, 582, 583, 838, 839, 586, 587, 842, 843, 590, 591, 846, 847, 594, 595, 850, 851, 598, 599, 854, 855, 602, 603, 858, 859, 606, 607, 862, 863], + [ 96, 97, 352, 353, 100, 101, 356, 357, 104, 105, 360, 361, 108, 109, 364, 365, 112, 113, 368, 369, 116, 117, 372, 373, 120, 121, 376, 377, 124, 125, 380, 381], + [ 98, 99, 354, 355, 102, 103, 358, 359, 106, 107, 362, 363, 110, 111, 366, 367, 114, 115, 370, 371, 118, 119, 374, 375, 122, 123, 378, 379, 126, 127, 382, 383], + [ 608, 609, 864, 865, 612, 613, 868, 869, 616, 617, 872, 873, 620, 621, 876, 877, 624, 625, 880, 881, 628, 629, 884, 885, 632, 633, 888, 889, 636, 637, 892, 893], + [ 610, 611, 866, 867, 614, 615, 870, 871, 618, 619, 874, 875, 622, 623, 878, 879, 626, 627, 882, 883, 630, 631, 886, 887, 634, 635, 890, 891, 638, 639, 894, 895], + [ 128, 129, 384, 385, 132, 133, 388, 389, 136, 137, 392, 393, 140, 141, 396, 397, 144, 145, 400, 401, 148, 149, 404, 405, 152, 153, 408, 409, 156, 157, 412, 413], + [ 130, 131, 386, 387, 134, 135, 390, 391, 138, 139, 394, 395, 142, 143, 398, 399, 146, 147, 402, 403, 150, 151, 406, 407, 154, 155, 410, 411, 158, 159, 414, 415], + [ 640, 641, 896, 897, 644, 645, 900, 901, 648, 649, 904, 905, 652, 653, 908, 909, 656, 657, 912, 913, 660, 661, 916, 917, 664, 665, 920, 921, 668, 669, 924, 925], + [ 642, 643, 898, 899, 646, 647, 902, 903, 650, 651, 906, 907, 654, 655, 910, 911, 658, 659, 914, 915, 662, 663, 918, 919, 666, 667, 922, 923, 670, 671, 926, 927], + [ 160, 161, 416, 417, 164, 165, 420, 421, 168, 169, 424, 425, 172, 173, 428, 429, 176, 177, 432, 433, 180, 181, 436, 437, 184, 185, 440, 441, 188, 189, 444, 445], + [ 162, 163, 418, 419, 166, 167, 422, 423, 170, 171, 426, 427, 174, 175, 430, 431, 178, 179, 434, 435, 182, 183, 438, 439, 186, 187, 442, 443, 190, 191, 446, 447], + [ 672, 673, 928, 929, 676, 677, 932, 933, 680, 681, 936, 937, 684, 685, 940, 941, 688, 689, 944, 945, 692, 693, 948, 949, 696, 697, 952, 953, 700, 701, 956, 957], + [ 674, 675, 930, 931, 678, 679, 934, 935, 682, 683, 938, 939, 686, 687, 942, 943, 690, 691, 946, 947, 694, 695, 950, 951, 698, 699, 954, 955, 702, 703, 958, 959], + [ 192, 193, 448, 449, 196, 197, 452, 453, 200, 201, 456, 457, 204, 205, 460, 461, 208, 209, 464, 465, 212, 213, 468, 469, 216, 217, 472, 473, 220, 221, 476, 477], + [ 194, 195, 450, 451, 198, 199, 454, 455, 202, 203, 458, 459, 206, 207, 462, 463, 210, 211, 466, 467, 214, 215, 470, 471, 218, 219, 474, 475, 222, 223, 478, 479], + [ 704, 705, 960, 961, 708, 709, 964, 965, 712, 713, 968, 969, 716, 717, 972, 973, 720, 721, 976, 977, 724, 725, 980, 981, 728, 729, 984, 985, 732, 733, 988, 989], + [ 706, 707, 962, 963, 710, 711, 966, 967, 714, 715, 970, 971, 718, 719, 974, 975, 722, 723, 978, 979, 726, 727, 982, 983, 730, 731, 986, 987, 734, 735, 990, 991], + [ 224, 225, 480, 481, 228, 229, 484, 485, 232, 233, 488, 489, 236, 237, 492, 493, 240, 241, 496, 497, 244, 245, 500, 501, 248, 249, 504, 505, 252, 253, 508, 509], + [ 226, 227, 482, 483, 230, 231, 486, 487, 234, 235, 490, 491, 238, 239, 494, 495, 242, 243, 498, 499, 246, 247, 502, 503, 250, 251, 506, 507, 254, 255, 510, 511], + [ 736, 737, 992, 993, 740, 741, 996, 997, 744, 745, 1000, 1001, 748, 749, 1004, 1005, 752, 753, 1008, 1009, 756, 757, 1012, 1013, 760, 761, 1016, 1017, 764, 765, 1020, 1021], + [ 738, 739, 994, 995, 742, 743, 998, 999, 746, 747, 1002, 1003, 750, 751, 1006, 1007, 754, 755, 1010, 1011, 758, 759, 1014, 1015, 762, 763, 1018, 1019, 766, 767, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Repeat one dimension + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(1, 32 // 4), + tile_group_steps=(2, 2), + ) + assert len(tiles) == (32 // (2 * 1)) * (32 // (2 * (32 // 4))) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=832, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=962, sizes=[1, 8, 2, 2], strides=[0, 4, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 32, 33, 4, 5, 36, 37, 8, 9, 40, 41, 12, 13, 44, 45, 16, 17, 48, 49, 20, 21, 52, 53, 24, 25, 56, 57, 28, 29, 60, 61], + [ 2, 3, 34, 35, 6, 7, 38, 39, 10, 11, 42, 43, 14, 15, 46, 47, 18, 19, 50, 51, 22, 23, 54, 55, 26, 27, 58, 59, 30, 31, 62, 63], + [ 64, 65, 96, 97, 68, 69, 100, 101, 72, 73, 104, 105, 76, 77, 108, 109, 80, 81, 112, 113, 84, 85, 116, 117, 88, 89, 120, 121, 92, 93, 124, 125], + [ 66, 67, 98, 99, 70, 71, 102, 103, 74, 75, 106, 107, 78, 79, 110, 111, 82, 83, 114, 115, 86, 87, 118, 119, 90, 91, 122, 123, 94, 95, 126, 127], + [ 128, 129, 160, 161, 132, 133, 164, 165, 136, 137, 168, 169, 140, 141, 172, 173, 144, 145, 176, 177, 148, 149, 180, 181, 152, 153, 184, 185, 156, 157, 188, 189], + [ 130, 131, 162, 163, 134, 135, 166, 167, 138, 139, 170, 171, 142, 143, 174, 175, 146, 147, 178, 179, 150, 151, 182, 183, 154, 155, 186, 187, 158, 159, 190, 191], + [ 192, 193, 224, 225, 196, 197, 228, 229, 200, 201, 232, 233, 204, 205, 236, 237, 208, 209, 240, 241, 212, 213, 244, 245, 216, 217, 248, 249, 220, 221, 252, 253], + [ 194, 195, 226, 227, 198, 199, 230, 231, 202, 203, 234, 235, 206, 207, 238, 239, 210, 211, 242, 243, 214, 215, 246, 247, 218, 219, 250, 251, 222, 223, 254, 255], + [ 256, 257, 288, 289, 260, 261, 292, 293, 264, 265, 296, 297, 268, 269, 300, 301, 272, 273, 304, 305, 276, 277, 308, 309, 280, 281, 312, 313, 284, 285, 316, 317], + [ 258, 259, 290, 291, 262, 263, 294, 295, 266, 267, 298, 299, 270, 271, 302, 303, 274, 275, 306, 307, 278, 279, 310, 311, 282, 283, 314, 315, 286, 287, 318, 319], + [ 320, 321, 352, 353, 324, 325, 356, 357, 328, 329, 360, 361, 332, 333, 364, 365, 336, 337, 368, 369, 340, 341, 372, 373, 344, 345, 376, 377, 348, 349, 380, 381], + [ 322, 323, 354, 355, 326, 327, 358, 359, 330, 331, 362, 363, 334, 335, 366, 367, 338, 339, 370, 371, 342, 343, 374, 375, 346, 347, 378, 379, 350, 351, 382, 383], + [ 384, 385, 416, 417, 388, 389, 420, 421, 392, 393, 424, 425, 396, 397, 428, 429, 400, 401, 432, 433, 404, 405, 436, 437, 408, 409, 440, 441, 412, 413, 444, 445], + [ 386, 387, 418, 419, 390, 391, 422, 423, 394, 395, 426, 427, 398, 399, 430, 431, 402, 403, 434, 435, 406, 407, 438, 439, 410, 411, 442, 443, 414, 415, 446, 447], + [ 448, 449, 480, 481, 452, 453, 484, 485, 456, 457, 488, 489, 460, 461, 492, 493, 464, 465, 496, 497, 468, 469, 500, 501, 472, 473, 504, 505, 476, 477, 508, 509], + [ 450, 451, 482, 483, 454, 455, 486, 487, 458, 459, 490, 491, 462, 463, 494, 495, 466, 467, 498, 499, 470, 471, 502, 503, 474, 475, 506, 507, 478, 479, 510, 511], + [ 512, 513, 544, 545, 516, 517, 548, 549, 520, 521, 552, 553, 524, 525, 556, 557, 528, 529, 560, 561, 532, 533, 564, 565, 536, 537, 568, 569, 540, 541, 572, 573], + [ 514, 515, 546, 547, 518, 519, 550, 551, 522, 523, 554, 555, 526, 527, 558, 559, 530, 531, 562, 563, 534, 535, 566, 567, 538, 539, 570, 571, 542, 543, 574, 575], + [ 576, 577, 608, 609, 580, 581, 612, 613, 584, 585, 616, 617, 588, 589, 620, 621, 592, 593, 624, 625, 596, 597, 628, 629, 600, 601, 632, 633, 604, 605, 636, 637], + [ 578, 579, 610, 611, 582, 583, 614, 615, 586, 587, 618, 619, 590, 591, 622, 623, 594, 595, 626, 627, 598, 599, 630, 631, 602, 603, 634, 635, 606, 607, 638, 639], + [ 640, 641, 672, 673, 644, 645, 676, 677, 648, 649, 680, 681, 652, 653, 684, 685, 656, 657, 688, 689, 660, 661, 692, 693, 664, 665, 696, 697, 668, 669, 700, 701], + [ 642, 643, 674, 675, 646, 647, 678, 679, 650, 651, 682, 683, 654, 655, 686, 687, 658, 659, 690, 691, 662, 663, 694, 695, 666, 667, 698, 699, 670, 671, 702, 703], + [ 704, 705, 736, 737, 708, 709, 740, 741, 712, 713, 744, 745, 716, 717, 748, 749, 720, 721, 752, 753, 724, 725, 756, 757, 728, 729, 760, 761, 732, 733, 764, 765], + [ 706, 707, 738, 739, 710, 711, 742, 743, 714, 715, 746, 747, 718, 719, 750, 751, 722, 723, 754, 755, 726, 727, 758, 759, 730, 731, 762, 763, 734, 735, 766, 767], + [ 768, 769, 800, 801, 772, 773, 804, 805, 776, 777, 808, 809, 780, 781, 812, 813, 784, 785, 816, 817, 788, 789, 820, 821, 792, 793, 824, 825, 796, 797, 828, 829], + [ 770, 771, 802, 803, 774, 775, 806, 807, 778, 779, 810, 811, 782, 783, 814, 815, 786, 787, 818, 819, 790, 791, 822, 823, 794, 795, 826, 827, 798, 799, 830, 831], + [ 832, 833, 864, 865, 836, 837, 868, 869, 840, 841, 872, 873, 844, 845, 876, 877, 848, 849, 880, 881, 852, 853, 884, 885, 856, 857, 888, 889, 860, 861, 892, 893], + [ 834, 835, 866, 867, 838, 839, 870, 871, 842, 843, 874, 875, 846, 847, 878, 879, 850, 851, 882, 883, 854, 855, 886, 887, 858, 859, 890, 891, 862, 863, 894, 895], + [ 896, 897, 928, 929, 900, 901, 932, 933, 904, 905, 936, 937, 908, 909, 940, 941, 912, 913, 944, 945, 916, 917, 948, 949, 920, 921, 952, 953, 924, 925, 956, 957], + [ 898, 899, 930, 931, 902, 903, 934, 935, 906, 907, 938, 939, 910, 911, 942, 943, 914, 915, 946, 947, 918, 919, 950, 951, 922, 923, 954, 955, 926, 927, 958, 959], + [ 960, 961, 992, 993, 964, 965, 996, 997, 968, 969, 1000, 1001, 972, 973, 1004, 1005, 976, 977, 1008, 1009, 980, 981, 1012, 1013, 984, 985, 1016, 1017, 988, 989, 1020, 1021], + [ 962, 963, 994, 995, 966, 967, 998, 999, 970, 971, 1002, 1003, 974, 975, 1006, 1007, 978, 979, 1010, 1011, 982, 983, 1014, 1015, 986, 987, 1018, 1019, 990, 991, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Repeat other dimension + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(32 // 4, 1), + tile_group_steps=(2, 2), + ) + assert len(tiles) == (32 // (2 * 1)) * (32 // (2 * (32 // 4))) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1] + ) + assert tiles[26] == TensorTile( + (32, 32), offset=84, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=94, sizes=[1, 8, 2, 2], strides=[0, 128, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 32, 33, 64, 65, 96, 97, 128, 129, 160, 161, 192, 193, 224, 225, 256, 257, 288, 289, 320, 321, 352, 353, 384, 385, 416, 417, 448, 449, 480, 481], + [ 2, 3, 34, 35, 66, 67, 98, 99, 130, 131, 162, 163, 194, 195, 226, 227, 258, 259, 290, 291, 322, 323, 354, 355, 386, 387, 418, 419, 450, 451, 482, 483], + [ 512, 513, 544, 545, 576, 577, 608, 609, 640, 641, 672, 673, 704, 705, 736, 737, 768, 769, 800, 801, 832, 833, 864, 865, 896, 897, 928, 929, 960, 961, 992, 993], + [ 514, 515, 546, 547, 578, 579, 610, 611, 642, 643, 674, 675, 706, 707, 738, 739, 770, 771, 802, 803, 834, 835, 866, 867, 898, 899, 930, 931, 962, 963, 994, 995], + [ 4, 5, 36, 37, 68, 69, 100, 101, 132, 133, 164, 165, 196, 197, 228, 229, 260, 261, 292, 293, 324, 325, 356, 357, 388, 389, 420, 421, 452, 453, 484, 485], + [ 6, 7, 38, 39, 70, 71, 102, 103, 134, 135, 166, 167, 198, 199, 230, 231, 262, 263, 294, 295, 326, 327, 358, 359, 390, 391, 422, 423, 454, 455, 486, 487], + [ 516, 517, 548, 549, 580, 581, 612, 613, 644, 645, 676, 677, 708, 709, 740, 741, 772, 773, 804, 805, 836, 837, 868, 869, 900, 901, 932, 933, 964, 965, 996, 997], + [ 518, 519, 550, 551, 582, 583, 614, 615, 646, 647, 678, 679, 710, 711, 742, 743, 774, 775, 806, 807, 838, 839, 870, 871, 902, 903, 934, 935, 966, 967, 998, 999], + [ 8, 9, 40, 41, 72, 73, 104, 105, 136, 137, 168, 169, 200, 201, 232, 233, 264, 265, 296, 297, 328, 329, 360, 361, 392, 393, 424, 425, 456, 457, 488, 489], + [ 10, 11, 42, 43, 74, 75, 106, 107, 138, 139, 170, 171, 202, 203, 234, 235, 266, 267, 298, 299, 330, 331, 362, 363, 394, 395, 426, 427, 458, 459, 490, 491], + [ 520, 521, 552, 553, 584, 585, 616, 617, 648, 649, 680, 681, 712, 713, 744, 745, 776, 777, 808, 809, 840, 841, 872, 873, 904, 905, 936, 937, 968, 969, 1000, 1001], + [ 522, 523, 554, 555, 586, 587, 618, 619, 650, 651, 682, 683, 714, 715, 746, 747, 778, 779, 810, 811, 842, 843, 874, 875, 906, 907, 938, 939, 970, 971, 1002, 1003], + [ 12, 13, 44, 45, 76, 77, 108, 109, 140, 141, 172, 173, 204, 205, 236, 237, 268, 269, 300, 301, 332, 333, 364, 365, 396, 397, 428, 429, 460, 461, 492, 493], + [ 14, 15, 46, 47, 78, 79, 110, 111, 142, 143, 174, 175, 206, 207, 238, 239, 270, 271, 302, 303, 334, 335, 366, 367, 398, 399, 430, 431, 462, 463, 494, 495], + [ 524, 525, 556, 557, 588, 589, 620, 621, 652, 653, 684, 685, 716, 717, 748, 749, 780, 781, 812, 813, 844, 845, 876, 877, 908, 909, 940, 941, 972, 973, 1004, 1005], + [ 526, 527, 558, 559, 590, 591, 622, 623, 654, 655, 686, 687, 718, 719, 750, 751, 782, 783, 814, 815, 846, 847, 878, 879, 910, 911, 942, 943, 974, 975, 1006, 1007], + [ 16, 17, 48, 49, 80, 81, 112, 113, 144, 145, 176, 177, 208, 209, 240, 241, 272, 273, 304, 305, 336, 337, 368, 369, 400, 401, 432, 433, 464, 465, 496, 497], + [ 18, 19, 50, 51, 82, 83, 114, 115, 146, 147, 178, 179, 210, 211, 242, 243, 274, 275, 306, 307, 338, 339, 370, 371, 402, 403, 434, 435, 466, 467, 498, 499], + [ 528, 529, 560, 561, 592, 593, 624, 625, 656, 657, 688, 689, 720, 721, 752, 753, 784, 785, 816, 817, 848, 849, 880, 881, 912, 913, 944, 945, 976, 977, 1008, 1009], + [ 530, 531, 562, 563, 594, 595, 626, 627, 658, 659, 690, 691, 722, 723, 754, 755, 786, 787, 818, 819, 850, 851, 882, 883, 914, 915, 946, 947, 978, 979, 1010, 1011], + [ 20, 21, 52, 53, 84, 85, 116, 117, 148, 149, 180, 181, 212, 213, 244, 245, 276, 277, 308, 309, 340, 341, 372, 373, 404, 405, 436, 437, 468, 469, 500, 501], + [ 22, 23, 54, 55, 86, 87, 118, 119, 150, 151, 182, 183, 214, 215, 246, 247, 278, 279, 310, 311, 342, 343, 374, 375, 406, 407, 438, 439, 470, 471, 502, 503], + [ 532, 533, 564, 565, 596, 597, 628, 629, 660, 661, 692, 693, 724, 725, 756, 757, 788, 789, 820, 821, 852, 853, 884, 885, 916, 917, 948, 949, 980, 981, 1012, 1013], + [ 534, 535, 566, 567, 598, 599, 630, 631, 662, 663, 694, 695, 726, 727, 758, 759, 790, 791, 822, 823, 854, 855, 886, 887, 918, 919, 950, 951, 982, 983, 1014, 1015], + [ 24, 25, 56, 57, 88, 89, 120, 121, 152, 153, 184, 185, 216, 217, 248, 249, 280, 281, 312, 313, 344, 345, 376, 377, 408, 409, 440, 441, 472, 473, 504, 505], + [ 26, 27, 58, 59, 90, 91, 122, 123, 154, 155, 186, 187, 218, 219, 250, 251, 282, 283, 314, 315, 346, 347, 378, 379, 410, 411, 442, 443, 474, 475, 506, 507], + [ 536, 537, 568, 569, 600, 601, 632, 633, 664, 665, 696, 697, 728, 729, 760, 761, 792, 793, 824, 825, 856, 857, 888, 889, 920, 921, 952, 953, 984, 985, 1016, 1017], + [ 538, 539, 570, 571, 602, 603, 634, 635, 666, 667, 698, 699, 730, 731, 762, 763, 794, 795, 826, 827, 858, 859, 890, 891, 922, 923, 954, 955, 986, 987, 1018, 1019], + [ 28, 29, 60, 61, 92, 93, 124, 125, 156, 157, 188, 189, 220, 221, 252, 253, 284, 285, 316, 317, 348, 349, 380, 381, 412, 413, 444, 445, 476, 477, 508, 509], + [ 30, 31, 62, 63, 94, 95, 126, 127, 158, 159, 190, 191, 222, 223, 254, 255, 286, 287, 318, 319, 350, 351, 382, 383, 414, 415, 446, 447, 478, 479, 510, 511], + [ 540, 541, 572, 573, 604, 605, 636, 637, 668, 669, 700, 701, 732, 733, 764, 765, 796, 797, 828, 829, 860, 861, 892, 893, 924, 925, 956, 957, 988, 989, 1020, 1021], + [ 542, 543, 574, 575, 606, 607, 638, 639, 670, 671, 702, 703, 734, 735, 766, 767, 798, 799, 830, 831, 862, 863, 894, 895, 926, 927, 958, 959, 990, 991, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Different repeats and steps + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(8, 2), + tile_group_steps=(2, 4), + ) + assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1] + ) + assert tiles[12] == TensorTile( + (32, 32), offset=80, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=86, sizes=[8, 2, 2, 2], strides=[128, 8, 32, 1] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 1, 64, 65, 128, 129, 192, 193, 4, 5, 68, 69, 132, 133, 196, 197, 256, 257, 320, 321, 384, 385, 448, 449, 260, 261, 324, 325, 388, 389, 452, 453], + [ 2, 3, 66, 67, 130, 131, 194, 195, 6, 7, 70, 71, 134, 135, 198, 199, 258, 259, 322, 323, 386, 387, 450, 451, 262, 263, 326, 327, 390, 391, 454, 455], + [ 512, 513, 576, 577, 640, 641, 704, 705, 516, 517, 580, 581, 644, 645, 708, 709, 768, 769, 832, 833, 896, 897, 960, 961, 772, 773, 836, 837, 900, 901, 964, 965], + [ 514, 515, 578, 579, 642, 643, 706, 707, 518, 519, 582, 583, 646, 647, 710, 711, 770, 771, 834, 835, 898, 899, 962, 963, 774, 775, 838, 839, 902, 903, 966, 967], + [ 8, 9, 72, 73, 136, 137, 200, 201, 12, 13, 76, 77, 140, 141, 204, 205, 264, 265, 328, 329, 392, 393, 456, 457, 268, 269, 332, 333, 396, 397, 460, 461], + [ 10, 11, 74, 75, 138, 139, 202, 203, 14, 15, 78, 79, 142, 143, 206, 207, 266, 267, 330, 331, 394, 395, 458, 459, 270, 271, 334, 335, 398, 399, 462, 463], + [ 520, 521, 584, 585, 648, 649, 712, 713, 524, 525, 588, 589, 652, 653, 716, 717, 776, 777, 840, 841, 904, 905, 968, 969, 780, 781, 844, 845, 908, 909, 972, 973], + [ 522, 523, 586, 587, 650, 651, 714, 715, 526, 527, 590, 591, 654, 655, 718, 719, 778, 779, 842, 843, 906, 907, 970, 971, 782, 783, 846, 847, 910, 911, 974, 975], + [ 16, 17, 80, 81, 144, 145, 208, 209, 20, 21, 84, 85, 148, 149, 212, 213, 272, 273, 336, 337, 400, 401, 464, 465, 276, 277, 340, 341, 404, 405, 468, 469], + [ 18, 19, 82, 83, 146, 147, 210, 211, 22, 23, 86, 87, 150, 151, 214, 215, 274, 275, 338, 339, 402, 403, 466, 467, 278, 279, 342, 343, 406, 407, 470, 471], + [ 528, 529, 592, 593, 656, 657, 720, 721, 532, 533, 596, 597, 660, 661, 724, 725, 784, 785, 848, 849, 912, 913, 976, 977, 788, 789, 852, 853, 916, 917, 980, 981], + [ 530, 531, 594, 595, 658, 659, 722, 723, 534, 535, 598, 599, 662, 663, 726, 727, 786, 787, 850, 851, 914, 915, 978, 979, 790, 791, 854, 855, 918, 919, 982, 983], + [ 24, 25, 88, 89, 152, 153, 216, 217, 28, 29, 92, 93, 156, 157, 220, 221, 280, 281, 344, 345, 408, 409, 472, 473, 284, 285, 348, 349, 412, 413, 476, 477], + [ 26, 27, 90, 91, 154, 155, 218, 219, 30, 31, 94, 95, 158, 159, 222, 223, 282, 283, 346, 347, 410, 411, 474, 475, 286, 287, 350, 351, 414, 415, 478, 479], + [ 536, 537, 600, 601, 664, 665, 728, 729, 540, 541, 604, 605, 668, 669, 732, 733, 792, 793, 856, 857, 920, 921, 984, 985, 796, 797, 860, 861, 924, 925, 988, 989], + [ 538, 539, 602, 603, 666, 667, 730, 731, 542, 543, 606, 607, 670, 671, 734, 735, 794, 795, 858, 859, 922, 923, 986, 987, 798, 799, 862, 863, 926, 927, 990, 991], + [ 32, 33, 96, 97, 160, 161, 224, 225, 36, 37, 100, 101, 164, 165, 228, 229, 288, 289, 352, 353, 416, 417, 480, 481, 292, 293, 356, 357, 420, 421, 484, 485], + [ 34, 35, 98, 99, 162, 163, 226, 227, 38, 39, 102, 103, 166, 167, 230, 231, 290, 291, 354, 355, 418, 419, 482, 483, 294, 295, 358, 359, 422, 423, 486, 487], + [ 544, 545, 608, 609, 672, 673, 736, 737, 548, 549, 612, 613, 676, 677, 740, 741, 800, 801, 864, 865, 928, 929, 992, 993, 804, 805, 868, 869, 932, 933, 996, 997], + [ 546, 547, 610, 611, 674, 675, 738, 739, 550, 551, 614, 615, 678, 679, 742, 743, 802, 803, 866, 867, 930, 931, 994, 995, 806, 807, 870, 871, 934, 935, 998, 999], + [ 40, 41, 104, 105, 168, 169, 232, 233, 44, 45, 108, 109, 172, 173, 236, 237, 296, 297, 360, 361, 424, 425, 488, 489, 300, 301, 364, 365, 428, 429, 492, 493], + [ 42, 43, 106, 107, 170, 171, 234, 235, 46, 47, 110, 111, 174, 175, 238, 239, 298, 299, 362, 363, 426, 427, 490, 491, 302, 303, 366, 367, 430, 431, 494, 495], + [ 552, 553, 616, 617, 680, 681, 744, 745, 556, 557, 620, 621, 684, 685, 748, 749, 808, 809, 872, 873, 936, 937, 1000, 1001, 812, 813, 876, 877, 940, 941, 1004, 1005], + [ 554, 555, 618, 619, 682, 683, 746, 747, 558, 559, 622, 623, 686, 687, 750, 751, 810, 811, 874, 875, 938, 939, 1002, 1003, 814, 815, 878, 879, 942, 943, 1006, 1007], + [ 48, 49, 112, 113, 176, 177, 240, 241, 52, 53, 116, 117, 180, 181, 244, 245, 304, 305, 368, 369, 432, 433, 496, 497, 308, 309, 372, 373, 436, 437, 500, 501], + [ 50, 51, 114, 115, 178, 179, 242, 243, 54, 55, 118, 119, 182, 183, 246, 247, 306, 307, 370, 371, 434, 435, 498, 499, 310, 311, 374, 375, 438, 439, 502, 503], + [ 560, 561, 624, 625, 688, 689, 752, 753, 564, 565, 628, 629, 692, 693, 756, 757, 816, 817, 880, 881, 944, 945, 1008, 1009, 820, 821, 884, 885, 948, 949, 1012, 1013], + [ 562, 563, 626, 627, 690, 691, 754, 755, 566, 567, 630, 631, 694, 695, 758, 759, 818, 819, 882, 883, 946, 947, 1010, 1011, 822, 823, 886, 887, 950, 951, 1014, 1015], + [ 56, 57, 120, 121, 184, 185, 248, 249, 60, 61, 124, 125, 188, 189, 252, 253, 312, 313, 376, 377, 440, 441, 504, 505, 316, 317, 380, 381, 444, 445, 508, 509], + [ 58, 59, 122, 123, 186, 187, 250, 251, 62, 63, 126, 127, 190, 191, 254, 255, 314, 315, 378, 379, 442, 443, 506, 507, 318, 319, 382, 383, 446, 447, 510, 511], + [ 568, 569, 632, 633, 696, 697, 760, 761, 572, 573, 636, 637, 700, 701, 764, 765, 824, 825, 888, 889, 952, 953, 1016, 1017, 828, 829, 892, 893, 956, 957, 1020, 1021], + [ 570, 571, 634, 635, 698, 699, 762, 763, 574, 575, 638, 639, 702, 703, 766, 767, 826, 827, 890, 891, 954, 955, 1018, 1019, 830, 831, 894, 895, 958, 959, 1022, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile col major + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(8, 2), + tile_group_steps=(2, 4), + tile_col_major=True, + ) + assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32] + ) + assert tiles[12] == TensorTile( + (32, 32), offset=80, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=86, sizes=[8, 2, 2, 2], strides=[128, 8, 1, 32] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 2, 64, 66, 128, 130, 192, 194, 4, 6, 68, 70, 132, 134, 196, 198, 256, 258, 320, 322, 384, 386, 448, 450, 260, 262, 324, 326, 388, 390, 452, 454], + [ 1, 3, 65, 67, 129, 131, 193, 195, 5, 7, 69, 71, 133, 135, 197, 199, 257, 259, 321, 323, 385, 387, 449, 451, 261, 263, 325, 327, 389, 391, 453, 455], + [ 512, 514, 576, 578, 640, 642, 704, 706, 516, 518, 580, 582, 644, 646, 708, 710, 768, 770, 832, 834, 896, 898, 960, 962, 772, 774, 836, 838, 900, 902, 964, 966], + [ 513, 515, 577, 579, 641, 643, 705, 707, 517, 519, 581, 583, 645, 647, 709, 711, 769, 771, 833, 835, 897, 899, 961, 963, 773, 775, 837, 839, 901, 903, 965, 967], + [ 8, 10, 72, 74, 136, 138, 200, 202, 12, 14, 76, 78, 140, 142, 204, 206, 264, 266, 328, 330, 392, 394, 456, 458, 268, 270, 332, 334, 396, 398, 460, 462], + [ 9, 11, 73, 75, 137, 139, 201, 203, 13, 15, 77, 79, 141, 143, 205, 207, 265, 267, 329, 331, 393, 395, 457, 459, 269, 271, 333, 335, 397, 399, 461, 463], + [ 520, 522, 584, 586, 648, 650, 712, 714, 524, 526, 588, 590, 652, 654, 716, 718, 776, 778, 840, 842, 904, 906, 968, 970, 780, 782, 844, 846, 908, 910, 972, 974], + [ 521, 523, 585, 587, 649, 651, 713, 715, 525, 527, 589, 591, 653, 655, 717, 719, 777, 779, 841, 843, 905, 907, 969, 971, 781, 783, 845, 847, 909, 911, 973, 975], + [ 16, 18, 80, 82, 144, 146, 208, 210, 20, 22, 84, 86, 148, 150, 212, 214, 272, 274, 336, 338, 400, 402, 464, 466, 276, 278, 340, 342, 404, 406, 468, 470], + [ 17, 19, 81, 83, 145, 147, 209, 211, 21, 23, 85, 87, 149, 151, 213, 215, 273, 275, 337, 339, 401, 403, 465, 467, 277, 279, 341, 343, 405, 407, 469, 471], + [ 528, 530, 592, 594, 656, 658, 720, 722, 532, 534, 596, 598, 660, 662, 724, 726, 784, 786, 848, 850, 912, 914, 976, 978, 788, 790, 852, 854, 916, 918, 980, 982], + [ 529, 531, 593, 595, 657, 659, 721, 723, 533, 535, 597, 599, 661, 663, 725, 727, 785, 787, 849, 851, 913, 915, 977, 979, 789, 791, 853, 855, 917, 919, 981, 983], + [ 24, 26, 88, 90, 152, 154, 216, 218, 28, 30, 92, 94, 156, 158, 220, 222, 280, 282, 344, 346, 408, 410, 472, 474, 284, 286, 348, 350, 412, 414, 476, 478], + [ 25, 27, 89, 91, 153, 155, 217, 219, 29, 31, 93, 95, 157, 159, 221, 223, 281, 283, 345, 347, 409, 411, 473, 475, 285, 287, 349, 351, 413, 415, 477, 479], + [ 536, 538, 600, 602, 664, 666, 728, 730, 540, 542, 604, 606, 668, 670, 732, 734, 792, 794, 856, 858, 920, 922, 984, 986, 796, 798, 860, 862, 924, 926, 988, 990], + [ 537, 539, 601, 603, 665, 667, 729, 731, 541, 543, 605, 607, 669, 671, 733, 735, 793, 795, 857, 859, 921, 923, 985, 987, 797, 799, 861, 863, 925, 927, 989, 991], + [ 32, 34, 96, 98, 160, 162, 224, 226, 36, 38, 100, 102, 164, 166, 228, 230, 288, 290, 352, 354, 416, 418, 480, 482, 292, 294, 356, 358, 420, 422, 484, 486], + [ 33, 35, 97, 99, 161, 163, 225, 227, 37, 39, 101, 103, 165, 167, 229, 231, 289, 291, 353, 355, 417, 419, 481, 483, 293, 295, 357, 359, 421, 423, 485, 487], + [ 544, 546, 608, 610, 672, 674, 736, 738, 548, 550, 612, 614, 676, 678, 740, 742, 800, 802, 864, 866, 928, 930, 992, 994, 804, 806, 868, 870, 932, 934, 996, 998], + [ 545, 547, 609, 611, 673, 675, 737, 739, 549, 551, 613, 615, 677, 679, 741, 743, 801, 803, 865, 867, 929, 931, 993, 995, 805, 807, 869, 871, 933, 935, 997, 999], + [ 40, 42, 104, 106, 168, 170, 232, 234, 44, 46, 108, 110, 172, 174, 236, 238, 296, 298, 360, 362, 424, 426, 488, 490, 300, 302, 364, 366, 428, 430, 492, 494], + [ 41, 43, 105, 107, 169, 171, 233, 235, 45, 47, 109, 111, 173, 175, 237, 239, 297, 299, 361, 363, 425, 427, 489, 491, 301, 303, 365, 367, 429, 431, 493, 495], + [ 552, 554, 616, 618, 680, 682, 744, 746, 556, 558, 620, 622, 684, 686, 748, 750, 808, 810, 872, 874, 936, 938, 1000, 1002, 812, 814, 876, 878, 940, 942, 1004, 1006], + [ 553, 555, 617, 619, 681, 683, 745, 747, 557, 559, 621, 623, 685, 687, 749, 751, 809, 811, 873, 875, 937, 939, 1001, 1003, 813, 815, 877, 879, 941, 943, 1005, 1007], + [ 48, 50, 112, 114, 176, 178, 240, 242, 52, 54, 116, 118, 180, 182, 244, 246, 304, 306, 368, 370, 432, 434, 496, 498, 308, 310, 372, 374, 436, 438, 500, 502], + [ 49, 51, 113, 115, 177, 179, 241, 243, 53, 55, 117, 119, 181, 183, 245, 247, 305, 307, 369, 371, 433, 435, 497, 499, 309, 311, 373, 375, 437, 439, 501, 503], + [ 560, 562, 624, 626, 688, 690, 752, 754, 564, 566, 628, 630, 692, 694, 756, 758, 816, 818, 880, 882, 944, 946, 1008, 1010, 820, 822, 884, 886, 948, 950, 1012, 1014], + [ 561, 563, 625, 627, 689, 691, 753, 755, 565, 567, 629, 631, 693, 695, 757, 759, 817, 819, 881, 883, 945, 947, 1009, 1011, 821, 823, 885, 887, 949, 951, 1013, 1015], + [ 56, 58, 120, 122, 184, 186, 248, 250, 60, 62, 124, 126, 188, 190, 252, 254, 312, 314, 376, 378, 440, 442, 504, 506, 316, 318, 380, 382, 444, 446, 508, 510], + [ 57, 59, 121, 123, 185, 187, 249, 251, 61, 63, 125, 127, 189, 191, 253, 255, 313, 315, 377, 379, 441, 443, 505, 507, 317, 319, 381, 383, 445, 447, 509, 511], + [ 568, 570, 632, 634, 696, 698, 760, 762, 572, 574, 636, 638, 700, 702, 764, 766, 824, 826, 888, 890, 952, 954, 1016, 1018, 828, 830, 892, 894, 956, 958, 1020, 1022], + [ 569, 571, 633, 635, 697, 699, 761, 763, 573, 575, 637, 639, 701, 703, 765, 767, 825, 827, 889, 891, 953, 955, 1017, 1019, 829, 831, 893, 895, 957, 959, 1021, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile col major and tile group col major + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(8, 2), + tile_group_steps=(2, 4), + tile_col_major=True, + tile_group_col_major=True, + ) + assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=2, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + assert tiles[12] == TensorTile( + (32, 32), offset=80, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=86, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 2, 64, 66, 128, 130, 192, 194, 32, 34, 96, 98, 160, 162, 224, 226, 256, 258, 320, 322, 384, 386, 448, 450, 288, 290, 352, 354, 416, 418, 480, 482], + [ 1, 3, 65, 67, 129, 131, 193, 195, 33, 35, 97, 99, 161, 163, 225, 227, 257, 259, 321, 323, 385, 387, 449, 451, 289, 291, 353, 355, 417, 419, 481, 483], + [ 512, 514, 576, 578, 640, 642, 704, 706, 544, 546, 608, 610, 672, 674, 736, 738, 768, 770, 832, 834, 896, 898, 960, 962, 800, 802, 864, 866, 928, 930, 992, 994], + [ 513, 515, 577, 579, 641, 643, 705, 707, 545, 547, 609, 611, 673, 675, 737, 739, 769, 771, 833, 835, 897, 899, 961, 963, 801, 803, 865, 867, 929, 931, 993, 995], + [ 4, 6, 68, 70, 132, 134, 196, 198, 36, 38, 100, 102, 164, 166, 228, 230, 260, 262, 324, 326, 388, 390, 452, 454, 292, 294, 356, 358, 420, 422, 484, 486], + [ 5, 7, 69, 71, 133, 135, 197, 199, 37, 39, 101, 103, 165, 167, 229, 231, 261, 263, 325, 327, 389, 391, 453, 455, 293, 295, 357, 359, 421, 423, 485, 487], + [ 516, 518, 580, 582, 644, 646, 708, 710, 548, 550, 612, 614, 676, 678, 740, 742, 772, 774, 836, 838, 900, 902, 964, 966, 804, 806, 868, 870, 932, 934, 996, 998], + [ 517, 519, 581, 583, 645, 647, 709, 711, 549, 551, 613, 615, 677, 679, 741, 743, 773, 775, 837, 839, 901, 903, 965, 967, 805, 807, 869, 871, 933, 935, 997, 999], + [ 8, 10, 72, 74, 136, 138, 200, 202, 40, 42, 104, 106, 168, 170, 232, 234, 264, 266, 328, 330, 392, 394, 456, 458, 296, 298, 360, 362, 424, 426, 488, 490], + [ 9, 11, 73, 75, 137, 139, 201, 203, 41, 43, 105, 107, 169, 171, 233, 235, 265, 267, 329, 331, 393, 395, 457, 459, 297, 299, 361, 363, 425, 427, 489, 491], + [ 520, 522, 584, 586, 648, 650, 712, 714, 552, 554, 616, 618, 680, 682, 744, 746, 776, 778, 840, 842, 904, 906, 968, 970, 808, 810, 872, 874, 936, 938, 1000, 1002], + [ 521, 523, 585, 587, 649, 651, 713, 715, 553, 555, 617, 619, 681, 683, 745, 747, 777, 779, 841, 843, 905, 907, 969, 971, 809, 811, 873, 875, 937, 939, 1001, 1003], + [ 12, 14, 76, 78, 140, 142, 204, 206, 44, 46, 108, 110, 172, 174, 236, 238, 268, 270, 332, 334, 396, 398, 460, 462, 300, 302, 364, 366, 428, 430, 492, 494], + [ 13, 15, 77, 79, 141, 143, 205, 207, 45, 47, 109, 111, 173, 175, 237, 239, 269, 271, 333, 335, 397, 399, 461, 463, 301, 303, 365, 367, 429, 431, 493, 495], + [ 524, 526, 588, 590, 652, 654, 716, 718, 556, 558, 620, 622, 684, 686, 748, 750, 780, 782, 844, 846, 908, 910, 972, 974, 812, 814, 876, 878, 940, 942, 1004, 1006], + [ 525, 527, 589, 591, 653, 655, 717, 719, 557, 559, 621, 623, 685, 687, 749, 751, 781, 783, 845, 847, 909, 911, 973, 975, 813, 815, 877, 879, 941, 943, 1005, 1007], + [ 16, 18, 80, 82, 144, 146, 208, 210, 48, 50, 112, 114, 176, 178, 240, 242, 272, 274, 336, 338, 400, 402, 464, 466, 304, 306, 368, 370, 432, 434, 496, 498], + [ 17, 19, 81, 83, 145, 147, 209, 211, 49, 51, 113, 115, 177, 179, 241, 243, 273, 275, 337, 339, 401, 403, 465, 467, 305, 307, 369, 371, 433, 435, 497, 499], + [ 528, 530, 592, 594, 656, 658, 720, 722, 560, 562, 624, 626, 688, 690, 752, 754, 784, 786, 848, 850, 912, 914, 976, 978, 816, 818, 880, 882, 944, 946, 1008, 1010], + [ 529, 531, 593, 595, 657, 659, 721, 723, 561, 563, 625, 627, 689, 691, 753, 755, 785, 787, 849, 851, 913, 915, 977, 979, 817, 819, 881, 883, 945, 947, 1009, 1011], + [ 20, 22, 84, 86, 148, 150, 212, 214, 52, 54, 116, 118, 180, 182, 244, 246, 276, 278, 340, 342, 404, 406, 468, 470, 308, 310, 372, 374, 436, 438, 500, 502], + [ 21, 23, 85, 87, 149, 151, 213, 215, 53, 55, 117, 119, 181, 183, 245, 247, 277, 279, 341, 343, 405, 407, 469, 471, 309, 311, 373, 375, 437, 439, 501, 503], + [ 532, 534, 596, 598, 660, 662, 724, 726, 564, 566, 628, 630, 692, 694, 756, 758, 788, 790, 852, 854, 916, 918, 980, 982, 820, 822, 884, 886, 948, 950, 1012, 1014], + [ 533, 535, 597, 599, 661, 663, 725, 727, 565, 567, 629, 631, 693, 695, 757, 759, 789, 791, 853, 855, 917, 919, 981, 983, 821, 823, 885, 887, 949, 951, 1013, 1015], + [ 24, 26, 88, 90, 152, 154, 216, 218, 56, 58, 120, 122, 184, 186, 248, 250, 280, 282, 344, 346, 408, 410, 472, 474, 312, 314, 376, 378, 440, 442, 504, 506], + [ 25, 27, 89, 91, 153, 155, 217, 219, 57, 59, 121, 123, 185, 187, 249, 251, 281, 283, 345, 347, 409, 411, 473, 475, 313, 315, 377, 379, 441, 443, 505, 507], + [ 536, 538, 600, 602, 664, 666, 728, 730, 568, 570, 632, 634, 696, 698, 760, 762, 792, 794, 856, 858, 920, 922, 984, 986, 824, 826, 888, 890, 952, 954, 1016, 1018], + [ 537, 539, 601, 603, 665, 667, 729, 731, 569, 571, 633, 635, 697, 699, 761, 763, 793, 795, 857, 859, 921, 923, 985, 987, 825, 827, 889, 891, 953, 955, 1017, 1019], + [ 28, 30, 92, 94, 156, 158, 220, 222, 60, 62, 124, 126, 188, 190, 252, 254, 284, 286, 348, 350, 412, 414, 476, 478, 316, 318, 380, 382, 444, 446, 508, 510], + [ 29, 31, 93, 95, 157, 159, 221, 223, 61, 63, 125, 127, 189, 191, 253, 255, 285, 287, 349, 351, 413, 415, 477, 479, 317, 319, 381, 383, 445, 447, 509, 511], + [ 540, 542, 604, 606, 668, 670, 732, 734, 572, 574, 636, 638, 700, 702, 764, 766, 796, 798, 860, 862, 924, 926, 988, 990, 828, 830, 892, 894, 956, 958, 1020, 1022], + [ 541, 543, 605, 607, 669, 671, 733, 735, 573, 575, 637, 639, 701, 703, 765, 767, 797, 799, 861, 863, 925, 927, 989, 991, 829, 831, 893, 895, 957, 959, 1021, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # Tile col major and tile group col major and iter col major + tiles = TensorTiler2D.step_tiler( + (32, 32), + tile_dims=(2, 2), + tile_group_repeats=(8, 2), + tile_group_steps=(2, 4), + tile_col_major=True, + tile_group_col_major=True, + iter_col_major=True, + ) + assert len(tiles) == (32 // (2 * 8)) * (32 // (2 * 2)) + assert tiles[0] == TensorTile( + (32, 32), offset=0, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + assert tiles[1] == TensorTile( + (32, 32), offset=64, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + assert tiles[12] == TensorTile( + (32, 32), offset=20, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + assert tiles[-1] == TensorTile( + (32, 32), offset=86, sizes=[2, 8, 2, 2], strides=[8, 128, 1, 32] + ) + # fmt: off + ref_access_order_tensor = np.array([ + [ 0, 2, 128, 130, 256, 258, 384, 386, 32, 34, 160, 162, 288, 290, 416, 418, 512, 514, 640, 642, 768, 770, 896, 898, 544, 546, 672, 674, 800, 802, 928, 930], + [ 1, 3, 129, 131, 257, 259, 385, 387, 33, 35, 161, 163, 289, 291, 417, 419, 513, 515, 641, 643, 769, 771, 897, 899, 545, 547, 673, 675, 801, 803, 929, 931], + [ 64, 66, 192, 194, 320, 322, 448, 450, 96, 98, 224, 226, 352, 354, 480, 482, 576, 578, 704, 706, 832, 834, 960, 962, 608, 610, 736, 738, 864, 866, 992, 994], + [ 65, 67, 193, 195, 321, 323, 449, 451, 97, 99, 225, 227, 353, 355, 481, 483, 577, 579, 705, 707, 833, 835, 961, 963, 609, 611, 737, 739, 865, 867, 993, 995], + [ 4, 6, 132, 134, 260, 262, 388, 390, 36, 38, 164, 166, 292, 294, 420, 422, 516, 518, 644, 646, 772, 774, 900, 902, 548, 550, 676, 678, 804, 806, 932, 934], + [ 5, 7, 133, 135, 261, 263, 389, 391, 37, 39, 165, 167, 293, 295, 421, 423, 517, 519, 645, 647, 773, 775, 901, 903, 549, 551, 677, 679, 805, 807, 933, 935], + [ 68, 70, 196, 198, 324, 326, 452, 454, 100, 102, 228, 230, 356, 358, 484, 486, 580, 582, 708, 710, 836, 838, 964, 966, 612, 614, 740, 742, 868, 870, 996, 998], + [ 69, 71, 197, 199, 325, 327, 453, 455, 101, 103, 229, 231, 357, 359, 485, 487, 581, 583, 709, 711, 837, 839, 965, 967, 613, 615, 741, 743, 869, 871, 997, 999], + [ 8, 10, 136, 138, 264, 266, 392, 394, 40, 42, 168, 170, 296, 298, 424, 426, 520, 522, 648, 650, 776, 778, 904, 906, 552, 554, 680, 682, 808, 810, 936, 938], + [ 9, 11, 137, 139, 265, 267, 393, 395, 41, 43, 169, 171, 297, 299, 425, 427, 521, 523, 649, 651, 777, 779, 905, 907, 553, 555, 681, 683, 809, 811, 937, 939], + [ 72, 74, 200, 202, 328, 330, 456, 458, 104, 106, 232, 234, 360, 362, 488, 490, 584, 586, 712, 714, 840, 842, 968, 970, 616, 618, 744, 746, 872, 874, 1000, 1002], + [ 73, 75, 201, 203, 329, 331, 457, 459, 105, 107, 233, 235, 361, 363, 489, 491, 585, 587, 713, 715, 841, 843, 969, 971, 617, 619, 745, 747, 873, 875, 1001, 1003], + [ 12, 14, 140, 142, 268, 270, 396, 398, 44, 46, 172, 174, 300, 302, 428, 430, 524, 526, 652, 654, 780, 782, 908, 910, 556, 558, 684, 686, 812, 814, 940, 942], + [ 13, 15, 141, 143, 269, 271, 397, 399, 45, 47, 173, 175, 301, 303, 429, 431, 525, 527, 653, 655, 781, 783, 909, 911, 557, 559, 685, 687, 813, 815, 941, 943], + [ 76, 78, 204, 206, 332, 334, 460, 462, 108, 110, 236, 238, 364, 366, 492, 494, 588, 590, 716, 718, 844, 846, 972, 974, 620, 622, 748, 750, 876, 878, 1004, 1006], + [ 77, 79, 205, 207, 333, 335, 461, 463, 109, 111, 237, 239, 365, 367, 493, 495, 589, 591, 717, 719, 845, 847, 973, 975, 621, 623, 749, 751, 877, 879, 1005, 1007], + [ 16, 18, 144, 146, 272, 274, 400, 402, 48, 50, 176, 178, 304, 306, 432, 434, 528, 530, 656, 658, 784, 786, 912, 914, 560, 562, 688, 690, 816, 818, 944, 946], + [ 17, 19, 145, 147, 273, 275, 401, 403, 49, 51, 177, 179, 305, 307, 433, 435, 529, 531, 657, 659, 785, 787, 913, 915, 561, 563, 689, 691, 817, 819, 945, 947], + [ 80, 82, 208, 210, 336, 338, 464, 466, 112, 114, 240, 242, 368, 370, 496, 498, 592, 594, 720, 722, 848, 850, 976, 978, 624, 626, 752, 754, 880, 882, 1008, 1010], + [ 81, 83, 209, 211, 337, 339, 465, 467, 113, 115, 241, 243, 369, 371, 497, 499, 593, 595, 721, 723, 849, 851, 977, 979, 625, 627, 753, 755, 881, 883, 1009, 1011], + [ 20, 22, 148, 150, 276, 278, 404, 406, 52, 54, 180, 182, 308, 310, 436, 438, 532, 534, 660, 662, 788, 790, 916, 918, 564, 566, 692, 694, 820, 822, 948, 950], + [ 21, 23, 149, 151, 277, 279, 405, 407, 53, 55, 181, 183, 309, 311, 437, 439, 533, 535, 661, 663, 789, 791, 917, 919, 565, 567, 693, 695, 821, 823, 949, 951], + [ 84, 86, 212, 214, 340, 342, 468, 470, 116, 118, 244, 246, 372, 374, 500, 502, 596, 598, 724, 726, 852, 854, 980, 982, 628, 630, 756, 758, 884, 886, 1012, 1014], + [ 85, 87, 213, 215, 341, 343, 469, 471, 117, 119, 245, 247, 373, 375, 501, 503, 597, 599, 725, 727, 853, 855, 981, 983, 629, 631, 757, 759, 885, 887, 1013, 1015], + [ 24, 26, 152, 154, 280, 282, 408, 410, 56, 58, 184, 186, 312, 314, 440, 442, 536, 538, 664, 666, 792, 794, 920, 922, 568, 570, 696, 698, 824, 826, 952, 954], + [ 25, 27, 153, 155, 281, 283, 409, 411, 57, 59, 185, 187, 313, 315, 441, 443, 537, 539, 665, 667, 793, 795, 921, 923, 569, 571, 697, 699, 825, 827, 953, 955], + [ 88, 90, 216, 218, 344, 346, 472, 474, 120, 122, 248, 250, 376, 378, 504, 506, 600, 602, 728, 730, 856, 858, 984, 986, 632, 634, 760, 762, 888, 890, 1016, 1018], + [ 89, 91, 217, 219, 345, 347, 473, 475, 121, 123, 249, 251, 377, 379, 505, 507, 601, 603, 729, 731, 857, 859, 985, 987, 633, 635, 761, 763, 889, 891, 1017, 1019], + [ 28, 30, 156, 158, 284, 286, 412, 414, 60, 62, 188, 190, 316, 318, 444, 446, 540, 542, 668, 670, 796, 798, 924, 926, 572, 574, 700, 702, 828, 830, 956, 958], + [ 29, 31, 157, 159, 285, 287, 413, 415, 61, 63, 189, 191, 317, 319, 445, 447, 541, 543, 669, 671, 797, 799, 925, 927, 573, 575, 701, 703, 829, 831, 957, 959], + [ 92, 94, 220, 222, 348, 350, 476, 478, 124, 126, 252, 254, 380, 382, 508, 510, 604, 606, 732, 734, 860, 862, 988, 990, 636, 638, 764, 766, 892, 894, 1020, 1022], + [ 93, 95, 221, 223, 349, 351, 477, 479, 125, 127, 253, 255, 381, 383, 509, 511, 605, 607, 733, 735, 861, 863, 989, 991, 637, 639, 765, 767, 893, 895, 1021, 1023]]) + # fmt: on + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order_tensor).all() + assert (access_count == 1).all() + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: step_tiler_invalid +@construct_test +def step_tiler_invalid(): + try: + tiles = TensorTiler2D.step_tiler( + (), (3, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tensor dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (10, 9, 4), (3, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too many tensor dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3, -1), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3,), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too few tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (1, 1, 1), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too many tile dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=0 + ) + raise ValueError("Invalid repeat.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (4, 2), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Indivisible tile (height)") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3, 3), (1, 1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Indivisible tile (width)") + except ValueError: + # good + pass + + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3, 2), (1,), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Too few tile group dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3, 2), (1, -1), (1, 1), tile_col_major=True, pattern_repeat=5 + ) + raise ValueError("Bad tile group dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (9, 4), (3, 2), (1, 1, 1), (1, 1), tile_col_major=True + ) + raise ValueError("Too many tile group dims, should fail.") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (18, 8), (3, 2), (2, 3), (1, 1), tile_col_major=True + ) + raise ValueError( + "Indivisible by tile repeat width (but without allow_partial)." + ) + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (18, 8), (3, 2), (4, 2), (1, 1), tile_col_major=True + ) + raise ValueError( + "Indivisible by tile repeat height (but without allow_partial)." + ) + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (18, 8), (3, 2), (4, 2), (1, -1), tile_col_major=True + ) + raise ValueError("Bad tile step dims") + except ValueError: + # good + pass + try: + tiles = TensorTiler2D.step_tiler( + (18, 8), (3, 2), (4, 2), (1,), tile_col_major=True + ) + raise ValueError("Too few tile step dims") + except ValueError: + # good + pass + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/step_tiler_partial.py b/test/python/tensortiler/step_tiler_partial.py new file mode 100644 index 0000000000..f5a8a44912 --- /dev/null +++ b/test/python/tensortiler/step_tiler_partial.py @@ -0,0 +1,70 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile, TensorTileSequence, TensorTiler2D +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: step_tiler_partial_row +@construct_test +def step_tiler_partial_row(): + + # all row major + # tile col major + # tile group col major + # iter col major + # all col major + # pattern repeat + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: step_tiler_partial_col +@construct_test +def step_tiler_partial_col(): + """ + THIS IS BUGGY: + tensor_dims = (3 * 5 * 3, 2 * 6 * 2) + + # All row major + tiles = TensorTiler2D.step_tiler( + tensor_dims, tile_dims=(3, 2), tile_group_repeats=(5, 7), tile_group_steps=(2, 3), allow_partial=True + ) + print(len(tiles)) + for t in tiles: + print(t) + print(tiles[0]) + print(tiles[1]) + print(tiles[3]) + print(tiles[-1]) + tiles.visualize(plot_access_count=True) + anim = tiles.animate() + HTML(anim.to_jshtml()) + """ + + # all row major + # tile col major + # tile group col major + # iter col major + # all col major + # pattern repeat + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: step_tiler_partial_both +@construct_test +def step_tiler_partial_both(): + + # all row major + # tile col major + # tile group col major + # iter col major + # all col major + # pattern repeat + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/tensortile.py b/test/python/tensortiler/tensortile.py new file mode 100644 index 0000000000..6c8acf7eae --- /dev/null +++ b/test/python/tensortiler/tensortile.py @@ -0,0 +1,136 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: tensor_tile +@construct_test +def tensor_tile(): + # Valid + tile = TensorTile((2, 3), 4, sizes=[1, 2], strides=[0, 1]) + + # Check accessors + assert ( + tile.tensor_dims[0] == 2 + and tile.tensor_dims[1] == 3 + and len(tile.tensor_dims) == 2 + ) + assert tile.offset == 4 + assert tile.sizes == [1, 2] + assert tile.strides == [0, 1] + assert tile.transformation_dims == [(1, 0), (2, 1)] + access_order, access_count = tile.access_tensors() + assert ( + access_order == np.array([[-1, -1, -1], [-1, 0, 1]], dtype=access_order.dtype) + ).all() + assert ( + access_count == np.array([[0, 0, 0], [0, 1, 1]], dtype=access_count.dtype) + ).all() + + tile2 = TensorTile((2, 3), 4, sizes=[1, 2], strides=[0, 1]) + assert tile2 == tile + + tile3 = TensorTile((2, 3), 2, sizes=[1, 2], strides=[0, 1]) + assert tile3 != tile + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: tensor_tile_invalid +@construct_test +def tensor_tile_invalid(): + + # Bad tensor dims + try: + tile = TensorTile((), 4, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad dims (no dims)") + except ValueError: + # Good + pass + try: + tile = TensorTile((0, 1), 4, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad dims (first dim 0)") + except ValueError: + # Good + pass + try: + tile = TensorTile((1, 0), 4, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad dims (second dim 0)") + except ValueError: + # Good + pass + try: + tile = TensorTile((-1, 1), 4, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad dims (first dim negative)") + except ValueError: + # Good + pass + try: + tile = TensorTile((1, -1), 4, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad dims (second dim negative)") + except ValueError: + # Good + pass + + # Bad offset + try: + tile = TensorTile((2, 3), -1, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad offset (negative)") + except ValueError: + # Good + pass + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[1, 2], strides=[0, 1]) + raise Exception("Should fail, bad offset (too large)") + except ValueError: + # Good + pass + + # Bad sizes + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[-1], strides=[1]) + raise Exception("Should fail, size (negative)") + except ValueError: + # Good + pass + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[0], strides=[1]) + raise Exception("Should fail, size (zero)") + except ValueError: + # Good + pass + + # Bad sizes + strides + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[], strides=[]) + raise Exception("Should fail, size and stride empty") + except ValueError: + # Good + pass + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[1], strides=[0, 1]) + raise Exception("Should fail, sizes and strides uneven dimensions") + except ValueError: + # Good + pass + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[1, 1], strides=[0]) + raise Exception("Should fail, sizes and strides uneven dimensions 2") + except ValueError: + # Good + pass + + # Bad strides + try: + tile = TensorTile((2, 3), 2 * 3, sizes=[1], strides=[-1]) + raise Exception("Should fail, bad stride (negative)") + except ValueError: + # Good + pass + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/tensortilesequence.py b/test/python/tensortiler/tensortilesequence.py new file mode 100644 index 0000000000..8058fb5d75 --- /dev/null +++ b/test/python/tensortiler/tensortilesequence.py @@ -0,0 +1,145 @@ +import numpy as np + +from aie.helpers.tensortiler import TensorTile, TensorTileSequence +from util import construct_test + +# RUN: %python %s | FileCheck %s + + +# CHECK-LABEL: tensor_tile_sequence +@construct_test +def tensor_tile_sequence(): + + empty_tiles = TensorTileSequence((2, 2), 0) + assert len(empty_tiles) == 0 + ref_access_order = np.array([[-1, -1], [-1, -1]]) + ref_access_count = np.array([[0, 0], [0, 0]]) + access_order, access_count = empty_tiles.access_tensors() + assert (access_order == ref_access_order).all() + assert (access_count == ref_access_count).all() + + def offset_fn(step, _prev_offset): + return step + + tiles = TensorTileSequence( + (2, 2), 4, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn + ) + assert len(tiles) == 4 + ref_access_order = np.array([[0, 1], [2, 3]]) + ref_access_count = np.array([[1, 1], [1, 1]]) + access_order, access_count = tiles.access_tensors() + assert (access_order == ref_access_order).all() + assert (access_count == ref_access_count).all() + + tile = TensorTile((2, 2), offset=2, sizes=[1, 1], strides=[1, 1]) + assert tile in tiles + assert tiles[2] == tile + tiles2 = list(iter(tiles)) + assert tiles2[2] == tile + + del tiles[2] + assert not (tile in tiles) + tiles.insert(2, tile) + assert tile in tiles + + tile2 = TensorTile((3, 3), offset=2, sizes=[1, 1], strides=[1, 1]) + assert not (tile2 in tiles) + + tiles3 = TensorTileSequence( + (2, 2), 4, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn + ) + assert tiles == tiles3 + tiles4 = TensorTileSequence( + (2, 2), 3, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn + ) + assert tiles != tiles4 + ref_access_order = np.array([[0, 1], [2, -1]]) + ref_access_count = np.array([[1, 1], [1, 0]]) + access_order, access_count = tiles4.access_tensors() + assert (access_order == ref_access_order).all() + assert (access_count == ref_access_count).all() + + tiles4_copy = TensorTileSequence.from_tiles(tiles4) + assert tiles4_copy == tiles4 + access_order, access_count = tiles4_copy.access_tensors() + assert (access_order == ref_access_order).all() + assert (access_count == ref_access_count).all() + + # CHECK: Pass! + print("Pass!") + + +# CHECK-LABEL: tensor_tile_sequence_invalid +@construct_test +def tensor_tile_sequence_invalid(): + def offset_fn(step, _prev_offset): + return step + + try: + tiles = TensorTileSequence( + (0, 2), 4, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn + ) + raise Exception("Should fail, bad dims") + except ValueError: + # Good + pass + try: + tiles = TensorTileSequence( + (2, 2), -1, sizes=[1, 1], strides=[1, 1], offset_fn=offset_fn + ) + raise Exception("Should fail, bad num steps") + except ValueError: + # Good + pass + try: + tiles = TensorTileSequence( + (2, 2), 1, sizes=[1, 0], strides=[0, 1], offset_fn=offset_fn + ) + raise Exception("Should fail, bad sizes") + except ValueError: + # Good + pass + try: + tiles = TensorTileSequence( + (2, 2), 1, sizes=[1, 1], strides=[-1, 1], offset_fn=offset_fn + ) + raise Exception("Should fail, bad strides") + except ValueError: + # Good + pass + try: + tiles = TensorTileSequence((2, 2), 1, strides=[1, 1], offset_fn=offset_fn) + raise Exception("Should fail, missing sizes") + except ValueError: + # Good + pass + try: + tiles = TensorTileSequence((2, 2), 1, sizes=[1, 1], offset_fn=offset_fn) + raise Exception("Should fail, missing strides") + except ValueError: + # Good + pass + try: + tiles = TensorTileSequence((2, 2), 1, strides=[1, 1], sizes=[1, 1]) + raise Exception("Should fail, missing offset") + except ValueError: + # Good + pass + + tiles = TensorTileSequence((2, 3), 1, offset=0, strides=[0, 1], sizes=[1, 1]) + try: + tiles.append(TensorTile((3, 2), offset=0, strides=[0, 1], sizes=[1, 1])) + raise Exception("Should not be able to add tile with inconsistent tensor dim") + except ValueError: + # Good + pass + + try: + TensorTileSequence.from_tiles([]) + raise Exception("Should not be able to create sequence from no tiles") + except ValueError: + # Good + pass + + # CHECK: Pass! + print("Pass!") diff --git a/test/python/tensortiler/util.py b/test/python/tensortiler/util.py new file mode 100644 index 0000000000..2267760a46 --- /dev/null +++ b/test/python/tensortiler/util.py @@ -0,0 +1,4 @@ +# Run test +def construct_test(f): + print("\nTEST:", f.__name__) + f()