Skip to content

Commit

Permalink
Finish callable_promise
Browse files Browse the repository at this point in the history
Signed-off-by: owentou <[email protected]>
  • Loading branch information
owent committed Jan 18, 2022
1 parent b45b37a commit 9315053
Show file tree
Hide file tree
Showing 4 changed files with 450 additions and 91 deletions.
134 changes: 86 additions & 48 deletions include/libcopp/coroutine/callable_promise.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,24 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_promise_base<TVALUE, false> : public p
set_status(promise_status::kDone);
}
data_ = std::move(value);
has_return_ = true;
}

typename value_type& data() noexcept { return data_; }
const typename value_type& data() const noexcept { return data_; }
inline typename value_type& data() noexcept { return data_; }
inline const typename value_type& data() const noexcept { return data_; }

inline bool has_return() const noexcept { return has_return_; }

protected:
typename value_type data_;
bool has_return_ = false;
};

template <class TPROMISE, typename = std::enable_if_t<std::is_base_of_v<promise_base_type, TPROMISE>>>
# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TPROMISE>
# else
template <class TPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TPROMISE>::value> >
# endif
class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base_type {
public:
using promise_type = TPROMISE;
Expand All @@ -69,7 +77,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base
public:
callable_awaitable_base(handle_type handle) : callee_{handle} {}

bool await_ready() noexcept {
inline bool await_ready() noexcept {
if (!callee_) {
return true;
}
Expand All @@ -83,22 +91,53 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base
callee_.resume();
}

if (callee_.promise().get_status() >= promise_status::kDone) {
return true;
}

return callee_.done();
}

# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS
template <DerivedPromiseBaseType TCPROMISE>
# else
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value>>
template <class TCPROMISE, typename = std::enable_if_t<std::is_base_of<promise_base_type, TCPROMISE>::value> >
# endif
void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TCPROMISE> caller) noexcept {
set_caller(caller);
caller.promise().set_waiting_handle(callee_);
inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<TCPROMISE> caller) noexcept {
if (caller.promise().get_status() < promise_status::kDone) {
set_caller(caller);
caller.promise().set_waiting_handle(callee_);
callee_.promise().add_caller(caller);
} else {
// Already done and can not suspend again
caller.resume();
}
}

inline handle_type& get_callee() noexcept { return callee_; }
inline const handle_type& get_callee() const noexcept { return callee_; }

protected:
inline void detach() noexcept {
// caller maybe null if the callable is already ready when co_await
auto caller = get_caller();
auto& callee_promise = get_callee().promise();

if (caller) {
callee_promise.remove_caller(caller, true);
caller.promise->set_waiting_handle(nullptr);
set_caller(nullptr);
}

if (callee_promise.get_status() < promise_status::kDone) {
if (await_ready() || !caller) {
callee_promise.set_status(promise_status::kKilled);
} else {
callee_promise.set_status(caller.promise->get_status());
}
}
}

private:
handle_type callee_;
};
Expand All @@ -114,26 +153,15 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable<TPROMISE, true> : public cal
public:
using base_type::await_ready;
using base_type::await_suspend;
using base_type::detach;
using base_type::get_callee;
using base_type::get_caller;
using base_type::set_caller;
callable_awaitable(handle_type handle) : base_type(handle) {}

void await_resume() {
// caller maybe null if the callable is already ready when co_await
auto caller = get_caller();
if (await_ready() || !caller) {
get_callee().promise().set_status(promise_status::kDone);
} else {
get_callee().promise().set_status(caller.promise().get_status());
}

get_callee().promise().resume_waiting();

if (caller) {
caller.promise().set_waiting_handle(nullptr);
set_caller(nullptr);
}
inline void await_resume() {
detach();
get_callee().promise().resume_waiting(get_callee(), true);
}
};

Expand All @@ -148,35 +176,22 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable<TPROMISE, false> : public ca
public:
using base_type::await_ready;
using base_type::await_suspend;
using base_type::detach;
using base_type::get_callee;
using base_type::get_caller;
using base_type::set_caller;
callable_awaitable(handle_type handle) : base_type(handle) {}

value_type await_resume() {
bool is_ready = await_ready();
// caller maybe null if the callable is already ready when co_await
auto caller = get_caller();
promise_status status;
if (is_ready || !caller) {
status = promise_status::kDone;
} else {
status = caller.promise().get_status();
}
get_callee().promise().set_status(status);

get_callee().promise().resume_waiting();
inline value_type await_resume() {
detach();
auto& callee_promise = get_callee().promise();
callee_promise.resume_waiting(get_callee(), true);

if (caller) {
caller.promise().set_waiting_handle(nullptr);
set_caller(nullptr);
if (!callee_promise.has_return()) {
return promise_error_transform<value_type>()(callee_promise.get_status());
}

if (is_ready) {
return std::move(get_callee().promise().data());
} else {
return promise_error_transform<value_type>()(status);
}
return std::move(callee_promise.data());
}
};

Expand All @@ -193,20 +208,19 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_future {
}

struct initial_awaitable {
bool await_ready() const noexcept { return false; }
void await_resume() const noexcept {
inline bool await_ready() const noexcept { return false; }
inline void await_resume() const noexcept {
if (handle.promise().get_status() == promise_status::kCreated) {
promise_status excepted = promise_status::kCreated;
handle.promise().set_status(promise_status::kRunning, &excepted);
}
}
void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<promise_type> caller) noexcept {
inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<promise_type> caller) noexcept {
handle = caller;
}
LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<promise_type> handle;
};
initial_awaitable initial_suspend() noexcept { return {}; }
LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() { throw; }
};
using handle_type = LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<promise_type>;
Expand Down Expand Up @@ -237,9 +251,33 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_future {
awaitable_type operator co_await() { return awaitable_type{current_handle_}; }

inline bool is_ready() const noexcept { return current_handle_.done(); }

inline promise_status get_status() const noexcept { return current_handle_.promise().get_status(); }

static auto yield_status() noexcept { return promise_base_type::pick_current_status(); }

/**
* @brief Get the internal handle object
* @note This function is only for internal use(testing), do not use it in your code.
*
* @return internal handle
*/
inline const handle_type& get_internal_handle() const noexcept { return current_handle_; }

/**
* @brief Get the internal promise object
* @note This function is only for internal use(testing), do not use it in your code.
*
* @return internal promise object
*/
inline promise_type& get_internal_promise() noexcept { return current_handle_.promise(); }

/**
* @brief Get the internal promise object
* @note This function is only for internal use(testing), do not use it in your code.
*
* @return internal promise object
*/
inline const promise_type& get_internal_promise() const noexcept { return current_handle_.promise(); }

private:
Expand Down
Loading

0 comments on commit 9315053

Please sign in to comment.