Skip to content

Commit 1b05a28

Browse files
aaraujomtprimak
authored andcommitted
cpu: gemm: make sure to return out of memory if so
Some cases were still not returning proper status in case of out of memory.
1 parent 089420e commit 1b05a28

File tree

4 files changed

+28
-0
lines changed

4 files changed

+28
-0
lines changed

src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,13 +1927,22 @@ dnnl_status_t jit_avx512_common_gemm_f32(const char *transa, const char *transb,
19271927
c_buffers = (float *)malloc(
19281928
nthr_m * nthr_n * (nthr_k - 1) * MB * NB * sizeof(float),
19291929
PAGE_4K);
1930+
if (!c_buffers) {
1931+
free(ompstatus_);
1932+
return dnnl_out_of_memory;
1933+
}
19301934
}
19311935

19321936
const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
19331937
const size_t ws_size_per_thr
19341938
= rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
19351939
if (k > STACK_K_CAPACITY) {
19361940
ws_buffers = (float *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K);
1941+
if (!ws_buffers) {
1942+
free(ompstatus_);
1943+
free(c_buffers);
1944+
return dnnl_out_of_memory;
1945+
}
19371946
}
19381947

19391948
parallel(nthr_to_use, [&](int ithr, int nthr) {

src/cpu/gemm/f32/jit_avx_gemm_f32.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,8 @@ dnnl_status_t jit_avx_gemm_f32(const char *transa, const char *transb,
24962496
if (nthr_k > 1) {
24972497
ompstatus_ = (unsigned char *)malloc(
24982498
nthr_to_use * CACHE_LINE_SIZE, CACHE_LINE_SIZE);
2499+
if (!ompstatus_) return dnnl_out_of_memory;
2500+
24992501
ompstatus = (unsigned char volatile *)ompstatus_;
25002502
assert(ompstatus);
25012503

@@ -2505,13 +2507,22 @@ dnnl_status_t jit_avx_gemm_f32(const char *transa, const char *transb,
25052507
c_buffers = (float *)malloc(
25062508
nthr_m * nthr_n * (nthr_k - 1) * MB * NB * sizeof(float),
25072509
PAGE_4K);
2510+
if (!c_buffers) {
2511+
free(ompstatus_);
2512+
return dnnl_out_of_memory;
2513+
}
25082514
}
25092515

25102516
const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
25112517
const size_t ws_size_per_thr
25122518
= rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
25132519
if (k > STACK_K_CAPACITY) {
25142520
ws_buffers = (float *)malloc(nthr_to_use * ws_size_per_thr, PAGE_4K);
2521+
if (!ws_buffers) {
2522+
free(ompstatus_);
2523+
free(c_buffers);
2524+
return dnnl_out_of_memory;
2525+
}
25152526
}
25162527

25172528
parallel(nthr_to_use, [&](int ithr, int nthr) {

src/cpu/gemm/gemm_driver.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,11 @@ static dnnl_status_t gemm_threading_driver(
15981598
c_local_storage = (c_type *)malloc(
15991599
sizeof(c_type) * c_local_stride * nthr_goal, PAGE_4K);
16001600

1601+
if (!c_local_storage) {
1602+
free(thread_arg);
1603+
return dnnl_out_of_memory;
1604+
}
1605+
16011606
for (int ithr = 0; ithr < nthr_goal; ithr++) {
16021607
thread_arg[ithr].c_local = c_local_storage + ithr * c_local_stride;
16031608
thread_arg[ithr].ldc_local = ldc_local;

src/cpu/gemm/gemm_pack.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ dnnl_status_t sgemm_pack_get_size(const char *identifier, const char *transa,
165165
bool do_a = utils::one_of(*identifier, 'a', 'A');
166166
float alpha = 1.0f;
167167
gemm_pack_storage_shell_t shell {dnnl_get_max_threads()};
168+
if (!shell.get()) return dnnl_out_of_memory;
168169

169170
result = gemm_pack_driver<float, float, float>(identifier, transa, transb,
170171
M, N, K, &alpha, lda, ldb, nullptr, &shell, true);
@@ -204,6 +205,7 @@ dnnl_status_t gemm_bf16bf16f32_pack_get_size(const char *identifier,
204205

205206
float alpha = 1.0f;
206207
gemm_pack_storage_shell_t shell {dnnl_get_max_threads()};
208+
if (!shell.get()) return dnnl_out_of_memory;
207209

208210
result = gemm_pack_driver<bfloat16_t, bfloat16_t, float>(identifier, transa,
209211
transb, &M_s32, &N_s32, &K_s32, &alpha, &lda_s32, &ldb_s32, nullptr,
@@ -244,6 +246,7 @@ dnnl_status_t gemm_x8x8s32_pack_get_size(const char *identifier,
244246
bool do_a = utils::one_of(*identifier, 'a', 'A');
245247
float alpha = 1.0f;
246248
gemm_pack_storage_shell_t shell {dnnl_get_max_threads(), do_a, !do_a};
249+
if (!shell.get()) return dnnl_out_of_memory;
247250

248251
if (!use_reference_igemm<a_dt, b_dt>()) {
249252
result = gemm_pack_driver<a_dt, b_dt, int32_t>(identifier, transa,

0 commit comments

Comments
 (0)