Skip to content

[WebAssembly] Constant fold wasm.dot #149619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

badumbatish
Copy link
Contributor

Constant fold wasm.dot of constant vectors/splats.

Test case added in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll

@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Jul 18, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Jasmine Tang (badumbatish)

Changes

Constant fold wasm.dot of constant vectors/splats.

Test case added in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll


Full diff: https://github.com/llvm/llvm-project/pull/149619.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+31)
  • (added) llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll (+39)
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::aarch64_sve_convert_from_svbool:
   case Intrinsic::wasm_alltrue:
   case Intrinsic::wasm_anytrue:
+  case Intrinsic::wasm_dot:
   // WebAssembly float semantics are always known
   case Intrinsic::wasm_trunc_signed:
   case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
     }
     return ConstantVector::get(Result);
   }
+  case Intrinsic::wasm_dot: {
+    unsigned NumElements =
+        cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+    assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+           "wasm dot takes i16x8 and produce i32x4");
+    assert(Ty->isIntegerTy());
+    SmallVector<APInt, 8> MulVector;
+
+    for (unsigned I = 0; I < NumElements; ++I) {
+      ConstantInt *Elt0 =
+          cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+      ConstantInt *Elt1 =
+          cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+      // sext 32 first, according to specs
+      APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+      // TODO: imul in specs includes a modulo operation
+      // Is this performed automatically via trunc = true in APInt creation of *
+      MulVector.push_back(IMul);
+    }
+    for (unsigned I = 0; I < Result.size(); ++I) {
+      // Same case as with imul
+      APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+      Result[I] = ConstantInt::get(Ty, IAdd);
+    }
+
+    return ConstantVector::get(Result);
+  }
   default:
     break;
   }
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..02c6649becbce
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+; a               =   1    2    3    4    5    6    7    8
+; b               =   1    2    3    4    5    6    7    8
+; k1|k2 = a * b   =   1    4    9   16   25   36   49   64
+; k1 + k2         =   (1+25) |  (4+36) | (9+49)  | (16+64)
+; result          =    26    |   40    |   58    |   80
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT:    ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT:    ret <4 x i32> splat (i32 2)
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+  ret <4 x i32> %res
+}
+
+

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Jasmine Tang (badumbatish)

Changes

Constant fold wasm.dot of constant vectors/splats.

Test case added in llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll


Full diff: https://github.com/llvm/llvm-project/pull/149619.diff

2 Files Affected:

  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+31)
  • (added) llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll (+39)
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::aarch64_sve_convert_from_svbool:
   case Intrinsic::wasm_alltrue:
   case Intrinsic::wasm_anytrue:
+  case Intrinsic::wasm_dot:
   // WebAssembly float semantics are always known
   case Intrinsic::wasm_trunc_signed:
   case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
     }
     return ConstantVector::get(Result);
   }
+  case Intrinsic::wasm_dot: {
+    unsigned NumElements =
+        cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+    assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+           "wasm dot takes i16x8 and produce i32x4");
+    assert(Ty->isIntegerTy());
+    SmallVector<APInt, 8> MulVector;
+
+    for (unsigned I = 0; I < NumElements; ++I) {
+      ConstantInt *Elt0 =
+          cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+      ConstantInt *Elt1 =
+          cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+      // sext 32 first, according to specs
+      APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+      // TODO: imul in specs includes a modulo operation
+      // Is this performed automatically via trunc = true in APInt creation of *
+      MulVector.push_back(IMul);
+    }
+    for (unsigned I = 0; I < Result.size(); ++I) {
+      // Same case as with imul
+      APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+      Result[I] = ConstantInt::get(Ty, IAdd);
+    }
+
+    return ConstantVector::get(Result);
+  }
   default:
     break;
   }
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..02c6649becbce
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+; a               =   1    2    3    4    5    6    7    8
+; b               =   1    2    3    4    5    6    7    8
+; k1|k2 = a * b   =   1    4    9   16   25   36   49   64
+; k1 + k2         =   (1+25) |  (4+36) | (9+49)  | (16+64)
+; result          =    26    |   40    |   58    |   80
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT:    ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT:    ret <4 x i32> splat (i32 2)
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+  ret <4 x i32> %res
+}
+
+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants