-
Notifications
You must be signed in to change notification settings - Fork 165
Open
Description
What happened?
Hello,
There is this comment showing how the dot_dimension_numbers can be printed and parsed.
// Generic:
// dot_dimension_numbers = #stablehlo.dot<
// lhs_batching_dimensions = [],
// lhs_contracting_dimensions = [1],
// rhs_batching_dimensions = [],
// rhs_contracting_dimensions = [0]
// >
// dot_dimension_numbers = #stablehlo.dot<
// lhs_batching_dimensions = [0],
// lhs_contracting_dimensions = [2],
// rhs_batching_dimensions = [0],
// rhs_contracting_dimensions = [1]
// >
//
// Custom:
// contracting_dims = [1] x [0]
// batching_dims = [0] x [0], contracting_dims = [2] x [1]
However, it was found by a colleague of mine that it is also possible to parse the following:
// dot_dimension_numbers = #stablehlo.dot<
// lhs_contracting_dimensions = [1],
// rhs_contracting_dimensions = [0]
// >This seems a bit harmless, but it is interesting that this printing the generic format omits these {rhs,lhs}_batching_dimensions. E.g.,
from jaxlib.mlir.dialects import stablehlo as jstablehlo # pylint: disable=no-name-in-module
from jaxlib.mlir.dialects import func
from jaxlib.mlir.ir import Context as jContext # pylint: disable=no-name-in-module
from jaxlib.mlir.ir import Module as jModule # pylint: disable=no-name-in-module
from jax._src.interpreters import mlir as mlir_interpreter
program = """
func.func @dot1(%arg0: tensor<4xf64,>,
%arg1: tensor<4xf64>) -> tensor<f64> {
%0 = "stablehlo.dot_general"(%arg0, %arg1)
{dot_dimension_numbers = #stablehlo.dot<
lhs_contracting_dimensions = [0],
rhs_contracting_dimensions = [0]>,
precision_config = [#stablehlo<precision DEFAULT>,
#stablehlo<precision DEFAULT>]}
: (tensor<4xf64>, tensor<4xf64>) -> tensor<f64>
func.return %0 : tensor<f64>
}
"""
from jaxlib.mlir import ir
with jContext() as ctx:
ctx.load_all_available_dialects()
ctx.append_dialect_registry(mlir_interpreter.upstream_dialects)
jstablehlo.register_dialect(ctx) # pylint: disable=no-member
mod = jModule.parse(program)
print(mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=False))"builtin.module"() ({
"func.func"() <{function_type = (tensor<4xf64>, tensor<4xf64>) -> tensor<f64>, sym_name = "dot1"}> ({
^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>):
# no rhs/lhs batching dimensions
%0 = "stablehlo.dot_general"(%arg0, %arg1) <{dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]}> : (tensor<4xf64>, tensor<4xf64>) -> tensor<f64>
"func.return"(%0) : (tensor<f64>) -> ()
}) : () -> ()
}) : () -> ()
This seems to be a minor issue, but it is interesting because {rhs/lhs}_batching_dimensions are not marked as optional anywhere.
Steps to reproduce your issue
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
Version information
No response
Metadata
Metadata
Assignees
Labels
No labels