Skip to content

Commit

Permalink
Add some(generator)
Browse files Browse the repository at this point in the history
Signed-off-by: WenTao Ou <[email protected]>
  • Loading branch information
owent committed Jun 20, 2022
1 parent 6ab5264 commit 9ff5cb6
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 14 deletions.
39 changes: 36 additions & 3 deletions include/libcopp/coroutine/callable_promise.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// clang-format on
#include <assert.h>
#include <type_traits>
#include <vector>
// clang-format off
#include <libcopp/utils/config/stl_include_suffix.h> // NOLINT(build/include_order)
// clang-format on
Expand All @@ -25,6 +26,9 @@

LIBCOPP_COPP_NAMESPACE_BEGIN

template <class TFUTURE>
class LIBCOPP_COPP_API_HEAD_ONLY some_delegate;

template <class TVALUE>
class LIBCOPP_COPP_API_HEAD_ONLY callable_future;

Expand Down Expand Up @@ -104,7 +108,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_promise_base<TVALUE, false> : public p
# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TPROMISE>
# else
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value> >
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value>>
# endif
class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base_type {
public:
Expand Down Expand Up @@ -143,7 +147,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base
# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TCPROMISE>
# else
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value> >
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value>>
# endif
inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TCPROMISE> caller) noexcept {
if (caller.promise().get_status() < promise_status::kDone) {
Expand All @@ -169,7 +173,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base
UTIL_FORCEINLINE const handle_type& get_callee() const noexcept { return callee_; }

protected:
inline void detach() noexcept {
void detach() noexcept {
// caller maybe null if the callable is already ready when co_await
auto caller = get_caller();
auto& callee_promise = get_callee().promise();
Expand Down Expand Up @@ -423,6 +427,35 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_future {
handle_type current_handle_;
};

template <class TFUTURE>
struct some_ready {
using type = std::vector<std::reference_wrapper<TFUTURE>>;
};

template <class TCONTAINER>
struct some_ready_container {
using container_type = typename std::decay<TCONTAINER>::type;
using value_type = typename std::decay<typename container_type::value_type>::type;
};

# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <class TREADY_CONTAINER, class TCONTAINER>
LIBCOPP_COPP_API_HEAD_ONLY callable_future<promise_status> some(
TREADY_CONTAINER&&ready_futures, size_t ready_count, TCONTAINER&&pending_futures) requires std::convertible_to <
typename std::decay<TREADY_CONTAINER>::type,
typename some_ready<typename some_ready_container<TCONTAINER>::value_type>::type > {
# else
template <class TREADY_CONTAINER, class TCONTAINER,
class = typename std::enable_if<std::is_same<
typename std::decay<TREADY_CONTAINER>::type,
typename some_ready<typename some_ready_container<TCONTAINER>::value_type>::type>::value>::type>
LIBCOPP_COPP_API_HEAD_ONLY callable_future<promise_status> some(TREADY_CONTAINER&& ready_futures, size_t ready_count,
TCONTAINER&& pending_futures) {
# endif
return some_delegate<typename some_ready_container<TCONTAINER>::value_type>::run(
std::forward<TREADY_CONTAINER>(ready_futures), ready_count, std::forward<TCONTAINER>(pending_futures));
}

LIBCOPP_COPP_NAMESPACE_END

#endif
209 changes: 204 additions & 5 deletions include/libcopp/coroutine/generator_promise.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// clang-format on
#include <assert.h>
#include <functional>
#include <list>
#include <type_traits>

#if defined(LIBCOPP_MACRO_ENABLE_STD_EXCEPTION_PTR) && LIBCOPP_MACRO_ENABLE_STD_EXCEPTION_PTR
Expand Down Expand Up @@ -68,7 +69,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY generator_context_base {
# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TPROMISE>
# else
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value> >
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value>>
# endif
UTIL_FORCEINLINE LIBCOPP_COPP_API_HEAD_ONLY void add_caller(
const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TPROMISE>& handle) noexcept {
Expand All @@ -80,7 +81,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY generator_context_base {
# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TPROMISE>
# else
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value> >
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value>>
# endif
UTIL_FORCEINLINE LIBCOPP_COPP_API_HEAD_ONLY void remove_caller(
const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TPROMISE>& handle, bool inherit_status) noexcept {
Expand Down Expand Up @@ -180,7 +181,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY generator_context_delegate<TVALUE, false> : pub
template <class TVALUE>
class LIBCOPP_COPP_API_HEAD_ONLY generator_context
: public generator_context_delegate<TVALUE, std::is_void<typename std::decay<TVALUE>::type>::value>,
public std::enable_shared_from_this<generator_context<TVALUE> > {
public std::enable_shared_from_this<generator_context<TVALUE>> {
public:
using base_type = generator_context_delegate<TVALUE, std::is_void<typename std::decay<TVALUE>::type>::value>;
using value_type = typename base_type::value_type;
Expand Down Expand Up @@ -222,7 +223,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY generator_awaitable_base : public awaitable_bas
# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TCPROMISE>
# else
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value> >
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value>>
# endif
inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TCPROMISE> caller) noexcept {
if (nullptr != context_ && caller.promise().get_status() < promise_status::kDone) {
Expand All @@ -243,7 +244,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY generator_awaitable_base : public awaitable_bas
}

protected:
inline promise_status detach() noexcept {
promise_status detach() noexcept {
promise_status result_status;
COPP_UNLIKELY_IF (nullptr == context_) {
result_status = promise_status::kInvalid;
Expand Down Expand Up @@ -439,11 +440,209 @@ class LIBCOPP_COPP_API_HEAD_ONLY generator_future {
UTIL_FORCEINLINE std::shared_ptr<context_type>& get_context() noexcept { return context_; }

private:
template <class TFUTURE>
friend class LIBCOPP_COPP_API_HEAD_ONLY some_delegate;

std::shared_ptr<context_type> context_;
await_suspend_callback_type await_suspend_callback_;
await_resume_callback_type await_resume_callback_;
};

// some
template <class TVALUE>
class LIBCOPP_COPP_API_HEAD_ONLY some_delegate<generator_future<TVALUE>> {
public:
using future_type = generator_future<TVALUE>;
using value_type = future_type::value_type;
using ready_output_type = typename some_ready<future_type>::type;

private:
struct context_type {
std::list<future_type*> pending;
ready_output_type ready;
size_t ready_bound = 0;
size_t scan_bound = 0;
promise_status status = promise_status::kCreated;
promise_caller_manager::handle_delegate caller_handle = promise_caller_manager::handle_delegate(nullptr);
};

static void suspend_future(const promise_caller_manager::handle_delegate& caller, future_type& generator) {
generator.get_context()->add_caller(caller);

// Custom event. awaitable object may be deleted after this call
if (generator.await_suspend_callback_) {
generator.await_suspend_callback_(generator.get_context());
}
}

static void resume_future(const promise_caller_manager::handle_delegate& caller, future_type& generator) {
generator.get_context()->remove_caller(caller);

// Custom event
if (generator.await_resume_callback_) {
generator.await_resume_callback_(*generator.get_context());
}
}

static void force_resume_all(context_type& context) {
for (auto& pending_future : context.pending) {
resume_future(context.caller_handle, *pending_future);
}

context.caller_handle = nullptr;
}

static void scan_ready(context_type& context) {
auto iter = context.pending.begin();

while (iter != context.pending.end()) {
if ((*iter)->is_pending()) {
++iter;
continue;
}
future_type& future = **iter;
context.ready.push_back(std::ref(future));
iter = context.pending.erase(iter);

resume_future(context.caller_handle, future);
}
}

public:
class awaitable_type : public awaitable_base_type {
public:
awaitable_type(context_type* context) : context_(context) {}

inline bool await_ready() noexcept {
if (nullptr == context_) {
return true;
}

if (context_->ready.size() >= context_->ready_bound) {
return true;
}

return context_->pending.empty();
}

# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TCPROMISE>
# else
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value>>
# endif
inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TCPROMISE> caller) noexcept {
if (nullptr == context_ || caller.promise().get_status() >= promise_status::kDone) {
// Already done and can not suspend again
caller.resume();
return;
}

set_caller(caller);

// Allow kill resume to forward error information
caller.promise().set_flag(promise_flag::kInternalWaitting, true);

// set caller for all futures
if (!context_->caller_handle) {
context_->caller_handle = caller;
// Copy pending here, the callback may call resume and will change the pending list
std::list<future_type*> copy_pending = context_->pending;
for (auto& pending_future : copy_pending) {
suspend_future(context_->caller_handle, *pending_future);
}
}
}

void await_resume() {
// caller maybe null if the callable is already ready when co_await
auto caller = get_caller();
if (caller) {
if (nullptr != caller.promise) {
caller.promise->set_flag(promise_flag::kInternalWaitting, false);
}
set_caller(nullptr);
}

if (nullptr == context_) {
return;
}

++context_->scan_bound;
if (context_->scan_bound >= context_->ready_bound) {
scan_ready(*context_);
context_->scan_bound = context_->ready.size();
}
}

private:
context_type* context_;
};

struct promise_type {
context_type* context_;

promise_type(context_type* context) : context_(context) {}
promise_type(const promise_type&) = delete;
promise_type(promise_type&&) = delete;
promise_type& operator=(const promise_type&) = delete;
promise_type& operator=(promise_type&&) = delete;
~promise_type() {
COPP_LIKELY_IF (nullptr != context_ && !context_->caller_handle) {
force_resume_all(*context_);
}
}

inline awaitable_type operator co_await() { return awaitable_type{context_}; }
};

template <class TCONTAINER>
static callable_future<promise_status> run(ready_output_type& ready_futures, size_t ready_count,
TCONTAINER&& futures) {
context_type context;

for (auto& future_object : futures) {
if (future_object.is_pending()) {
context.pending.push_back(&future_object);
} else {
context.ready.push_back(future_object);
}
}

if (context.ready.size() >= ready_count) {
context.ready.swap(ready_futures);
co_return promise_status::kDone;
}

if (ready_count >= context.pending.size() + ready_futures.size()) {
ready_count = context.pending.size() + ready_futures.size();
}
context.ready_bound = ready_count;
context.scan_bound = context.ready.size();
context.status = promise_status::kRunning;

{
promise_type some_promise{&context};
while (context.status < promise_status::kDone) {
// Killed by caller
auto current_status = co_yield callable_future<promise_status>::yield_status();
if (current_status >= promise_status::kDone) {
context.status = current_status;
break;
}

co_await some_promise;
}

// destroy promise object and detach handles
}

co_return context.status;
}

private:
std::shared_ptr<context_type> context_;
};

LIBCOPP_COPP_NAMESPACE_END

#endif
2 changes: 1 addition & 1 deletion include/libcotask/task_promise.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY task_awaitable_base : public LIBCOPP_COPP_NAMES
}

protected:
inline task_status_type detach() noexcept {
task_status_type detach() noexcept {
task_status_type result_status;
COPP_UNLIKELY_IF (nullptr == context_) {
result_status = task_status_type::kInvalid;
Expand Down
Loading

0 comments on commit 9ff5cb6

Please sign in to comment.