Skip to content

Commit

Permalink
Fixed support for static variable inside a coroutine.
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasfertig committed Jul 26, 2024
1 parent 4f78c7d commit 7751994
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 6 deletions.
5 changes: 3 additions & 2 deletions CodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ void CodeGenerator::InsertArg(const VarDecl* stmt)
HandleLocalStaticNonTrivialClass(stmt);

} else {
if(InsertVarDecl()) {
if(InsertVarDecl(stmt)) {
const auto desugaredType = GetType(GetDesugarType(stmt->getType()));

const bool isMemberPointer{isa<MemberPointerType>(desugaredType.getTypePtrOrNull())};
Expand Down Expand Up @@ -2542,7 +2542,8 @@ void CodeGenerator::InsertArg(const ForStmt* stmt)
WrapInParens(
[&]() {
if(const auto* init = stmt->getInit()) {
MultiStmtDeclCodeGenerator codeGenerator{mOutputFormatHelper, mLambdaStack, InsertVarDecl()};
MultiStmtDeclCodeGenerator codeGenerator{
mOutputFormatHelper, mLambdaStack, InsertVarDecl(nullptr)};
codeGenerator.InsertArg(init);

} else {
Expand Down
6 changes: 3 additions & 3 deletions CodeGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class CodeGenerator
void EndLifetimeScope();

protected:
virtual bool InsertVarDecl() { return true; }
virtual bool InsertVarDecl(const VarDecl*) { return true; }
virtual bool SkipSpaceAfterVarDecl() { return false; }
virtual bool InsertComma() { return false; }
virtual bool InsertSemi() { return true; }
Expand Down Expand Up @@ -500,7 +500,7 @@ class MultiStmtDeclCodeGenerator final : public CodeGenerator
OnceFalse mInsertComma{}; //! Insert the comma after we have generated the first \c VarDecl and we are about to
//! insert another one.

bool InsertVarDecl() override { return mInsertVarDecl; }
bool InsertVarDecl(const VarDecl*) override { return mInsertVarDecl; }
bool InsertComma() override { return mInsertComma; }
bool InsertSemi() override { return false; }
};
Expand Down Expand Up @@ -563,7 +563,7 @@ class CoroutinesCodeGenerator final : public CodeGenerator
std::string GetFrameName() const { return mFrameName; }

protected:
bool InsertVarDecl() override { return mInsertVarDecl; }
bool InsertVarDecl(const VarDecl* vd) override { return mInsertVarDecl or (vd and vd->isStaticLocal()); }
bool SkipSpaceAfterVarDecl() override { return not mInsertVarDecl; }

private:
Expand Down
6 changes: 5 additions & 1 deletion CoroutinesCodeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class CoroutineASTTransformer : public StmtVisitor<CoroutineASTTransformer>
void VisitDeclRefExpr(DeclRefExpr* stmt)
{
if(auto* vd = dyn_cast_or_null<VarDecl>(stmt->getDecl())) {
RETURN_IF(not vd->isLocalVarDeclOrParm() or not Contains(mVarNamePrefix, vd));
RETURN_IF(not vd->isLocalVarDeclOrParm() or vd->isStaticLocal() or not Contains(mVarNamePrefix, vd));

auto* memberExpr = mVarNamePrefix[vd];

Expand All @@ -241,6 +241,10 @@ class CoroutineASTTransformer : public StmtVisitor<CoroutineASTTransformer>
{
for(auto* decl : stmt->decls()) {
if(auto* varDecl = dyn_cast_or_null<VarDecl>(decl)) {
if(varDecl->isStaticLocal()) {
continue;
}

// add this point a placement-new would be appropriate for at least some cases.

auto* field = AddField(mASTData, GetName(*varDecl), varDecl->getType());
Expand Down
76 changes: 76 additions & 0 deletions tests/EduCoroutineStaticVarTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// cmdline:-std=c++20
// cmdlineinsights:-edu-show-coroutine-transformation

#include <coroutine>
#include <cstdio>
#include <exception> // std::terminate
#include <iostream>
#include <list>
#include <new>
#include <string_view>
#include <utility>

using namespace std::string_literals;
using namespace std::string_view_literals;

struct Task {
struct promise_type {
Task get_return_object() noexcept { return {}; }
std::suspend_never initial_suspend() noexcept { return {}; }
std::suspend_never final_suspend() noexcept { return {}; }
void return_void() noexcept {}
void unhandled_exception() noexcept {}
};
};

struct Scheduler;

struct awaiter : std::suspend_always {
Scheduler* _sched;

explicit awaiter(Scheduler& sched)
: _sched{&sched}
{}
void await_suspend(std::coroutine_handle<> coro) const noexcept;
};

struct Scheduler {
std::list<std::coroutine_handle<>> _tasks{};

bool schedule()
{
auto task = _tasks.front();
_tasks.pop_front();

if(not task.done()) { task.resume(); }

return not _tasks.empty();
}

auto suspend() { return awaiter{*this}; }
};

void awaiter::await_suspend(std::coroutine_handle<> coro) const noexcept
{
_sched->_tasks.push_back(coro);
}

Task taskA(Scheduler& sched)
{
std::cout << "Hello, from task A\n"sv;

co_await sched.suspend();

static std::string res{"a is back doing work\n"s};
std::cout << res;
}

int main()
{
Scheduler scheduler{};

taskA(scheduler);

while(scheduler.schedule()) {}
}

214 changes: 214 additions & 0 deletions tests/EduCoroutineStaticVarTest.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*************************************************************************************
* NOTE: The coroutine transformation you've enabled is a hand coded transformation! *
* Most of it is _not_ present in the AST. What you see is an approximation. *
*************************************************************************************/
#include <coroutine>
#include <cstdio>
#include <exception>
#include <iostream>
#include <list>
#include <new>
#include <string_view>
#include <utility>

using namespace std::string_literals;
using namespace std::string_view_literals;

struct Task
{
struct promise_type
{
inline Task get_return_object() noexcept
{
return {};
}

inline std::suspend_never initial_suspend() noexcept
{
return {};
}

inline std::suspend_never final_suspend() noexcept
{
return {};
}

inline void return_void() noexcept
{
}

inline void unhandled_exception() noexcept
{
}

// inline constexpr promise_type() noexcept = default;
};

};


struct Scheduler;

struct awaiter : public std::suspend_always
{
Scheduler * _sched;
inline explicit awaiter(Scheduler & sched)
: std::suspend_always()
, _sched{&sched}
{
}

void await_suspend(std::coroutine_handle<void> coro) const noexcept;

};


struct Scheduler
{
std::list<std::coroutine_handle<void>, std::allocator<std::coroutine_handle<void> > > _tasks = std::list<std::coroutine_handle<void>, std::allocator<std::coroutine_handle<void> > >{};
inline bool schedule()
{
std::coroutine_handle<void> task = std::coroutine_handle<void>(this->_tasks.front());
this->_tasks.pop_front();
if(!task.done()) {
task.resume();
}

return !this->_tasks.empty();
}

inline awaiter suspend()
{
return awaiter{*this};
}

// inline ~Scheduler() noexcept = default;
};


void awaiter::await_suspend(std::coroutine_handle<void> coro) const noexcept
{
this->_sched->_tasks.push_back(coro);
}


struct __taskAFrame
{
void (*resume_fn)(__taskAFrame *);
void (*destroy_fn)(__taskAFrame *);
std::__coroutine_traits_sfinae<Task>::promise_type __promise;
int __suspend_index;
bool __initial_await_suspend_called;
Scheduler & sched;
std::suspend_never __suspend_58_6;
awaiter __suspend_62_18;
std::suspend_never __suspend_58_6_1;
};

Task taskA(Scheduler & sched)
{
/* Allocate the frame including the promise */
/* Note: The actual parameter new is __builtin_coro_size */
__taskAFrame * __f = reinterpret_cast<__taskAFrame *>(operator new(sizeof(__taskAFrame)));
__f->__suspend_index = 0;
__f->__initial_await_suspend_called = false;
__f->sched = std::forward<Scheduler &>(sched);

/* Construct the promise. */
new (&__f->__promise)std::__coroutine_traits_sfinae<Task>::promise_type{};

/* Forward declare the resume and destroy function. */
void __taskAResume(__taskAFrame * __f);
void __taskADestroy(__taskAFrame * __f);

/* Assign the resume and destroy function pointers. */
__f->resume_fn = &__taskAResume;
__f->destroy_fn = &__taskADestroy;

/* Call the made up function with the coroutine body for initial suspend.
This function will be called subsequently by coroutine_handle<>::resume()
which calls __builtin_coro_resume(__handle_) */
__taskAResume(__f);


return __f->__promise.get_return_object();
}

/* This function invoked by coroutine_handle<>::resume() */
void __taskAResume(__taskAFrame * __f)
{
try
{
/* Create a switch to get to the correct resume point */
switch(__f->__suspend_index) {
case 0: break;
case 1: goto __resume_taskA_1;
case 2: goto __resume_taskA_2;
}

/* co_await EduCoroutineStaticVarTest.cpp:58 */
__f->__suspend_58_6 = __f->__promise.initial_suspend();
if(!__f->__suspend_58_6.await_ready()) {
__f->__suspend_58_6.await_suspend(std::coroutine_handle<Task::promise_type>::from_address(static_cast<void *>(__f)).operator std::coroutine_handle<void>());
__f->__suspend_index = 1;
__f->__initial_await_suspend_called = true;
return;
}

__resume_taskA_1:
__f->__suspend_58_6.await_resume();
std::operator<<(std::cout, std::operator""sv("Hello, from task A\n", 19UL));

/* co_await EduCoroutineStaticVarTest.cpp:62 */
__f->__suspend_62_18 = __f->sched.suspend();
if(!__f->__suspend_62_18.await_ready()) {
__f->__suspend_62_18.await_suspend(std::coroutine_handle<Task::promise_type>::from_address(static_cast<void *>(__f)).operator std::coroutine_handle<void>());
__f->__suspend_index = 2;
return;
}

__resume_taskA_2:
__f->__suspend_62_18.await_resume();
static std::basic_string<char, std::char_traits<char>, std::allocator<char> > res = {std::operator""s("a is back doing work\n", 21UL)};
std::operator<<(std::cout, res);
goto __final_suspend;
} catch(...) {
if(!__f->__initial_await_suspend_called) {
throw ;
}

__f->__promise.unhandled_exception();
}

__final_suspend:

/* co_await EduCoroutineStaticVarTest.cpp:58 */
__f->__suspend_58_6_1 = __f->__promise.final_suspend();
if(!__f->__suspend_58_6_1.await_ready()) {
__f->__suspend_58_6_1.await_suspend(std::coroutine_handle<Task::promise_type>::from_address(static_cast<void *>(__f)).operator std::coroutine_handle<void>());
return;
}

__f->destroy_fn(__f);
}

/* This function invoked by coroutine_handle<>::destroy() */
void __taskADestroy(__taskAFrame * __f)
{
/* destroy all variables with dtors */
__f->~__taskAFrame();
/* Deallocating the coroutine frame */
/* Note: The actual argument to delete is __builtin_coro_frame with the promise as parameter */
operator delete(static_cast<void *>(__f));
}


int main()
{
Scheduler scheduler = {{std::list<std::coroutine_handle<void>, std::allocator<std::coroutine_handle<void> > >{}}};
taskA(scheduler);
while(scheduler.schedule()) {
}

return 0;
}

0 comments on commit 7751994

Please sign in to comment.