-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
22e50f9
commit 074afc8
Showing
6 changed files
with
98 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
from mlir.ir import OpResult | ||
|
||
from mlir_utils.dialects.ext.tensor import S, empty | ||
from mlir_utils.dialects.ext.arith import constant | ||
from mlir_utils.dialects.util import register_value_caster | ||
|
||
# noinspection PyUnresolvedReferences | ||
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext | ||
from mlir_utils.types import f64_t, RankedTensorType | ||
|
||
# needed since the fix isn't defined here nor conftest.py | ||
pytest.mark.usefixtures("ctx") | ||
|
||
|
||
def test_caster_registration(ctx: MLIRContext): | ||
sizes = S, 3, S | ||
ten = empty(sizes, f64_t) | ||
assert repr(ten) == "Tensor(%0, tensor<?x3x?xf64>)" | ||
|
||
def dummy_caster(val): | ||
print(val) | ||
return val | ||
|
||
register_value_caster(RankedTensorType.static_typeid, dummy_caster) | ||
ten = empty(sizes, f64_t) | ||
assert repr(ten) == "Tensor(%1, tensor<?x3x?xf64>)" | ||
|
||
register_value_caster(RankedTensorType.static_typeid, dummy_caster, 0) | ||
ten = empty(sizes, f64_t) | ||
assert repr(ten) != "Tensor(%1, tensor<?x3x?xf64>)" | ||
assert isinstance(ten, OpResult) | ||
|
||
one = constant(1) | ||
assert repr(one) == "Scalar(%3, i64)" |