Skip to content

Parsing/Printing generic format for DotDimensionNumbers #2828

@erick-xanadu

Description

@erick-xanadu

What happened?

Hello,

There is this comment showing how the dot_dimension_numbers can be printed and parsed.

// Generic:
//    dot_dimension_numbers = #stablehlo.dot<
//      lhs_batching_dimensions = [],
//      lhs_contracting_dimensions = [1],
//      rhs_batching_dimensions = [],
//      rhs_contracting_dimensions = [0]
//    >
//    dot_dimension_numbers = #stablehlo.dot<
//      lhs_batching_dimensions = [0],
//      lhs_contracting_dimensions = [2],
//      rhs_batching_dimensions = [0],
//      rhs_contracting_dimensions = [1]
//    >
//
// Custom:
//    contracting_dims = [1] x [0]
//    batching_dims = [0] x [0], contracting_dims = [2] x [1]

However, it was found by a colleague of mine that it is also possible to parse the following:

//    dot_dimension_numbers = #stablehlo.dot<
//      lhs_contracting_dimensions = [1],
//      rhs_contracting_dimensions = [0]
//    >

This seems a bit harmless, but it is interesting that this printing the generic format omits these {rhs,lhs}_batching_dimensions. E.g.,

from jaxlib.mlir.dialects import stablehlo as jstablehlo  # pylint: disable=no-name-in-module
from jaxlib.mlir.dialects import func
from jaxlib.mlir.ir import Context as jContext  # pylint: disable=no-name-in-module
from jaxlib.mlir.ir import Module as jModule  # pylint: disable=no-name-in-module
from jax._src.interpreters import mlir as mlir_interpreter

program = """
func.func @dot1(%arg0: tensor<4xf64,>,
                %arg1: tensor<4xf64>) -> tensor<f64> {
  %0 = "stablehlo.dot_general"(%arg0, %arg1)
     {dot_dimension_numbers = #stablehlo.dot<
                                        lhs_contracting_dimensions = [0],
                                        rhs_contracting_dimensions = [0]>,
                                        precision_config = [#stablehlo<precision DEFAULT>,
                                        #stablehlo<precision DEFAULT>]}
           : (tensor<4xf64>, tensor<4xf64>) -> tensor<f64>
  func.return %0 : tensor<f64>
}
"""
from jaxlib.mlir import ir
with jContext() as ctx:
  ctx.load_all_available_dialects()
  ctx.append_dialect_registry(mlir_interpreter.upstream_dialects)
  jstablehlo.register_dialect(ctx)  # pylint: disable=no-member
  mod = jModule.parse(program)
  print(mod.operation.get_asm(binary=False, print_generic_op_form=True, assume_verified=False))
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<4xf64>, tensor<4xf64>) -> tensor<f64>, sym_name = "dot1"}> ({
  ^bb0(%arg0: tensor<4xf64>, %arg1: tensor<4xf64>):
    # no rhs/lhs batching dimensions
    %0 = "stablehlo.dot_general"(%arg0, %arg1) <{dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]}> : (tensor<4xf64>, tensor<4xf64>) -> tensor<f64>
    "func.return"(%0) : (tensor<f64>) -> ()
  }) : () -> ()
}) : () -> ()

This seems to be a minor issue, but it is interesting because {rhs/lhs}_batching_dimensions are not marked as optional anywhere.

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Version information

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions