-
Notifications
You must be signed in to change notification settings - Fork 25
Open
Description
#loc2 = loc(unknown)
module @zml attributes {mhlo.num_partitions = 4 : i32, mhlo.num_replicas = 1 : i32} {
func.func @main(%arg0: tensor<128x128xf16> {mhlo.layout_mode = "default", mhlo.sharding = "{devices=[4,1]<=[4]}"} loc(unknown), %arg1: tensor<128x128xf16> {mhlo.layout_mode = "default", mhlo.sharding = "{devices=[4,1]<=[4]}"} loc(unknown)) -> (tensor<128x128xf16> {mhlo.layout_mode = "default", mhlo.sharding = "{devices=[4,1]<=[4]}"}) {
%0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{devices=[1,4]<=[4]}"} : (tensor<128x128xf16>) -> tensor<128x128xf16> loc(#loc3)
%1 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{devices=[4,1]<=[4]}"} : (tensor<128x128xf16>) -> tensor<128x128xf16> loc(#loc3)
%2 = stablehlo.dot_general %0, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<128x128xf16>, tensor<128x128xf16>) -> tensor<128x128xf16> loc(#loc6)
%3 = stablehlo.custom_call @Sharding(%2) {backend_config = "", mhlo.sharding = "{devices=[4,1]<=[4]}"} : (tensor<128x128xf16>) -> tensor<128x128xf16> loc(#loc3)
return %3 : tensor<128x128xf16> loc(#loc1)
} loc(#loc1)
} loc(#loc5)
#loc = loc("module.zig":84:39)
#loc1 = loc("module.zig":339:39)
#loc3 = loc("tensor.zig":187:44)
#loc4 = loc("tensor.zig":1179:47)
#loc5 = loc("main"(#loc))
#loc6 = loc("dot({m=128,k=128!,f16},{k=128!,n=128,f16},contracting={ { 1, 0 } },batching={ }"(#loc4))I'm generating MLIR from zml project. How can I take this and pass to sdy_opt? Unfortunately, I won't be able to use JAX to generate sdy.sharding attributes
Metadata
Metadata
Assignees
Labels
No labels