Skip to content

Commit 66aa960

Browse files
authored
[CI] Enable building StableHLO in CI (#4170)
Context: Due to the LLVM version upgrade, stablehlo CI tests have been temporarily disabled. Ref link of discussion: #4152
1 parent 6240480 commit 66aa960

File tree

8 files changed

+10
-14
lines changed

8 files changed

+10
-14
lines changed

build_tools/build_standalone.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ cmake -GNinja -B"$build_dir" "$llvm_project_dir/llvm" \
2020
-DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \
2121
-DLLVM_ENABLE_PROJECTS=mlir \
2222
-DLLVM_EXTERNAL_PROJECTS="torch-mlir" \
23-
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
2423
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$project_dir" \
2524
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
2625
-DLLVM_ENABLE_ASSERTIONS=ON \

build_tools/ci/build_posix.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
5151
-DLLVM_TARGETS_TO_BUILD=host \
5252
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
5353
-DTORCH_MLIR_ENABLE_LTC=OFF \
54-
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
5554
-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON
5655
echo "::endgroup::"
5756

build_tools/ci/test_posix.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ case $torch_version in
2828

2929
# TODO: Need to verify in the stable version
3030
# TODO: Enable for nightly once the Stabelhlo integration is done
31-
# echo "::group::Run FxImporter2Stablehlo e2e integration tests"
32-
# python3 -m e2e_testing.main --config=fx_importer_stablehlo -v
33-
# echo "::endgroup::"
31+
echo "::group::Run FxImporter2Stablehlo e2e integration tests"
32+
python3 -m e2e_testing.main --config=fx_importer_stablehlo -v
33+
echo "::endgroup::"
3434

3535
echo "::group::Run FxImporter TOSA e2e integration tests"
3636
python3 -m e2e_testing.main --config=fx_importer_tosa -v

build_tools/e2eshark_build.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \
2323
-DPython3_FIND_VIRTUALENV=ONLY \
2424
-DLLVM_ENABLE_PROJECTS=mlir \
2525
-DLLVM_EXTERNAL_PROJECTS="torch-mlir" \
26-
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
2726
-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \
2827
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
2928
-DLLVM_TARGETS_TO_BUILD=host

build_tools/python_deploy/build_linux_packages.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ function build_in_tree() {
239239
-DLLVM_TARGETS_TO_BUILD=host \
240240
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
241241
-DTORCH_MLIR_ENABLE_LTC=${enable_ltc} \
242-
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
243242
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_bin" \
244243
-DTORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO} \
245244
-DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \
@@ -406,7 +405,6 @@ function build_out_of_tree() {
406405
-DMLIR_DIR="/main_checkout/torch-mlir/llvm-build/lib/cmake/mlir/" \
407406
-DMLIR_ENABLE_BINDINGS_PYTHON=OFF \
408407
-DTORCH_MLIR_ENABLE_LTC=${enable_ltc} \
409-
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
410408
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="$torch_from_bin" \
411409
-DTORCH_MLIR_SRC_PYTORCH_REPO=${TORCH_MLIR_SRC_PYTORCH_REPO} \
412410
-DTORCH_MLIR_SRC_PYTORCH_BRANCH=${TORCH_MLIR_SRC_PYTORCH_BRANCH} \

externals/stablehlo

Submodule stablehlo updated 135 files

lib/InitAll.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
3434
#include "stablehlo/conversions/linalg/transforms/Passes.h"
3535
#include "stablehlo/transforms/Passes.h"
36+
#include "stablehlo/transforms/optimization/Passes.h"
3637
#endif
3738

3839
#ifdef TORCH_MLIR_ENABLE_TOSA

test/Conversion/TorchToStablehlo/quantization.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ func.func @test_quantization_per_channel(%arg0: !torch.vtensor<[4,3,7,7],f32>) -
2525
%0 = torch.vtensor.literal(dense<[4.000000e-01, 1.000000e-01, 2.000000e-01, 3.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
2626
%1 = torch.vtensor.literal(dense<[4, 1, 2, 3]> : tensor<4xsi8>) : !torch.vtensor<[4],si8>
2727
%int12 = torch.constant.int 12
28-
%int1 = torch.constant.int 1
28+
%zero = torch.constant.int 0
2929
// CHECK: %[[QUANT:.+]] = stablehlo.uniform_quantize %[[ARG0]]
30-
// CHECK-SAME: (tensor<4x3x7x7xf32>) -> tensor<4x3x7x7x!quant.uniform<i8:f32:1, {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>
31-
%2 = torch.aten.quantize_per_channel %arg0, %0, %1, %int1, %int12 : !torch.vtensor<[4,3,7,7],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si8>, !torch.int, !torch.int -> !torch.vtensor<[4,3,7,7],!torch.qint8>
30+
// CHECK-SAME: (tensor<4x3x7x7xf32>) -> tensor<4x3x7x7x!quant.uniform<i8:f32:0, {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>
31+
%2 = torch.aten.quantize_per_channel %arg0, %0, %1, %zero, %int12 : !torch.vtensor<[4,3,7,7],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si8>, !torch.int, !torch.int -> !torch.vtensor<[4,3,7,7],!torch.qint8>
3232
%3 = torch.aten.int_repr %2 : !torch.vtensor<[4,3,7,7],!torch.qint8> -> !torch.vtensor<[4,3,7,7],si8>
3333
// CHECK: %[[DEQ:.+]] = stablehlo.uniform_dequantize %[[QUANT]]
34-
// CHECK-SAME: (tensor<4x3x7x7x!quant.uniform<i8:f32:1, {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>) -> tensor<4x3x7x7xf32>
35-
%4 = torch.aten._make_per_channel_quantized_tensor %3, %0, %1, %int1 : !torch.vtensor<[4,3,7,7],si8>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4,3,7,7],!torch.qint8>
34+
// CHECK-SAME: (tensor<4x3x7x7x!quant.uniform<i8:f32:0, {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>) -> tensor<4x3x7x7xf32>
35+
%4 = torch.aten._make_per_channel_quantized_tensor %3, %0, %1, %zero : !torch.vtensor<[4,3,7,7],si8>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4,3,7,7],!torch.qint8>
3636
%5 = torch.aten.dequantize.self %4 : !torch.vtensor<[4,3,7,7],!torch.qint8> -> !torch.vtensor<[4,3,7,7],f32>
3737
return %5 : !torch.vtensor<[4,3,7,7],f32>
3838
}

0 commit comments

Comments
 (0)