Skip to content

Commit e4185e4

Browse files
dzarukinvpirogov
authored andcommitted
scratchpad: work around corrupted unique_ptr usage
1 parent 2c5ff00 commit e4185e4

File tree

3 files changed

+122
-7
lines changed

3 files changed

+122
-7
lines changed

src/common/scratchpad.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ struct global_scratchpad_t : public scratchpad_t {
6666
global_scratchpad_t(engine_t *engine, size_t size) {
6767
UNUSED(engine);
6868
if (size > size_) {
69-
auto *mem_storage = create_scratchpad_memory_storage(engine, size);
70-
mem_storage_.reset(mem_storage);
69+
delete mem_storage_;
70+
mem_storage_ = create_scratchpad_memory_storage(engine, size);
7171
size_ = size;
7272
}
7373
reference_count_++;
@@ -76,23 +76,28 @@ struct global_scratchpad_t : public scratchpad_t {
7676
~global_scratchpad_t() {
7777
reference_count_--;
7878
if (reference_count_ == 0) {
79-
mem_storage_.reset();
79+
delete mem_storage_;
80+
mem_storage_ = nullptr;
8081
size_ = 0;
8182
}
8283
}
8384

8485
virtual const memory_storage_t *get_memory_storage() const override {
85-
return mem_storage_.get();
86+
return mem_storage_;
8687
}
8788

8889
private:
89-
thread_local static std::unique_ptr<memory_storage_t> mem_storage_;
90+
thread_local static memory_storage_t *mem_storage_;
9091
thread_local static size_t size_;
9192
thread_local static unsigned int reference_count_;
9293
};
9394

94-
thread_local std::unique_ptr<memory_storage_t>
95-
global_scratchpad_t::mem_storage_(nullptr);
95+
// CAVEAT: avoid having non-trivially-constructed thread-local objects. Their
96+
// construction order may depends on the program execution and the final
97+
// destruction order may be such that a thread-local object is destroyed
98+
// before all its users are destroyed thus causing a crash at exit.
99+
// Tested by tests/gtests/test_global_scratchad.cpp
100+
thread_local memory_storage_t *global_scratchpad_t::mem_storage_ = nullptr;
96101
thread_local size_t global_scratchpad_t::size_ = 0;
97102
thread_local unsigned int global_scratchpad_t::reference_count_ = 0;
98103

tests/gtests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ file(GLOB PRIM_TEST_CASES_SRC
8484
test_resampling.cpp
8585
test_isa_mask.cpp
8686
test_isa_iface.cpp
87+
test_global_scratchpad.cpp
8788
)
8889

8990
# Some tests will fail on AArch64 and need to be skipped.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*******************************************************************************
2+
* Copyright 2020 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#include "dnnl_test_common.hpp"
18+
#include "gtest/gtest.h"
19+
20+
#include "dnnl.hpp"
21+
22+
namespace dnnl {
23+
24+
using dt = memory::data_type;
25+
using tag = memory::format_tag;
26+
27+
// This test checks that globally defined primitive_t object will be
28+
// successfully destroyed after finishing the program despite the order of
29+
// internal objects destruction.
30+
// The cause was thread-local non-trivially-constructed object in
31+
// global_scratchpad_t object which got destroyed before global_scratchpad_t
32+
// causing a crash.
33+
class global_scratchpad : public ::testing::Test {};
34+
35+
struct conv_ctx_t {
36+
conv_ctx_t() : eng_(engine::kind::cpu, 0), c_() {}
37+
38+
struct conv_t {
39+
conv_t()
40+
: src_md()
41+
, wei_md()
42+
, dst_md()
43+
, pd()
44+
, src_mem()
45+
, wei_mem()
46+
, dst_mem()
47+
, prim() {}
48+
49+
memory::desc src_md;
50+
memory::desc wei_md;
51+
memory::desc dst_md;
52+
convolution_forward::primitive_desc pd;
53+
memory src_mem;
54+
memory wei_mem;
55+
memory dst_mem;
56+
primitive prim;
57+
};
58+
59+
void Setup(memory::dims src_dims, memory::dims wei_dims,
60+
memory::dims dst_dims, memory::dims strides_dims,
61+
memory::dims dilations_dims, memory::dims padding_left,
62+
memory::dims padding_right) {
63+
c_.src_md = memory::desc(src_dims, dt::f32, tag::any);
64+
c_.wei_md = memory::desc(wei_dims, dt::f32, tag::any);
65+
c_.dst_md = memory::desc(dst_dims, dt::f32, tag::any);
66+
67+
auto desc = convolution_forward::desc(prop_kind::forward,
68+
algorithm::convolution_direct, c_.src_md, c_.wei_md, c_.dst_md,
69+
strides_dims, dilations_dims, padding_left, padding_right);
70+
71+
c_.pd = convolution_forward::primitive_desc(desc, eng_);
72+
73+
c_.src_mem = memory(c_.pd.src_desc(), eng_);
74+
c_.wei_mem = memory(c_.pd.weights_desc(), eng_);
75+
c_.dst_mem = memory(c_.pd.dst_desc(), eng_);
76+
77+
c_.prim = convolution_forward(c_.pd);
78+
}
79+
80+
engine eng_;
81+
struct conv_t c_;
82+
};
83+
84+
conv_ctx_t global_conv_ctx1;
85+
conv_ctx_t global_conv_ctx2;
86+
87+
TEST(global_scratchpad, TestGlobalScratchpad) {
88+
memory::dims src1 = {1, 1, 3, 4};
89+
memory::dims wei1 = {1, 1, 3, 3};
90+
memory::dims dst1 = {1, 1, 8, 5};
91+
memory::dims str1 = {1, 1};
92+
memory::dims dil1 = {0, 0};
93+
memory::dims pad_l1 = {3, 1};
94+
memory::dims pad_r1 = {4, 2};
95+
global_conv_ctx1.Setup(src1, wei1, dst1, str1, dil1, pad_l1, pad_r1);
96+
97+
memory::dims src2 = {256, 3, 227, 227};
98+
memory::dims wei2 = {96, 3, 11, 11};
99+
memory::dims dst2 = {256, 96, 55, 55};
100+
memory::dims str2 = {4, 4};
101+
memory::dims dil2 = {0, 0};
102+
memory::dims pad_l2 = {0, 0};
103+
memory::dims pad_r2 = {0, 0};
104+
global_conv_ctx2.Setup(src2, wei2, dst2, str2, dil2, pad_l2, pad_r2);
105+
106+
// if something goes wrong, test should return 139 on Linux.
107+
};
108+
109+
} // namespace dnnl

0 commit comments

Comments
 (0)