Skip to content

Commit

Permalink
Fix drogon::Task<> not destructing internal object (drogonframework#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
marty1885 authored Mar 1, 2021
1 parent af2bd6b commit 8bd1f56
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 14 deletions.
93 changes: 80 additions & 13 deletions lib/inc/drogon/utils/coroutine.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ struct is_awaitable<
template <typename T>
constexpr bool is_awaitable_v = is_awaitable<T>::value;

template <typename T>
struct final_awaiter
{
bool await_ready() noexcept
{
return false;
}
template <typename T>
auto await_suspend(std::coroutine_handle<T> handle) noexcept
{
return handle.promise().continuation_;
Expand All @@ -97,7 +97,7 @@ struct final_awaiter
};

template <typename T = void>
struct Task
struct [[nodiscard]] Task
{
struct promise_type;
using handle_type = std::coroutine_handle<promise_type>;
Expand All @@ -106,7 +106,7 @@ struct Task
{
}
Task(const Task &) = delete;
Task(Task &&other)
Task(Task && other)
{
coro_ = other.coro_;
other.coro_ = nullptr;
Expand Down Expand Up @@ -141,14 +141,23 @@ struct Task

auto final_suspend() noexcept
{
return final_awaiter<promise_type>{};
return final_awaiter{};
}

void unhandled_exception()
{
exception_ = std::current_exception();
}
const T &result() const

T &&result() &&
{
if (exception_ != nullptr)
std::rethrow_exception(exception_);
assert(value.has_value() == true);
return std::move(value.value());
}

T &result() &
{
if (exception_ != nullptr)
std::rethrow_exception(exception_);
Expand All @@ -167,15 +176,15 @@ struct Task
};
bool await_ready() const
{
return coro_.done();
return !coro_ || coro_.done();
}
std::coroutine_handle<> await_suspend(std::coroutine_handle<> awaiting)
{
coro_.promise().setContinuation(awaiting);
return coro_;
}

auto operator co_await() const noexcept
auto operator co_await() const &noexcept
{
struct awaiter
{
Expand All @@ -185,7 +194,7 @@ struct Task
}
bool await_ready() noexcept
{
return false;
return !coro_ || coro_.done();
}
auto await_suspend(std::coroutine_handle<> handle) noexcept
{
Expand All @@ -194,7 +203,36 @@ struct Task
}
T await_resume()
{
return coro_.promise().result();
auto &&v = coro_.promise().result();
return v;
}

private:
handle_type coro_;
};
return awaiter(coro_);
}

auto operator co_await() const &&noexcept
{
struct awaiter
{
public:
explicit awaiter(handle_type coro) : coro_(coro)
{
}
bool await_ready() noexcept
{
return !coro_ || coro_.done();
}
auto await_suspend(std::coroutine_handle<> handle) noexcept
{
coro_.promise().setContinuation(handle);
return coro_;
}
T await_resume()
{
return std::move(coro_.promise().result());
}

private:
Expand All @@ -206,7 +244,7 @@ struct Task
};

template <>
struct Task<void>
struct [[nodiscard]] Task<void>
{
struct promise_type;
using handle_type = std::coroutine_handle<promise_type>;
Expand All @@ -215,7 +253,7 @@ struct Task<void>
{
}
Task(const Task &) = delete;
Task(Task &&other)
Task(Task && other)
{
coro_ = other.coro_;
other.coro_ = nullptr;
Expand Down Expand Up @@ -248,7 +286,7 @@ struct Task<void>
}
auto final_suspend() noexcept
{
return final_awaiter<promise_type>{};
return final_awaiter{};
}
void unhandled_exception()
{
Expand All @@ -275,7 +313,35 @@ struct Task<void>
coro_.promise().setContinuation(awaiting);
return coro_;
}
auto operator co_await() const noexcept
auto operator co_await() const &noexcept
{
struct awaiter
{
public:
explicit awaiter(handle_type coro) : coro_(coro)
{
}
bool await_ready() noexcept
{
return !coro_ || coro_.done();
}
auto await_suspend(std::coroutine_handle<> handle) noexcept
{
coro_.promise().setContinuation(handle);
return coro_;
}
auto await_resume()
{
coro_.promise().result();
}

private:
handle_type coro_;
};
return awaiter(coro_);
}

auto operator co_await() const &&noexcept
{
struct awaiter
{
Expand Down Expand Up @@ -504,6 +570,7 @@ inline auto co_future(Await await) noexcept
}(std::move(prom), std::move(await));
return fut;
}

namespace internal
{
struct TimerAwaiter : CallbackAwaiter<void>
Expand Down
46 changes: 45 additions & 1 deletion lib/tests/CoroutineTest.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
#include <drogon/utils/coroutine.h>
#include <exception>
#include <memory>
#include <type_traits>
#include <iostream>

using namespace drogon;

namespace drogon::internal
{
struct SomeStruct
{
~SomeStruct()
{
beenDestructed = true;
}
static bool beenDestructed;
};

bool SomeStruct::beenDestructed = false;

struct StructAwaiter : public CallbackAwaiter<std::shared_ptr<SomeStruct>>
{
void await_suspend(std::coroutine_handle<> handle)
{
setValue(std::make_shared<SomeStruct>());
handle.resume();
}
};

} // namespace drogon::internal

int main()
{
// Basic checks making sure coroutine works as expected
Expand Down Expand Up @@ -39,7 +64,7 @@ int main()

try
{
f();
co_await f();
std::cerr << "Exception should have been thrown\n";
exit(1);
}
Expand All @@ -60,5 +85,24 @@ int main()
};
sync_wait(throw_in_task());

// Test coroutine destruction
auto destruct = []() -> Task<> {
auto awaitStruct = []() -> Task<std::shared_ptr<internal::SomeStruct>> {
co_return co_await internal::StructAwaiter();
};

auto awaitNothing = [awaitStruct]() -> Task<> {
co_await awaitStruct();
};

co_await awaitNothing();
};
sync_wait(destruct());
if (internal::SomeStruct::beenDestructed == false)
{
std::cerr << "Coroutine didn't destruct allocated object.\n";
exit(1);
}

std::cout << "Done testing coroutines. No error." << std::endl;
}

0 comments on commit 8bd1f56

Please sign in to comment.