Skip to content

Commit 3339d74

Browse files
authored
Fix Blas diffuse on P^3 computation (#1526)
* Fix Blas diffuse on P^3 computation * fixup
1 parent a1d95f5 commit 3339d74

File tree

6 files changed

+301
-5
lines changed

6 files changed

+301
-5
lines changed

enzyme/Enzyme/Utils.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,10 @@ Function *getOrInsertMemcpyMat(Module &Mod, Type *elementType, PointerType *PT,
11211121
#if LLVM_VERSION_MAJOR >= 15
11221122
if (Mod.getContext().supportsTypedPointers()) {
11231123
#endif
1124-
assert(PT->getPointerElementType() == elementType);
1124+
#if LLVM_VERSION_MAJOR >= 13
1125+
if (!PT->isOpaquePointerTy())
1126+
#endif
1127+
assert(PT->getPointerElementType() == elementType);
11251128
#if LLVM_VERSION_MAJOR >= 15
11261129
}
11271130
#endif

enzyme/Enzyme/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ static inline DIFFE_TYPE whatType(llvm::Type *arg, DerivativeMode mode,
441441
if (!arg->getContext().supportsTypedPointers()) {
442442
return DIFFE_TYPE::DUP_ARG;
443443
}
444+
#elif LLVM_VERSION_MAJOR >= 13
445+
if (arg->isOpaquePointerTy()) {
446+
return DIFFE_TYPE::DUP_ARG;
447+
}
444448
#endif
445449
switch (whatType(arg->getPointerElementType(), mode, integersAreConstant,
446450
seen)) {
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
; RUN: if [ %llvmver -lt 16 ] && [ %llvmver -ge 14 ] ; then %opt < %s %loadEnzyme -opaque-pointers -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -jump-threading -adce -S | FileCheck %s; fi
2+
; RUN: if [ %llvmver -ge 14 ]; then %opt < %s %newLoadEnzyme -opaque-pointers -passes="enzyme,function(mem2reg,early-cse,instsimplify,jump-threading,adce)" -enzyme-preopt=false -S | FileCheck %s ; fi
3+
4+
; ModuleID = '../examples/big/big_inlined_correctness.cpp'
5+
source_filename = "../examples/big/big_inlined_correctness.cpp"
6+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
7+
target triple = "x86_64-unknown-linux-gnu"
8+
9+
%struct.Prod = type { ptr, double }
10+
11+
declare i32 @dgemm_(ptr nocapture noundef readonly %transa_t, ptr nocapture noundef readonly %transb_t, ptr nocapture noundef readonly %m, ptr nocapture noundef readonly %n, ptr nocapture noundef readonly %k, ptr nocapture noundef readonly %alpha, ptr nocapture noundef readonly %a, ptr nocapture noundef readonly %lda, ptr nocapture noundef readonly %b, ptr nocapture noundef readonly %ldb, ptr nocapture noundef readonly %beta, ptr nocapture noundef %c, ptr nocapture noundef readonly %ldc)
12+
13+
; Function Attrs: mustprogress noinline nounwind uwtable
14+
define dso_local void @_Z3mulR4ProdPd(ptr nocapture noundef nonnull align 8 dereferenceable(16) %P, ptr noalias nocapture noundef readonly %rhs) {
15+
entry:
16+
%N = alloca i8, align 1
17+
%ten = alloca i32, align 4
18+
%one = alloca double, align 8
19+
%zero = alloca double, align 8
20+
%calloc = call dereferenceable_or_null(32) ptr @calloc(i64 1, i64 32)
21+
store i8 78, ptr %N, align 1
22+
store i32 2, ptr %ten, align 4
23+
store double 1.000000e+00, ptr %one, align 8
24+
store double 0.000000e+00, ptr %zero, align 8
25+
%call1 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten)
26+
%0 = load ptr, ptr %P, align 8
27+
%call2 = call i32 @dgemm_(ptr noundef nonnull %N, ptr noundef nonnull %N, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %ten, ptr noundef nonnull %one, ptr noundef %calloc, ptr noundef nonnull %ten, ptr noundef %rhs, ptr noundef nonnull %ten, ptr noundef nonnull %zero, ptr noundef %0, ptr noundef nonnull %ten)
28+
%alpha = getelementptr inbounds %struct.Prod, ptr %P, i64 0, i32 1
29+
store double 0.000000e+00, ptr %alpha, align 8
30+
ret void
31+
}
32+
33+
declare noalias ptr @malloc(i64)
34+
35+
; Function Attrs: mustprogress nounwind uwtable
36+
define dso_local double @_Z8simulatePd(ptr nocapture noundef readonly %P) {
37+
entry:
38+
%M = alloca %struct.Prod, align 8
39+
%call = tail call noalias dereferenceable_or_null(32) ptr @malloc(i64 noundef 32)
40+
store ptr %call, ptr %M, align 8
41+
%alpha = getelementptr inbounds %struct.Prod, ptr %M, i64 0, i32 1
42+
store double 1.000000e+00, ptr %alpha, align 8
43+
call void @_Z3mulR4ProdPd(ptr noundef nonnull align 8 dereferenceable(16) %M, ptr noundef %P)
44+
%0 = load ptr, ptr %M, align 8
45+
%1 = load double, ptr %0, align 8
46+
ret double %1
47+
}
48+
49+
define void @caller(ptr %A, ptr %Adup) {
50+
entry:
51+
call void (...) @_Z17__enzyme_autodiffz(ptr noundef nonnull @_Z8simulatePd, metadata !"enzyme_dup", ptr noundef nonnull %A, ptr noundef nonnull %Adup)
52+
ret void
53+
}
54+
55+
declare void @_Z17__enzyme_autodiffz(...)
56+
57+
declare noalias noundef ptr @calloc(i64 noundef, i64 noundef)
58+
59+
; we must actually save or set the matmul
60+
; CHECK: define internal void @diffe_Z3mulR4ProdPd(ptr nocapture align 8 dereferenceable(16) %P, ptr nocapture align 8 %"P'", ptr noalias nocapture readonly %rhs, ptr nocapture %"rhs'", { ptr, ptr, ptr, ptr } %tapeArg)
61+
; CHECK-NEXT: invertentry:
62+
; CHECK-NEXT: %byref.transpose.transb = alloca i8, align 1
63+
; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8
64+
; CHECK-NEXT: %byref.transpose.transa = alloca i8, align 1
65+
; CHECK-NEXT: %byref.constant.fp.1.05 = alloca double, align 8
66+
; CHECK-NEXT: %byref.constant.char.G = alloca i8, align 1
67+
; CHECK-NEXT: %byref.constant.int.0 = alloca i32, align 4
68+
; CHECK-NEXT: %byref.constant.int.06 = alloca i32, align 4
69+
; CHECK-NEXT: %byref.constant.fp.1.07 = alloca double, align 8
70+
; CHECK-NEXT: %[[i0:.+]] = alloca i32, align 4
71+
; CHECK-NEXT: %byref.transpose.transb11 = alloca i8, align 1
72+
; CHECK-NEXT: %byref.constant.fp.1.014 = alloca double, align 8
73+
; CHECK-NEXT: %byref.transpose.transa16 = alloca i8, align 1
74+
; CHECK-NEXT: %byref.constant.fp.1.019 = alloca double, align 8
75+
; CHECK-NEXT: %byref.constant.char.G20 = alloca i8, align 1
76+
; CHECK-NEXT: %byref.constant.int.021 = alloca i32, align 4
77+
; CHECK-NEXT: %byref.constant.int.022 = alloca i32, align 4
78+
; CHECK-NEXT: %byref.constant.fp.1.023 = alloca double, align 8
79+
; CHECK-NEXT: %[[i1:.+]] = alloca i32, align 4
80+
; CHECK-NEXT: %malloccall3 = alloca i8, i64 8, align 8
81+
; CHECK-NEXT: %malloccall = alloca i8, i64 1, align 1
82+
; CHECK-NEXT: %malloccall2 = alloca i8, i64 8, align 8
83+
; CHECK-NEXT: %malloccall1 = alloca i8, i64 4, align 4
84+
; CHECK-NEXT: %"calloc'mi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 2
85+
; CHECK-NEXT: %calloc = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 3
86+
; CHECK-NEXT: store i8 78, ptr %malloccall, align 1
87+
; CHECK-NEXT: store i32 2, ptr %malloccall1, align 4
88+
; CHECK-NEXT: store double 1.000000e+00, ptr %malloccall2, align 8
89+
; CHECK-NEXT: store double 0.000000e+00, ptr %malloccall3, align 8
90+
; CHECK-NEXT: %"'il_phi" = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 1
91+
; CHECK-NEXT: %[[i2:.+]] = extractvalue { ptr, ptr, ptr, ptr } %tapeArg, 0
92+
; CHECK-NEXT: %"alpha'ipg" = getelementptr inbounds %struct.Prod, ptr %"P'", i64 0, i32 1
93+
; CHECK-NEXT: store double 0.000000e+00, ptr %"alpha'ipg", align 8
94+
; CHECK-NEXT: %ld.transb = load i8, ptr %malloccall, align 1
95+
; CHECK-NEXT: %[[i3:.+]] = icmp eq i8 %ld.transb, 110
96+
; CHECK-NEXT: %[[i4:.+]] = select i1 %[[i3]], i8 116, i8 0
97+
; CHECK-NEXT: %[[i5:.+]] = icmp eq i8 %ld.transb, 78
98+
; CHECK-NEXT: %[[i6:.+]] = select i1 %[[i5]], i8 84, i8 %[[i4]]
99+
; CHECK-NEXT: %[[i7:.+]] = icmp eq i8 %ld.transb, 116
100+
; CHECK-NEXT: %[[i8:.+]] = select i1 %[[i7]], i8 110, i8 %[[i6]]
101+
; CHECK-NEXT: %[[i9:.+]] = icmp eq i8 %ld.transb, 84
102+
; CHECK-NEXT: %[[i10:.+]] = select i1 %[[i9]], i8 78, i8 %[[i8]]
103+
; CHECK-NEXT: store i8 %[[i10]], ptr %byref.transpose.transb, align 1
104+
; CHECK-NEXT: %ld.row.trans = load i8, ptr %malloccall, align 1
105+
; CHECK-NEXT: %[[i11:.+]] = icmp eq i8 %ld.row.trans, 110
106+
; CHECK-NEXT: %[[i12:.+]] = icmp eq i8 %ld.row.trans, 78
107+
; CHECK-NEXT: %[[i13:.+]] = or i1 %[[i12]], %[[i11]]
108+
; CHECK-NEXT: %[[i14:.+]] = select i1 %[[i13]], ptr %byref.transpose.transb, ptr %malloccall
109+
; CHECK-NEXT: %[[i15:.+]] = select i1 %[[i13]], ptr %"'il_phi", ptr %rhs
110+
; CHECK-NEXT: %[[i16:.+]] = select i1 %[[i13]], ptr %rhs, ptr %"'il_phi"
111+
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.0, align 8
112+
; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i14]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i15]], ptr %malloccall1, ptr %[[i16]], ptr %malloccall1, ptr %byref.constant.fp.1.0, ptr %"calloc'mi", ptr %malloccall1, i32 1, i32 1)
113+
; CHECK-NEXT: %ld.transa = load i8, ptr %malloccall, align 1
114+
; CHECK-NEXT: %[[i17:.+]] = icmp eq i8 %ld.transa, 110
115+
; CHECK-NEXT: %[[i18:.+]] = select i1 %[[i17]], i8 116, i8 0
116+
; CHECK-NEXT: %[[i19:.+]] = icmp eq i8 %ld.transa, 78
117+
; CHECK-NEXT: %[[i20:.+]] = select i1 %[[i19]], i8 84, i8 %[[i18]]
118+
; CHECK-NEXT: %[[i21:.+]] = icmp eq i8 %ld.transa, 116
119+
; CHECK-NEXT: %[[i22:.+]] = select i1 %[[i21]], i8 110, i8 %[[i20]]
120+
; CHECK-NEXT: %[[i23:.+]] = icmp eq i8 %ld.transa, 84
121+
; CHECK-NEXT: %[[i24:.+]] = select i1 %[[i23]], i8 78, i8 %[[i22]]
122+
; CHECK-NEXT: store i8 %[[i24]], ptr %byref.transpose.transa, align 1
123+
; CHECK-NEXT: %ld.row.trans2 = load i8, ptr %malloccall, align 1
124+
; CHECK-NEXT: %[[i25:.+]] = icmp eq i8 %ld.row.trans2, 110
125+
; CHECK-NEXT: %[[i26:.+]] = icmp eq i8 %ld.row.trans2, 78
126+
; CHECK-NEXT: %[[i27:.+]] = or i1 %[[i26]], %[[i25]]
127+
; CHECK-NEXT: %[[i28:.+]] = select i1 %[[i27]], ptr %byref.transpose.transa, ptr %malloccall
128+
; CHECK-NEXT: %[[i29:.+]] = select i1 %[[i27]], ptr %[[i2]], ptr %"'il_phi"
129+
; CHECK-NEXT: %[[i30:.+]] = select i1 %[[i27]], ptr %"'il_phi", ptr %[[i2]]
130+
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.05, align 8
131+
; CHECK-NEXT: call void @dgemm_(ptr %[[i28]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i29]], ptr %malloccall1, ptr %[[i30]], ptr %malloccall1, ptr %byref.constant.fp.1.05, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
132+
; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G, align 1
133+
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.0, align 4
134+
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.06, align 4
135+
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.07, align 8
136+
; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G, ptr %byref.constant.int.0, ptr %byref.constant.int.06, ptr %byref.constant.fp.1.07, ptr %malloccall3, ptr %malloccall1, ptr %malloccall1, ptr %"'il_phi", ptr %malloccall1, ptr %[[i0]], i32 1)
137+
; CHECK-NEXT: tail call void @free(ptr nonnull %[[i2]])
138+
; CHECK-NEXT: %ld.transb10 = load i8, ptr %malloccall, align 1
139+
; CHECK-NEXT: %[[i31:.+]] = icmp eq i8 %ld.transb10, 110
140+
; CHECK-NEXT: %[[i32:.+]] = select i1 %[[i31]], i8 116, i8 0
141+
; CHECK-NEXT: %[[i33:.+]] = icmp eq i8 %ld.transb10, 78
142+
; CHECK-NEXT: %[[i34:.+]] = select i1 %[[i33]], i8 84, i8 %[[i32]]
143+
; CHECK-NEXT: %[[i35:.+]] = icmp eq i8 %ld.transb10, 116
144+
; CHECK-NEXT: %[[i36:.+]] = select i1 %[[i35]], i8 110, i8 %[[i34]]
145+
; CHECK-NEXT: %[[i37:.+]] = icmp eq i8 %ld.transb10, 84
146+
; CHECK-NEXT: %[[i38:.+]] = select i1 %[[i37]], i8 78, i8 %[[i36]]
147+
; CHECK-NEXT: store i8 %[[i38]], ptr %byref.transpose.transb11, align 1
148+
; CHECK-NEXT: %ld.row.trans12 = load i8, ptr %malloccall, align 1
149+
; CHECK-NEXT: %[[i39:.+]] = icmp eq i8 %ld.row.trans12, 110
150+
; CHECK-NEXT: %[[i40:.+]] = icmp eq i8 %ld.row.trans12, 78
151+
; CHECK-NEXT: %[[i41:.+]] = or i1 %[[i40]], %[[i39]]
152+
; CHECK-NEXT: %[[i42:.+]] = select i1 %[[i41]], ptr %byref.transpose.transb11, ptr %malloccall
153+
; CHECK-NEXT: %[[i43:.+]] = select i1 %[[i41]], ptr %"calloc'mi", ptr %rhs
154+
; CHECK-NEXT: %[[i44:.+]] = select i1 %[[i41]], ptr %rhs, ptr %"calloc'mi"
155+
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.014, align 8
156+
; CHECK-NEXT: call void @dgemm_(ptr %malloccall, ptr %[[i42]], ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %[[i43]], ptr %malloccall1, ptr %[[i44]], ptr %malloccall1, ptr %byref.constant.fp.1.014, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
157+
; CHECK-NEXT: %ld.transa15 = load i8, ptr %malloccall, align 1
158+
; CHECK-NEXT: %[[i45:.+]] = icmp eq i8 %ld.transa15, 110
159+
; CHECK-NEXT: %[[i46:.+]] = select i1 %[[i45]], i8 116, i8 0
160+
; CHECK-NEXT: %[[i47:.+]] = icmp eq i8 %ld.transa15, 78
161+
; CHECK-NEXT: %[[i48:.+]] = select i1 %[[i47]], i8 84, i8 %[[i46]]
162+
; CHECK-NEXT: %[[i49:.+]] = icmp eq i8 %ld.transa15, 116
163+
; CHECK-NEXT: %[[i50:.+]] = select i1 %[[i49]], i8 110, i8 %[[i48]]
164+
; CHECK-NEXT: %[[i51:.+]] = icmp eq i8 %ld.transa15, 84
165+
; CHECK-NEXT: %[[i52:.+]] = select i1 %[[i51]], i8 78, i8 %[[i50]]
166+
; CHECK-NEXT: store i8 %[[i52]], ptr %byref.transpose.transa16, align 1
167+
; CHECK-NEXT: %ld.row.trans17 = load i8, ptr %malloccall, align 1
168+
; CHECK-NEXT: %[[i53:.+]] = icmp eq i8 %ld.row.trans17, 110
169+
; CHECK-NEXT: %[[i54:.+]] = icmp eq i8 %ld.row.trans17, 78
170+
; CHECK-NEXT: %[[i55:.+]] = or i1 %[[i54]], %[[i53]]
171+
; CHECK-NEXT: %[[i56:.+]] = select i1 %[[i55:.+]], ptr %byref.transpose.transa16, ptr %malloccall
172+
; CHECK-NEXT: %[[i57:.+]] = select i1 %[[i55:.+]], ptr %rhs, ptr %"calloc'mi"
173+
; CHECK-NEXT: %[[i58:.+]] = select i1 %[[i55:.+]], ptr %"calloc'mi", ptr %rhs
174+
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.019, align 8
175+
; CHECK-NEXT: call void @dgemm_(ptr %[[i56]], ptr %malloccall, ptr %malloccall1, ptr %malloccall1, ptr %malloccall1, ptr %malloccall2, ptr %57, ptr %malloccall1, ptr %58, ptr %malloccall1, ptr %byref.constant.fp.1.019, ptr %"rhs'", ptr %malloccall1, i32 1, i32 1)
176+
; CHECK-NEXT: store i8 71, ptr %byref.constant.char.G20, align 1
177+
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.021, align 4
178+
; CHECK-NEXT: store i32 0, ptr %byref.constant.int.022, align 4
179+
; CHECK-NEXT: store double 1.000000e+00, ptr %byref.constant.fp.1.023, align 8
180+
; CHECK-NEXT: call void @dlascl_(ptr %byref.constant.char.G20, ptr %byref.constant.int.021, ptr %byref.constant.int.022, ptr %byref.constant.fp.1.023, ptr %malloccall2, ptr %malloccall1, ptr %malloccall1, ptr %"calloc'mi", ptr %malloccall1, ptr %[[i1]], i32 1)
181+
; CHECK-NEXT: call void @free(ptr nonnull %"calloc'mi")
182+
; CHECK-NEXT: call void @free(ptr %calloc)
183+
; CHECK-NEXT: ret void
184+
; CHECK-NEXT: }
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
2+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
3+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
4+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
5+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=1 | %lli - ; fi
6+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
7+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
8+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
9+
// RUN: if [ %llvmver -ge 12 ]; then %clang -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-lapack-copy=0 | %lli - ; fi
10+
11+
#include <stdio.h>
12+
#include <stdint.h>
13+
#include <stdlib.h>
14+
#include "test_utils.h"
15+
#include "../blas_inline.h"
16+
17+
extern int enzyme_dup;
18+
extern int enzyme_dupnoneed;
19+
extern int enzyme_out;
20+
extern int enzyme_const;
21+
22+
#include <assert.h>
23+
24+
void __enzyme_autodiff(void*, ...);
25+
26+
const size_t n = 20;
27+
28+
#include <string.h>
29+
struct Prod {
30+
double* out;
31+
double alpha;
32+
};
33+
34+
__attribute__((noinline))
35+
void mul(struct Prod* P, double* __restrict__ rhs) {
36+
double* tmp= (double*)malloc(sizeof(double)*n*n);
37+
memset(tmp, 0, n*n*sizeof(double));
38+
char N = 'N';
39+
int ten = n;
40+
double one = 1.0;
41+
double zero = 0.0;
42+
43+
dgemm_(&N, &N, &ten, &ten, &ten, &one, rhs, &ten, rhs, &ten, &one, tmp, &ten);
44+
dgemm_(&N, &N, &ten, &ten, &ten, &one, tmp, &ten, rhs, &ten, &zero, P->out, &ten);
45+
P->alpha = 0;
46+
return;
47+
}
48+
49+
double simulate(double* P) {
50+
struct Prod M;
51+
M.out = (double*)malloc(sizeof(double)*n*n);
52+
M.alpha = 1.0;
53+
mul(&M, P);
54+
return M.out[0];
55+
// double *out = (double*)malloc(sizeof(double)*n*n);
56+
// dgemm_(&N, &N, &ten, &ten, &ten, &one, P1.data(), &ten, P.data(), &ten, &zero, &out[0], &ten);
57+
// return P1(0, 0);
58+
}
59+
60+
int main(int argc, char **argv) {
61+
62+
double A[n * n];
63+
double Adup[n * n];
64+
double Adup_fd[n * n];
65+
66+
for (int i = 0; i < n; i++) {
67+
for (int j = 0; j < n; j++) {
68+
A[n*i + j] = j == i ? 0.3 : 0.1;
69+
Adup[n*i + j] = 0.0;
70+
Adup_fd[n*i + j] = 0.0;
71+
}
72+
}
73+
74+
double delta = 0.001;
75+
delta = delta * delta;
76+
77+
double fx = simulate(A);
78+
printf("f(A) = %f\n", fx);
79+
80+
// if (argc == 2) {
81+
__enzyme_autodiff((void *)simulate, enzyme_dup, &A[0], &Adup[0]);
82+
printf("dP(0,0) = %f, dP(0,1) = %f, dP(1,0) = %f\n", Adup[0], Adup[1], Adup[2]);
83+
//}
84+
85+
for (int i = 0; i < n*n; i++) {
86+
A[i] += delta / 2;
87+
double fx2 = simulate(A);
88+
A[i] -= delta;
89+
double fx3 = simulate(A);
90+
A[i] += delta/2;
91+
Adup_fd[i] = (fx2 - fx3) / delta;
92+
93+
printf("dA_fd[%d]=%f\n", i, Adup_fd[i]);
94+
95+
APPROX_EQ(Adup[i], Adup_fd[i], 1e-6);
96+
}
97+
98+
return 0;
99+
}

enzyme/test/Integration/blas_inline.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ int xerbla_(const char *srname, integer *info, int len)
2828
return 0;
2929
}
3030
__attribute__((noinline))
31-
logical lsame_(char *ca, char *cb, int, int)
31+
logical lsame_(char *ca, char *cb, int ca_size, int cb_size)
3232
{
3333
/* System generated locals */
3434
logical ret_val;
@@ -764,12 +764,17 @@ doublereal ddot_(integer *n, doublereal *dx, integer *incx, doublereal *dy,
764764
} /* ddot_ */
765765

766766
__attribute__((noinline))
767-
/* Subroutine */ int dgemm_(const char *transa, const char *transb, const integer *m, const integer *
767+
/* Subroutine */ int dgemm_(const char *transa_t, const char *transb_t, const integer *m, const integer *
768768
n, const integer *k, const doublereal *alpha, const doublereal *a, const integer *lda,
769769
const doublereal *b, const integer *ldb, const doublereal *beta, doublereal *c, const integer
770770
*ldc)
771771
{
772772

773+
char transa_v = *transa_t;
774+
char* transa = &transa_v;
775+
776+
char transb_v = *transb_t;
777+
char* transb = &transb_v;
773778

774779
/* System generated locals */
775780
integer a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset, i__1, i__2,

enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,9 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
114114
}
115115
}
116116

117-
os << " if (!shadow && need_" << argname << " && !cache_" << argname
118-
<< ")\n"
117+
os << " if (!shadow && need_" << argname
118+
<< " && ((cacheMode && overwritten_args_ptr) ? !cache_" << argname
119+
<< " : true ))\n"
119120
<< " return true;\n";
120121
os << " }\n";
121122
}

0 commit comments

Comments
 (0)