Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests][aot] Add test for externalized parameters #202

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

vinayakdsci
Copy link

Progresses towards iree-org/iree#18564.

Adds a test for the example given in Using external parameters section in the IREE PyTorch guide.

harsh-nod and others added 28 commits September 23, 2024 20:13
…g#161)

This PR modifies the insertion point for iter args to ensure that the
iter args are in the same order as the init args and outputs. This
simplifies the mapping between init args, iter args and outputs.

Signed-off-by: Harsh Menon <[email protected]>
Fixes iree-org#85

PR based on the work of @maxbartel 

Requires changes in torch-mlir:
[llvm/torch-mlir/#3688](llvm/torch-mlir#3688)

Adds the mutable modifier to a global buffer and lifts said buffer to a
global if there is a store-producer node associated with it.

Signed-off-by: Christopher McGirr <[email protected]>
Co-authored-by: Maximilian Bartel <[email protected]>
…iree-org#162)

This PR introduces changes to handle elementwise or general arithmetic
operations after we did some tiled-loop-reduction ("Reduction")
operation.

The main problem with the current stack is indexing_dims information for
Reduction relies on the user. This would work if it's user/consumer is
tkw.write, but in other cases such as BinaryPyOp or UnaryPyOp, it will
lack such information.

To make matters worst BinaryPyOp/UnaryPyOp depends on it's src/producer
for indexing dim, while Reduction op depends on it's dst/consumer for
its' indexing dim information. This would ended up causing infinite loop
between UnaryPyOp/BinaryPyOp <-> Reduction.

This PR fixes the indexing dimension logic Reduction and GetResult
(required for expanded Reduction) to be based on it's reduction axis(for
Reduction) and it's source/consumer information.

---------

Signed-off-by: Stanley Winata <[email protected]>
This PR removes the need for propagating indices using
post expansion. The new approach propagates the MMA
indices to the MMA dimensions of all tensors (rather
than just MMA nodes) and then specializes them depending
on whether they lie within the backward slices of the
LHS and RHS or forward slices of the ACC.

---------

Signed-off-by: Harsh Menon <[email protected]>
This PR adds more documentation about tkw. Specifically, it provides a
first draft of the introduction and adds a section on memory access
patterns.

Signed-off-by: Harsh Menon <[email protected]>
…g#166)

The main motivation behind this PR is to enable multiple induction
variable/iterArg on the same tiled "Reduction" loop. To enable above we
did a couple things:

1. Enable lowering/expansion on `operator.getitem` (the op that extract
multiple results in python i.e `res0, res1 = fn`) by templating it
on`GetResult(CustomOp)` since they have the same args and interface and
can reuse most of the indexing/expansion helper.

2. Introduce `res_idx`, a variable to represent which result index of an
op we are referring to, during expansion and context map. This is useful
for ops that has more than one results / variables as outputs.

3. bug fix in expand_reduction, where we hoist out iterating and
expanding of `reduction.init_args` out of the loop that iterates and
expands over the `yield`/`return_val` of the reduction loop. It is
expected that the size of `init_args` is the same as size of
`yield`/`return_val`. Hence if we had N iter_args/yields, we ended up
expanding the `init_args` N x N time instead of N times. We haven't seen
it thus far because we have been only playing with 1 init_arg/iterArg,
and 1x1 == 1.

4. Introduce a canonicalization pattern to fold chains of GetResult.
this is because GetResult by semantic/design is only expected to extract
and have one result. Hence a chain of GetResult should just be replaced
by itself. This help clean up the IR.

num.4 also helps circumvent issue where Reduction and GetResult is
expanded completely by itself not following the DFS structure per
dimension like the rest of the expansion code. This becomes especially
problematic for multiple IterArg since Getitem is not expecting its'
source value to be expanded without it.

---------

Signed-off-by: Stanley Winata <[email protected]>
Instead of generating individual element comparisons and doing
`vector.insertelement` generate the whole mask using vector ops.

Add support for vector codegen when generating MLIR IR from sympy
expressions. Add method `IndexingContext.iota` to generate special
symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index`
will start to generate vector ops when encountering such symbols,
inserting proper `splat`'s between scalar vals when necessary.

---------

Signed-off-by: Ivan Butygin <[email protected]>
…#179)

* Adds an option to `aot.export(import_symbolic_shape_expressions=True)`
to enable emission of torch-mlir symbolic shape constraints. This is
currently set to False until IREE is ready to ingest these by default.

Rough sequence of work in IREE proper:

* Custom lowering of `torch.symbolic_int` and
`torch.bind_symbolic_shape` ops to IREE util "assume" ops. Note that we
are only planning to lower "terminal" bindings (basically function
arguments and a couple of other such categories).
* Canonicalizations to ensure that assume equalities are == 0 (versus
the native form from torch where they assume a non zero equality).
* Fusion will clone corresponding bindings on dependent dims into
dispatch regions.
* Existing linalg shape analysis extended and queryable by codegen.

---------

Signed-off-by: Stella Laurenzo <[email protected]>
This PR adds code to construct the epilogue, kernel
and prologue once we have computed a schedule. We
simulate rotating registers in software and add
visualization tools to show the pipelined graphs.

---------

Signed-off-by: Harsh Menon <[email protected]>
This PR adds support for dynamic dimensions in the
kernels. The user specifies the dynamic dimensions
by
- Not adding them to the hyperparameter dictionary
- Explicitly specifying them with the dynamic_symbols kwarg
  and the dynamic_symbols_mapping kwarg to specify which
  values to use for the dynamic dims at runtime

This PR does not modify the codegen and so incorrect or
unsupported values for the dynamic dims will result
in incorrect results. (garbage in -> garbage out)

---------

Signed-off-by: Harsh Menon <[email protected]>
…ee-org#184)

* Rework how we are lowering `rational` sympy expressions, instead of
delayed materialization via lambdas introduce `_Rational` type and
propagate `numerator/denominator` values independently. Division will
only be materialized on explicit `sympy.floor/ceiling` op.
* Rework how igemm test cases are generated and introduce few real
shapes.
* Use custom pytest markers to separate perf/non-perf tests

---------

Signed-off-by: Ivan Butygin <[email protected]>
The motivation of this pass is to generalize the register analysis pass
which is used to determine the thread shape of TKW.Register, to all
other operations.

One main use case for such is to allow reduction, and later on
"broadcast" to use thread shape information from the kernel as opposed
to relying on vector_shape which may not always be valid.

We generalize the register analysis metho by finding a few anchor ops
who's thread shape information is determined, and then propagate to it's
successors and ancestors.

In addition to that we also implemented a couple helper
function/attributes.

1. Control_fn on BFS, ForwardSlice, BackwardSlice. This is to make it
easier for us to control/stop the search when we hit ops we do not want
to explore. In this case, we do not want to explore/propagate onto other
anchor ops and their children.

2. Introducing parent_op to IterArg and region of Reduction, for
developer ergonomics.

3. Move handling of IterArg and GetUser in BackwardSlice/BFS's get_input
exploration phase to be handled individually as opposed to being handled
when its' consumer is being explored. Previously to explore/propagate
IterArg/GetUser, we need to explore its' consumer, just exploring
IterArg/GetUser will not get handled correctly. This is useful for the
case where we want to propagate/explore mma.acc (usually IterArg)
directly.

---------

Signed-off-by: Stanley Winata <[email protected]>
We would like this to be controlled with a flag.

Signed-off-by: Harsh Menon <[email protected]>
Our tests are flaky, `fail-fast: false` won't allow failing builds abort
other.

Signed-off-by: Ivan Butygin <[email protected]>
Initial version of IGEMM benchmarking.

* If `--runperf` pytest option is set, generate IREE ref code and run
both TKW and ref code with `run_bench=True`
* Add `--dump-perf-files-path` option to save perf info files into
provided directory (filenames based on test name and params)

---------

Signed-off-by: Ivan Butygin <[email protected]>
* Add `arith.andi`, `arith.cmpi`, `vector.maskedload`, `vector.gather`,
`vector.contant_mask`, `vector.insertelement`, `vectot.splat`, support
non-splatted contants.
* Add `interpret_ndrange` helper

---------

Signed-off-by: Ivan Butygin <[email protected]>
Motivation of this PR is to be able to codegen/lower broadcast properly.
With that in mind, we implemented these things:

1. BroadcastOp class, op and lowering, to represent and store
broadcasting information. Mostly S.T we can query target shape
information and the source operand of broadcast.
2. Treat broadcast-add as an index conflict and handle it by emitting
broadcastOp.

---------

Signed-off-by: Stanley Winata <[email protected]>
This PR adds a flag to dump intermediates which include .ll and .s files
to see what instructions were
generated.

---------

Signed-off-by: Harsh Menon <[email protected]>
* Main CI is flaky, add a separate pipeline, which tests only TK as temp
solution
* Make `pytest` output more verbose
* Remove unnecessary stuff from perf pipeline

---------

Signed-off-by: Ivan Butygin <[email protected]>
* Move files from files from `shark-turbine` to `iree/turbine`.
* Update imports
* Update `setup.py`
* Make backward redirect `shark-turbine` -> `iree.turbine` (do we need
this?)

Progress on iree-org#28

---------

Signed-off-by: Ivan Butygin <[email protected]>
Copy link

@kumardeepakamd kumardeepakamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

Copy link
Contributor

@monorimet monorimet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Overall this looks OK, but I think in the future we might want more than just validating numerics.
A follow-up might break this down further or add more tests that ensure parameter externalization occurs as expected (make sure we don't have copies of buffers/parameters or leftover inlined parameters in the generated IR). These are nitpicks based on problems I've run into often on the SD modeling side. ( I believe some of these are covered in existing tests )

Also, If we're adding in tests for external parameters like so, perhaps a follow-up for AOT save/load on disk may cover more use cases.

The sort of problems I hope we validate a path for:

  • saving and loading .safetensors or .irpa from a torch module to a parameter index
  • intake of parameter files from common sources, ensure seamless plug-in of parameters to appropriately constructed compiled modules.

The motive behind these points are cases like Stable Diffusion deployment, where checkpoints are sourced by users and "plugged in" to a given inference solution. If we don't seamlessly intake those parameters, we either have to generate new IR for them or save a copy of parameters with modified parameter keys.

LGTM though. Thanks.

@vinayakdsci
Copy link
Author

@monorimet sure, I agree with all you say. I can add tests in follow up PRs too, for each of the cases that you mention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants