Skip to content

Commit 305f68e

Browse files
authored
optional use of blas copy for caching(#1226)
* add runtime [de]activation of blas copy
1 parent 1c20c36 commit 305f68e

File tree

8 files changed

+134
-508
lines changed

8 files changed

+134
-508
lines changed

.github/workflows/bcload.yml

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ jobs:
1414
build: ["Release"] # "RelWithDebInfo"
1515
os: [ubuntu-20.04]
1616

17-
exclude:
18-
# How to install FileCheck on ubuntu-18.04?
19-
- os: ubuntu-18.04
20-
llvm: 8
21-
2217
timeout-minutes: 30
2318
steps:
2419
- name: add llvm
@@ -45,46 +40,3 @@ jobs:
4540
cmake .. -DLLVM_EXTERNAL_LIT=`which lit` -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DLLVM_DIR=/usr/lib/llvm-${{ matrix.llvm }}/lib/cmake/llvm
4641
- name: make check-bcpass
4742
run: cd enzyme/build && make check-bcpass -j`nproc`
48-
49-
build-container:
50-
name: Bitcode loading CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ubuntu-18.04
51-
runs-on: ubuntu-latest
52-
container: ubuntu:18.04
53-
54-
strategy:
55-
fail-fast: false
56-
matrix:
57-
llvm: ["7", "8", "9", "10", "11", "12", "13"] #, "14"]
58-
build: ["Release"] # "RelWithDebInfo"
59-
60-
exclude:
61-
- llvm: 8
62-
63-
timeout-minutes: 30
64-
steps:
65-
- name: add llvm
66-
run: |
67-
apt-get -q update
68-
apt-get install -y ca-certificates software-properties-common wget gnupg2 python3 python3-pip sed git ssh zlib1g-dev
69-
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key| apt-key add -
70-
apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
71-
apt-get install -y autoconf cmake gcc g++ libtool gfortran llvm-${{ matrix.llvm }}-dev libomp-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libeigen3-dev libboost-dev
72-
python3 -m pip install --upgrade pip setuptools
73-
python3 -m pip install lit
74-
touch /usr/lib/llvm-${{ matrix.llvm }}/bin/yaml-bench
75-
if [[ '${{ matrix.llvm }}' == '7' || '${{ matrix.llvm }}' == '8' || '${{ matrix.llvm }}' == '9' ]]; then
76-
apt-get install -y llvm-${{ matrix.llvm }}-tools
77-
fi
78-
if [[ '${{ matrix.llvm }}' == '13' ]]; then
79-
sed -i 's/add_executable(llvm-omp-device-info IMPORTED)//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake
80-
sed -i 's/llvm-omp-device-info//g' /usr/lib/llvm-${{matrix.llvm}}/lib/cmake/llvm/LLVMExports*.cmake
81-
fi
82-
- uses: actions/checkout@v3
83-
- name: mkdir
84-
run: cd enzyme && rm -rf build && mkdir build
85-
- name: cmake
86-
run: |
87-
cd enzyme/build
88-
cmake .. -DLLVM_EXTERNAL_LIT=`which lit` -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DLLVM_DIR=/usr/lib/llvm-${{ matrix.llvm }}/lib/cmake/llvm
89-
- name: make check-bcpass
90-
run: cd enzyme/build && make check-bcpass -j`nproc`

enzyme/Enzyme/Utils.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ LLVMValueRef (*EnzymeSanitizeDerivatives)(LLVMValueRef, LLVMValueRef toset,
6565
LLVMValueRef) = nullptr;
6666

6767
extern llvm::cl::opt<bool> EnzymeZeroCache;
68+
llvm::cl::opt<bool>
69+
EnzymeBlasCopy("enzyme-blas-copy", cl::init(true), cl::Hidden,
70+
cl::desc("Use blas copy calls to cache vectors"));
6871
llvm::cl::opt<bool>
6972
EnzymeFastMath("enzyme-fast-math", cl::init(true), cl::Hidden,
7073
cl::desc("Use fast math on derivative compuation"));
@@ -664,6 +667,21 @@ Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
664667
return F;
665668
}
666669

670+
Function *getOrInsertMemcpyStridedBlas(Module &M, PointerType *T, Type *IT,
671+
BlasInfo blas) {
672+
std::string copy_name =
673+
(blas.prefix + blas.floatType + "copy" + blas.suffix).str();
674+
FunctionType *FT = FunctionType::get(Type::getVoidTy(M.getContext()),
675+
{IT, T, IT, T, IT}, false);
676+
#if LLVM_VERSION_MAJOR >= 9
677+
Function *dmemcpy =
678+
cast<Function>(M.getOrInsertFunction(copy_name, FT).getCallee());
679+
#else
680+
Function *dmemcpy = cast<Function>(M.getOrInsertFunction(copy_name, FT));
681+
#endif
682+
return dmemcpy;
683+
}
684+
667685
Function *getOrInsertMemcpyStrided(Module &M, PointerType *T, Type *IT,
668686
unsigned dstalign, unsigned srcalign) {
669687
Type *elementType = T->getPointerElementType();

enzyme/Enzyme/Utils.h

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ extern "C" {
8686
/// Print additional debug info relevant to performance
8787
extern llvm::cl::opt<bool> EnzymePrintPerf;
8888
extern llvm::cl::opt<bool> EnzymeStrongZero;
89+
extern llvm::cl::opt<bool> EnzymeBlasCopy;
8990
extern void (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
9091
const void *, LLVMValueRef);
9192
}
@@ -603,12 +604,29 @@ static inline bool isCertainPrint(const llvm::StringRef name) {
603604
return false;
604605
}
605606

607+
struct BlasInfo {
608+
llvm::StringRef floatType;
609+
llvm::StringRef prefix;
610+
llvm::StringRef suffix;
611+
llvm::StringRef function;
612+
};
613+
614+
#if LLVM_VERSION_MAJOR >= 16
615+
std::optional<BlasInfo> extractBLAS(llvm::StringRef in);
616+
#else
617+
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in);
618+
#endif
619+
606620
/// Create function for type that performs the derivative memcpy on floating
607621
/// point memory
608622
llvm::Function *getOrInsertDifferentialFloatMemcpy(
609623
llvm::Module &M, llvm::Type *T, unsigned dstalign, unsigned srcalign,
610624
unsigned dstaddr, unsigned srcaddr, unsigned bitwidth);
611625

626+
/// Create function for type that performs memcpy with a stride using blas copy
627+
llvm::Function *getOrInsertMemcpyStridedBlas(llvm::Module &M,
628+
llvm::PointerType *T,
629+
llvm::Type *IT, BlasInfo blas);
612630
/// Create function for type that performs memcpy with a stride
613631
llvm::Function *getOrInsertMemcpyStrided(llvm::Module &M, llvm::PointerType *T,
614632
llvm::Type *IT, unsigned dstalign,
@@ -1470,19 +1488,6 @@ static inline bool isNoCapture(const llvm::CallInst *call, size_t idx) {
14701488

14711489
void attributeKnownFunctions(llvm::Function &F);
14721490

1473-
struct BlasInfo {
1474-
llvm::StringRef floatType;
1475-
llvm::StringRef prefix;
1476-
llvm::StringRef suffix;
1477-
llvm::StringRef function;
1478-
};
1479-
1480-
#if LLVM_VERSION_MAJOR >= 16
1481-
std::optional<BlasInfo> extractBLAS(llvm::StringRef in);
1482-
#else
1483-
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in);
1484-
#endif
1485-
14861491
llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero = false);
14871492

14881493
llvm::Value *SanitizeDerivatives(llvm::Value *val, llvm::Value *toset,

enzyme/test/Enzyme/ReverseMode/blas/cblas_ddot.ll

Lines changed: 21 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -119,61 +119,17 @@ entry:
119119

120120
; CHECK: define internal { double*, double* } @[[augMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
121121
; CHECK-NEXT: entry:
122-
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %len, 8
123-
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
124-
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
125-
; CHECK-NEXT: %1 = icmp eq i32 %len, 0
126-
; CHECK-NEXT: br i1 %1, label %__enzyme_memcpy_double_32_da0sa0stride.exit, label %init.idx.i
127-
128-
; CHECK: init.idx.i: ; preds = %entry
129-
; CHECK-NEXT: %a.i = sub nsw i32 1, %len
130-
; CHECK-NEXT: %negidx.i = mul nsw i32 %a.i, %incm
131-
; CHECK-NEXT: %is.neg.i = icmp slt i32 %incm, 0
132-
; CHECK-NEXT: %startidx.i = select i1 %is.neg.i, i32 %negidx.i, i32 0
133-
; CHECK-NEXT: br label %for.body.i
134-
135-
; CHECK: for.body.i: ; preds = %for.body.i, %init.idx.i
136-
; CHECK-NEXT: %idx.i = phi i32 [ 0, %init.idx.i ], [ %idx.next.i, %for.body.i ]
137-
; CHECK-NEXT: %sidx.i = phi i32 [ %startidx.i, %init.idx.i ], [ %sidx.next.i, %for.body.i ]
138-
; CHECK-NEXT: %dst.i.i = getelementptr inbounds double, double* %0, i32 %idx.i
139-
; CHECK-NEXT: %src.i.i = getelementptr inbounds double, double* %m, i32 %sidx.i
140-
; CHECK-NEXT: %src.i.l.i = load double, double* %src.i.i
141-
; CHECK-NEXT: store double %src.i.l.i, double* %dst.i.i
142-
; CHECK-NEXT: %idx.next.i = add nsw i32 %idx.i, 1
143-
; CHECK-NEXT: %sidx.next.i = add nsw i32 %sidx.i, %incm
144-
; CHECK-NEXT: %2 = icmp eq i32 %len, %idx.next.i
145-
; CHECK-NEXT: br i1 %2, label %__enzyme_memcpy_double_32_da0sa0stride.exit, label %for.body.i
146-
147-
; CHECK: __enzyme_memcpy_double_32_da0sa0stride.exit: ; preds = %entry, %for.body.i
148-
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %len, 8
149-
; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1)
150-
; CHECK-NEXT: %3 = bitcast i8* %malloccall2 to double*
151-
; CHECK-NEXT: %4 = icmp eq i32 %len, 0
152-
; CHECK-NEXT: br i1 %4, label %__enzyme_memcpy_double_32_da0sa0stride.exit14, label %init.idx.i5
153-
154-
; CHECK: init.idx.i5: ; preds = %__enzyme_memcpy_double_32_da0sa0stride.exit
155-
; CHECK-NEXT: %a.i1 = sub nsw i32 1, %len
156-
; CHECK-NEXT: %negidx.i2 = mul nsw i32 %a.i1, %incn
157-
; CHECK-NEXT: %is.neg.i3 = icmp slt i32 %incn, 0
158-
; CHECK-NEXT: %startidx.i4 = select i1 %is.neg.i3, i32 %negidx.i2, i32 0
159-
; CHECK-NEXT: br label %for.body.i13
160-
161-
; CHECK: for.body.i13: ; preds = %for.body.i13, %init.idx.i5
162-
; CHECK-NEXT: %idx.i6 = phi i32 [ 0, %init.idx.i5 ], [ %idx.next.i11, %for.body.i13 ]
163-
; CHECK-NEXT: %sidx.i7 = phi i32 [ %startidx.i4, %init.idx.i5 ], [ %sidx.next.i12, %for.body.i13 ]
164-
; CHECK-NEXT: %dst.i.i8 = getelementptr inbounds double, double* %3, i32 %idx.i6
165-
; CHECK-NEXT: %src.i.i9 = getelementptr inbounds double, double* %n, i32 %sidx.i7
166-
; CHECK-NEXT: %src.i.l.i10 = load double, double* %src.i.i9
167-
; CHECK-NEXT: store double %src.i.l.i10, double* %dst.i.i8
168-
; CHECK-NEXT: %idx.next.i11 = add nsw i32 %idx.i6, 1
169-
; CHECK-NEXT: %sidx.next.i12 = add nsw i32 %sidx.i7, %incn
170-
; CHECK-NEXT: %5 = icmp eq i32 %len, %idx.next.i11
171-
; CHECK-NEXT: br i1 %5, label %__enzyme_memcpy_double_32_da0sa0stride.exit14, label %for.body.i13
172-
173-
; CHECK: __enzyme_memcpy_double_32_da0sa0stride.exit14: ; preds = %__enzyme_memcpy_double_32_da0sa0stride.exit, %for.body.i13
174-
; CHECK-NEXT: %6 = insertvalue { double*, double* } undef, double* %0, 0
175-
; CHECK-NEXT: %7 = insertvalue { double*, double* } %6, double* %3, 1
176-
; CHECK-NEXT: ret { double*, double* } %7
122+
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %len, 8
123+
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
124+
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
125+
; CHECK-NEXT: call void @cblas_dcopy(i32 %len, double* %m, i32 %incm, double* %0, i32 1)
126+
; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %len, 8
127+
; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1)
128+
; CHECK-NEXT: %1 = bitcast i8* %malloccall2 to double*
129+
; CHECK-NEXT: call void @cblas_dcopy(i32 %len, double* %n, i32 %incn, double* %1, i32 1)
130+
; CHECK-NEXT: %2 = insertvalue { double*, double* } undef, double* %0, 0
131+
; CHECK-NEXT: %3 = insertvalue { double*, double* } %2, double* %1, 1
132+
; CHECK-NEXT: ret { double*, double* } %3
177133
; CHECK-NEXT: }
178134

179135
; CHECK: define internal void @[[revMod]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, { double*, double* }
@@ -198,33 +154,11 @@ entry:
198154

199155
; CHECK: define internal double* @augmented_f.6(i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn)
200156
; CHECK-NEXT: entry:
201-
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %len, 8
202-
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
203-
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
204-
; CHECK-NEXT: %1 = icmp eq i32 %len, 0
205-
; CHECK-NEXT: br i1 %1, label %__enzyme_memcpy_double_32_da0sa0stride.exit, label %init.idx.i
206-
207-
; CHECK: init.idx.i: ; preds = %entry
208-
; CHECK-NEXT: %a.i = sub nsw i32 1, %len
209-
; CHECK-NEXT: %negidx.i = mul nsw i32 %a.i, %incm
210-
; CHECK-NEXT: %is.neg.i = icmp slt i32 %incm, 0
211-
; CHECK-NEXT: %startidx.i = select i1 %is.neg.i, i32 %negidx.i, i32 0
212-
; CHECK-NEXT: br label %for.body.i
213-
214-
; CHECK: for.body.i: ; preds = %for.body.i, %init.idx.i
215-
; CHECK-NEXT: %idx.i = phi i32 [ 0, %init.idx.i ], [ %idx.next.i, %for.body.i ]
216-
; CHECK-NEXT: %sidx.i = phi i32 [ %startidx.i, %init.idx.i ], [ %sidx.next.i, %for.body.i ]
217-
; CHECK-NEXT: %dst.i.i = getelementptr inbounds double, double* %0, i32 %idx.i
218-
; CHECK-NEXT: %src.i.i = getelementptr inbounds double, double* %m, i32 %sidx.i
219-
; CHECK-NEXT: %src.i.l.i = load double, double* %src.i.i
220-
; CHECK-NEXT: store double %src.i.l.i, double* %dst.i.i
221-
; CHECK-NEXT: %idx.next.i = add nsw i32 %idx.i, 1
222-
; CHECK-NEXT: %sidx.next.i = add nsw i32 %sidx.i, %incm
223-
; CHECK-NEXT: %2 = icmp eq i32 %len, %idx.next.i
224-
; CHECK-NEXT: br i1 %2, label %__enzyme_memcpy_double_32_da0sa0stride.exit, label %for.body.i
225-
226-
; CHECK: __enzyme_memcpy_double_32_da0sa0stride.exit: ; preds = %entry, %for.body.i
227-
; CHECK-NEXT: ret double* %0
157+
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %len, 8
158+
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
159+
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
160+
; CHECK-NEXT: call void @cblas_dcopy(i32 %len, double* %m, i32 %incm, double* %0, i32 1)
161+
; CHECK-NEXT: ret double* %0
228162
; CHECK-NEXT: }
229163

230164
; CHECK: define internal void @[[revModFirst]](i32 %len, double* noalias %m, i32 %incm, double* noalias %n, double* %"n'", i32 %incn, double %differeturn, double*
@@ -244,33 +178,11 @@ entry:
244178

245179
; CHECK: define internal double* @[[augModSecond]](i32 %len, double* noalias %m, double* %"m'", i32 %incm, double* noalias %n, i32 %incn)
246180
; CHECK-NEXT: entry:
247-
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %len, 8
248-
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
249-
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
250-
; CHECK-NEXT: %1 = icmp eq i32 %len, 0
251-
; CHECK-NEXT: br i1 %1, label %__enzyme_memcpy_double_32_da0sa0stride.exit, label %init.idx.i
252-
253-
; CHECK: init.idx.i: ; preds = %entry
254-
; CHECK-NEXT: %a.i = sub nsw i32 1, %len
255-
; CHECK-NEXT: %negidx.i = mul nsw i32 %a.i, %incn
256-
; CHECK-NEXT: %is.neg.i = icmp slt i32 %incn, 0
257-
; CHECK-NEXT: %startidx.i = select i1 %is.neg.i, i32 %negidx.i, i32 0
258-
; CHECK-NEXT: br label %for.body.i
259-
260-
; CHECK: for.body.i: ; preds = %for.body.i, %init.idx.i
261-
; CHECK-NEXT: %idx.i = phi i32 [ 0, %init.idx.i ], [ %idx.next.i, %for.body.i ]
262-
; CHECK-NEXT: %sidx.i = phi i32 [ %startidx.i, %init.idx.i ], [ %sidx.next.i, %for.body.i ]
263-
; CHECK-NEXT: %dst.i.i = getelementptr inbounds double, double* %0, i32 %idx.i
264-
; CHECK-NEXT: %src.i.i = getelementptr inbounds double, double* %n, i32 %sidx.i
265-
; CHECK-NEXT: %src.i.l.i = load double, double* %src.i.i
266-
; CHECK-NEXT: store double %src.i.l.i, double* %dst.i.i
267-
; CHECK-NEXT: %idx.next.i = add nsw i32 %idx.i, 1
268-
; CHECK-NEXT: %sidx.next.i = add nsw i32 %sidx.i, %incn
269-
; CHECK-NEXT: %2 = icmp eq i32 %len, %idx.next.i
270-
; CHECK-NEXT: br i1 %2, label %__enzyme_memcpy_double_32_da0sa0stride.exit, label %for.body.i
271-
272-
; CHECK: __enzyme_memcpy_double_32_da0sa0stride.exit: ; preds = %entry, %for.body.i
273-
; CHECK-NEXT: ret double* %0
181+
; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %len, 8
182+
; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize)
183+
; CHECK-NEXT: %0 = bitcast i8* %malloccall to double*
184+
; CHECK-NEXT: call void @cblas_dcopy(i32 %len, double* %n, i32 %incn, double* %0, i32 1)
185+
; CHECK-NEXT: ret double* %0
274186
; CHECK-NEXT: }
275187

276188

0 commit comments

Comments
 (0)