Skip to content

Commit 7751994

Browse files
committed
Fixed support for static variable inside a coroutine.
1 parent 4f78c7d commit 7751994

File tree

5 files changed

+301
-6
lines changed

5 files changed

+301
-6
lines changed

CodeGenerator.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,7 +1262,7 @@ void CodeGenerator::InsertArg(const VarDecl* stmt)
12621262
HandleLocalStaticNonTrivialClass(stmt);
12631263

12641264
} else {
1265-
if(InsertVarDecl()) {
1265+
if(InsertVarDecl(stmt)) {
12661266
const auto desugaredType = GetType(GetDesugarType(stmt->getType()));
12671267

12681268
const bool isMemberPointer{isa<MemberPointerType>(desugaredType.getTypePtrOrNull())};
@@ -2542,7 +2542,8 @@ void CodeGenerator::InsertArg(const ForStmt* stmt)
25422542
WrapInParens(
25432543
[&]() {
25442544
if(const auto* init = stmt->getInit()) {
2545-
MultiStmtDeclCodeGenerator codeGenerator{mOutputFormatHelper, mLambdaStack, InsertVarDecl()};
2545+
MultiStmtDeclCodeGenerator codeGenerator{
2546+
mOutputFormatHelper, mLambdaStack, InsertVarDecl(nullptr)};
25462547
codeGenerator.InsertArg(init);
25472548

25482549
} else {

CodeGenerator.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ class CodeGenerator
281281
void EndLifetimeScope();
282282

283283
protected:
284-
virtual bool InsertVarDecl() { return true; }
284+
virtual bool InsertVarDecl(const VarDecl*) { return true; }
285285
virtual bool SkipSpaceAfterVarDecl() { return false; }
286286
virtual bool InsertComma() { return false; }
287287
virtual bool InsertSemi() { return true; }
@@ -500,7 +500,7 @@ class MultiStmtDeclCodeGenerator final : public CodeGenerator
500500
OnceFalse mInsertComma{}; //! Insert the comma after we have generated the first \c VarDecl and we are about to
501501
//! insert another one.
502502

503-
bool InsertVarDecl() override { return mInsertVarDecl; }
503+
bool InsertVarDecl(const VarDecl*) override { return mInsertVarDecl; }
504504
bool InsertComma() override { return mInsertComma; }
505505
bool InsertSemi() override { return false; }
506506
};
@@ -563,7 +563,7 @@ class CoroutinesCodeGenerator final : public CodeGenerator
563563
std::string GetFrameName() const { return mFrameName; }
564564

565565
protected:
566-
bool InsertVarDecl() override { return mInsertVarDecl; }
566+
bool InsertVarDecl(const VarDecl* vd) override { return mInsertVarDecl or (vd and vd->isStaticLocal()); }
567567
bool SkipSpaceAfterVarDecl() override { return not mInsertVarDecl; }
568568

569569
private:

CoroutinesCodeGenerator.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ class CoroutineASTTransformer : public StmtVisitor<CoroutineASTTransformer>
229229
void VisitDeclRefExpr(DeclRefExpr* stmt)
230230
{
231231
if(auto* vd = dyn_cast_or_null<VarDecl>(stmt->getDecl())) {
232-
RETURN_IF(not vd->isLocalVarDeclOrParm() or not Contains(mVarNamePrefix, vd));
232+
RETURN_IF(not vd->isLocalVarDeclOrParm() or vd->isStaticLocal() or not Contains(mVarNamePrefix, vd));
233233

234234
auto* memberExpr = mVarNamePrefix[vd];
235235

@@ -241,6 +241,10 @@ class CoroutineASTTransformer : public StmtVisitor<CoroutineASTTransformer>
241241
{
242242
for(auto* decl : stmt->decls()) {
243243
if(auto* varDecl = dyn_cast_or_null<VarDecl>(decl)) {
244+
if(varDecl->isStaticLocal()) {
245+
continue;
246+
}
247+
244248
// add this point a placement-new would be appropriate for at least some cases.
245249

246250
auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType());

tests/EduCoroutineStaticVarTest.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// cmdline:-std=c++20
2+
// cmdlineinsights:-edu-show-coroutine-transformation
3+
4+
#include <coroutine>
5+
#include <cstdio>
6+
#include <exception> // std::terminate
7+
#include <iostream>
8+
#include <list>
9+
#include <new>
10+
#include <string_view>
11+
#include <utility>
12+
13+
using namespace std::string_literals;
14+
using namespace std::string_view_literals;
15+
16+
struct Task {
17+
struct promise_type {
18+
Task get_return_object() noexcept { return {}; }
19+
std::suspend_never initial_suspend() noexcept { return {}; }
20+
std::suspend_never final_suspend() noexcept { return {}; }
21+
void return_void() noexcept {}
22+
void unhandled_exception() noexcept {}
23+
};
24+
};
25+
26+
struct Scheduler;
27+
28+
struct awaiter : std::suspend_always {
29+
Scheduler* _sched;
30+
31+
explicit awaiter(Scheduler& sched)
32+
: _sched{&sched}
33+
{}
34+
void await_suspend(std::coroutine_handle<> coro) const noexcept;
35+
};
36+
37+
struct Scheduler {
38+
std::list<std::coroutine_handle<>> _tasks{};
39+
40+
bool schedule()
41+
{
42+
auto task = _tasks.front();
43+
_tasks.pop_front();
44+
45+
if(not task.done()) { task.resume(); }
46+
47+
return not _tasks.empty();
48+
}
49+
50+
auto suspend() { return awaiter{*this}; }
51+
};
52+
53+
void awaiter::await_suspend(std::coroutine_handle<> coro) const noexcept
54+
{
55+
_sched->_tasks.push_back(coro);
56+
}
57+
58+
Task taskA(Scheduler& sched)
59+
{
60+
std::cout << "Hello, from task A\n"sv;
61+
62+
co_await sched.suspend();
63+
64+
static std::string res{"a is back doing work\n"s};
65+
std::cout << res;
66+
}
67+
68+
int main()
69+
{
70+
Scheduler scheduler{};
71+
72+
taskA(scheduler);
73+
74+
while(scheduler.schedule()) {}
75+
}
76+
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/*************************************************************************************
2+
* NOTE: The coroutine transformation you've enabled is a hand coded transformation! *
3+
* Most of it is _not_ present in the AST. What you see is an approximation. *
4+
*************************************************************************************/
5+
#include <coroutine>
6+
#include <cstdio>
7+
#include <exception>
8+
#include <iostream>
9+
#include <list>
10+
#include <new>
11+
#include <string_view>
12+
#include <utility>
13+
14+
using namespace std::string_literals;
15+
using namespace std::string_view_literals;
16+
17+
struct Task
18+
{
19+
struct promise_type
20+
{
21+
inline Task get_return_object() noexcept
22+
{
23+
return {};
24+
}
25+
26+
inline std::suspend_never initial_suspend() noexcept
27+
{
28+
return {};
29+
}
30+
31+
inline std::suspend_never final_suspend() noexcept
32+
{
33+
return {};
34+
}
35+
36+
inline void return_void() noexcept
37+
{
38+
}
39+
40+
inline void unhandled_exception() noexcept
41+
{
42+
}
43+
44+
// inline constexpr promise_type() noexcept = default;
45+
};
46+
47+
};
48+
49+
50+
struct Scheduler;
51+
52+
struct awaiter : public std::suspend_always
53+
{
54+
Scheduler * _sched;
55+
inline explicit awaiter(Scheduler & sched)
56+
: std::suspend_always()
57+
, _sched{&sched}
58+
{
59+
}
60+
61+
void await_suspend(std::coroutine_handle<void> coro) const noexcept;
62+
63+
};
64+
65+
66+
struct Scheduler
67+
{
68+
std::list<std::coroutine_handle<void>, std::allocator<std::coroutine_handle<void> > > _tasks = std::list<std::coroutine_handle<void>, std::allocator<std::coroutine_handle<void> > >{};
69+
inline bool schedule()
70+
{
71+
std::coroutine_handle<void> task = std::coroutine_handle<void>(this->_tasks.front());
72+
this->_tasks.pop_front();
73+
if(!task.done()) {
74+
task.resume();
75+
}
76+
77+
return !this->_tasks.empty();
78+
}
79+
80+
inline awaiter suspend()
81+
{
82+
return awaiter{*this};
83+
}
84+
85+
// inline ~Scheduler() noexcept = default;
86+
};
87+
88+
89+
void awaiter::await_suspend(std::coroutine_handle<void> coro) const noexcept
90+
{
91+
this->_sched->_tasks.push_back(coro);
92+
}
93+
94+
95+
struct __taskAFrame
96+
{
97+
void (*resume_fn)(__taskAFrame *);
98+
void (*destroy_fn)(__taskAFrame *);
99+
std::__coroutine_traits_sfinae<Task>::promise_type __promise;
100+
int __suspend_index;
101+
bool __initial_await_suspend_called;
102+
Scheduler & sched;
103+
std::suspend_never __suspend_58_6;
104+
awaiter __suspend_62_18;
105+
std::suspend_never __suspend_58_6_1;
106+
};
107+
108+
Task taskA(Scheduler & sched)
109+
{
110+
/* Allocate the frame including the promise */
111+
/* Note: The actual parameter new is __builtin_coro_size */
112+
__taskAFrame * __f = reinterpret_cast<__taskAFrame *>(operator new(sizeof(__taskAFrame)));
113+
__f->__suspend_index = 0;
114+
__f->__initial_await_suspend_called = false;
115+
__f->sched = std::forward<Scheduler &>(sched);
116+
117+
/* Construct the promise. */
118+
new (&__f->__promise)std::__coroutine_traits_sfinae<Task>::promise_type{};
119+
120+
/* Forward declare the resume and destroy function. */
121+
void __taskAResume(__taskAFrame * __f);
122+
void __taskADestroy(__taskAFrame * __f);
123+
124+
/* Assign the resume and destroy function pointers. */
125+
__f->resume_fn = &__taskAResume;
126+
__f->destroy_fn = &__taskADestroy;
127+
128+
/* Call the made up function with the coroutine body for initial suspend.
129+
This function will be called subsequently by coroutine_handle<>::resume()
130+
which calls __builtin_coro_resume(__handle_) */
131+
__taskAResume(__f);
132+
133+
134+
return __f->__promise.get_return_object();
135+
}
136+
137+
/* This function invoked by coroutine_handle<>::resume() */
138+
void __taskAResume(__taskAFrame * __f)
139+
{
140+
try
141+
{
142+
/* Create a switch to get to the correct resume point */
143+
switch(__f->__suspend_index) {
144+
case 0: break;
145+
case 1: goto __resume_taskA_1;
146+
case 2: goto __resume_taskA_2;
147+
}
148+
149+
/* co_await EduCoroutineStaticVarTest.cpp:58 */
150+
__f->__suspend_58_6 = __f->__promise.initial_suspend();
151+
if(!__f->__suspend_58_6.await_ready()) {
152+
__f->__suspend_58_6.await_suspend(std::coroutine_handle<Task::promise_type>::from_address(static_cast<void *>(__f)).operator std::coroutine_handle<void>());
153+
__f->__suspend_index = 1;
154+
__f->__initial_await_suspend_called = true;
155+
return;
156+
}
157+
158+
__resume_taskA_1:
159+
__f->__suspend_58_6.await_resume();
160+
std::operator<<(std::cout, std::operator""sv("Hello, from task A\n", 19UL));
161+
162+
/* co_await EduCoroutineStaticVarTest.cpp:62 */
163+
__f->__suspend_62_18 = __f->sched.suspend();
164+
if(!__f->__suspend_62_18.await_ready()) {
165+
__f->__suspend_62_18.await_suspend(std::coroutine_handle<Task::promise_type>::from_address(static_cast<void *>(__f)).operator std::coroutine_handle<void>());
166+
__f->__suspend_index = 2;
167+
return;
168+
}
169+
170+
__resume_taskA_2:
171+
__f->__suspend_62_18.await_resume();
172+
static std::basic_string<char, std::char_traits<char>, std::allocator<char> > res = {std::operator""s("a is back doing work\n", 21UL)};
173+
std::operator<<(std::cout, res);
174+
goto __final_suspend;
175+
} catch(...) {
176+
if(!__f->__initial_await_suspend_called) {
177+
throw ;
178+
}
179+
180+
__f->__promise.unhandled_exception();
181+
}
182+
183+
__final_suspend:
184+
185+
/* co_await EduCoroutineStaticVarTest.cpp:58 */
186+
__f->__suspend_58_6_1 = __f->__promise.final_suspend();
187+
if(!__f->__suspend_58_6_1.await_ready()) {
188+
__f->__suspend_58_6_1.await_suspend(std::coroutine_handle<Task::promise_type>::from_address(static_cast<void *>(__f)).operator std::coroutine_handle<void>());
189+
return;
190+
}
191+
192+
__f->destroy_fn(__f);
193+
}
194+
195+
/* This function invoked by coroutine_handle<>::destroy() */
196+
void __taskADestroy(__taskAFrame * __f)
197+
{
198+
/* destroy all variables with dtors */
199+
__f->~__taskAFrame();
200+
/* Deallocating the coroutine frame */
201+
/* Note: The actual argument to delete is __builtin_coro_frame with the promise as parameter */
202+
operator delete(static_cast<void *>(__f));
203+
}
204+
205+
206+
int main()
207+
{
208+
Scheduler scheduler = {{std::list<std::coroutine_handle<void>, std::allocator<std::coroutine_handle<void> > >{}}};
209+
taskA(scheduler);
210+
while(scheduler.schedule()) {
211+
}
212+
213+
return 0;
214+
}

0 commit comments

Comments
 (0)