diff --git a/include/libcopp/coroutine/callable_promise.h b/include/libcopp/coroutine/callable_promise.h index 15d7532..a26c800 100644 --- a/include/libcopp/coroutine/callable_promise.h +++ b/include/libcopp/coroutine/callable_promise.h @@ -50,16 +50,24 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_promise_base : public p set_status(promise_status::kDone); } data_ = std::move(value); + has_return_ = true; } - typename value_type& data() noexcept { return data_; } - const typename value_type& data() const noexcept { return data_; } + inline typename value_type& data() noexcept { return data_; } + inline const typename value_type& data() const noexcept { return data_; } + + inline bool has_return() const noexcept { return has_return_; } protected: typename value_type data_; + bool has_return_ = false; }; -template >> +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS +template +# else +template ::value> > +# endif class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base_type { public: using promise_type = TPROMISE; @@ -69,7 +77,7 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base public: callable_awaitable_base(handle_type handle) : callee_{handle} {} - bool await_ready() noexcept { + inline bool await_ready() noexcept { if (!callee_) { return true; } @@ -83,22 +91,53 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable_base : public awaitable_base callee_.resume(); } + if (callee_.promise().get_status() >= promise_status::kDone) { + return true; + } + return callee_.done(); } # if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS template # else - template ::value>> + template ::value> > # endif - void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle caller) noexcept { - set_caller(caller); - caller.promise().set_waiting_handle(callee_); + inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle caller) noexcept { + if (caller.promise().get_status() < promise_status::kDone) { + set_caller(caller); + caller.promise().set_waiting_handle(callee_); + callee_.promise().add_caller(caller); + } else { + // Already done and can not suspend again + caller.resume(); + } } inline handle_type& get_callee() noexcept { return callee_; } inline const handle_type& get_callee() const noexcept { return callee_; } + protected: + inline void detach() noexcept { + // caller maybe null if the callable is already ready when co_await + auto caller = get_caller(); + auto& callee_promise = get_callee().promise(); + + if (caller) { + callee_promise.remove_caller(caller, true); + caller.promise->set_waiting_handle(nullptr); + set_caller(nullptr); + } + + if (callee_promise.get_status() < promise_status::kDone) { + if (await_ready() || !caller) { + callee_promise.set_status(promise_status::kKilled); + } else { + callee_promise.set_status(caller.promise->get_status()); + } + } + } + private: handle_type callee_; }; @@ -114,26 +153,15 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable : public cal public: using base_type::await_ready; using base_type::await_suspend; + using base_type::detach; using base_type::get_callee; using base_type::get_caller; using base_type::set_caller; callable_awaitable(handle_type handle) : base_type(handle) {} - void await_resume() { - // caller maybe null if the callable is already ready when co_await - auto caller = get_caller(); - if (await_ready() || !caller) { - get_callee().promise().set_status(promise_status::kDone); - } else { - get_callee().promise().set_status(caller.promise().get_status()); - } - - get_callee().promise().resume_waiting(); - - if (caller) { - caller.promise().set_waiting_handle(nullptr); - set_caller(nullptr); - } + inline void await_resume() { + detach(); + get_callee().promise().resume_waiting(get_callee(), true); } }; @@ -148,35 +176,22 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_awaitable : public ca public: using base_type::await_ready; using base_type::await_suspend; + using base_type::detach; using base_type::get_callee; using base_type::get_caller; using base_type::set_caller; callable_awaitable(handle_type handle) : base_type(handle) {} - value_type await_resume() { - bool is_ready = await_ready(); - // caller maybe null if the callable is already ready when co_await - auto caller = get_caller(); - promise_status status; - if (is_ready || !caller) { - status = promise_status::kDone; - } else { - status = caller.promise().get_status(); - } - get_callee().promise().set_status(status); - - get_callee().promise().resume_waiting(); + inline value_type await_resume() { + detach(); + auto& callee_promise = get_callee().promise(); + callee_promise.resume_waiting(get_callee(), true); - if (caller) { - caller.promise().set_waiting_handle(nullptr); - set_caller(nullptr); + if (!callee_promise.has_return()) { + return promise_error_transform()(callee_promise.get_status()); } - if (is_ready) { - return std::move(get_callee().promise().data()); - } else { - return promise_error_transform()(status); - } + return std::move(callee_promise.data()); } }; @@ -193,20 +208,19 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_future { } struct initial_awaitable { - bool await_ready() const noexcept { return false; } - void await_resume() const noexcept { + inline bool await_ready() const noexcept { return false; } + inline void await_resume() const noexcept { if (handle.promise().get_status() == promise_status::kCreated) { promise_status excepted = promise_status::kCreated; handle.promise().set_status(promise_status::kRunning, &excepted); } } - void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle caller) noexcept { + inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle caller) noexcept { handle = caller; } LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle handle; }; initial_awaitable initial_suspend() noexcept { return {}; } - LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE suspend_always final_suspend() noexcept { return {}; } void unhandled_exception() { throw; } }; using handle_type = LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle; @@ -237,9 +251,33 @@ class LIBCOPP_COPP_API_HEAD_ONLY callable_future { awaitable_type operator co_await() { return awaitable_type{current_handle_}; } inline bool is_ready() const noexcept { return current_handle_.done(); } + inline promise_status get_status() const noexcept { return current_handle_.promise().get_status(); } + static auto yield_status() noexcept { return promise_base_type::pick_current_status(); } + + /** + * @brief Get the internal handle object + * @note This function is only for internal use(testing), do not use it in your code. + * + * @return internal handle + */ + inline const handle_type& get_internal_handle() const noexcept { return current_handle_; } + + /** + * @brief Get the internal promise object + * @note This function is only for internal use(testing), do not use it in your code. + * + * @return internal promise object + */ inline promise_type& get_internal_promise() noexcept { return current_handle_.promise(); } + + /** + * @brief Get the internal promise object + * @note This function is only for internal use(testing), do not use it in your code. + * + * @return internal promise object + */ inline const promise_type& get_internal_promise() const noexcept { return current_handle_.promise(); } private: diff --git a/include/libcopp/coroutine/std_coroutine_common.h b/include/libcopp/coroutine/std_coroutine_common.h index 9451072..386778f 100644 --- a/include/libcopp/coroutine/std_coroutine_common.h +++ b/include/libcopp/coroutine/std_coroutine_common.h @@ -6,12 +6,18 @@ #include #include +#include #include +#include #if defined(LIBCOPP_MACRO_ENABLE_STD_EXCEPTION_PTR) && LIBCOPP_MACRO_ENABLE_STD_EXCEPTION_PTR # include #endif +#ifdef __cpp_impl_three_way_comparison +# include +#endif + #include "libcopp/future/future.h" #include "libcopp/utils/atomic_int_type.h" @@ -20,12 +26,13 @@ LIBCOPP_COPP_NAMESPACE_BEGIN enum class LIBCOPP_COPP_API_HEAD_ONLY promise_status : uint8_t { - kCreated = 0, - kRunning = 1, - kDone = 2, - kCancle = 3, - kKilled = 4, - kTimeout = 5, + kInvalid = 0, + kCreated = 1, + kRunning = 2, + kDone = 3, + kCancle = 4, + kKilled = 5, + kTimeout = 6, }; class promise_base_type; @@ -36,6 +43,85 @@ concept DerivedPromiseBaseType = std::is_base_of::value; # endif class promise_base_type { + public: + using handle_type = LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle; + using type_erased_handle_type = LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<>; + struct LIBCOPP_COPP_API_HEAD_ONLY handle_delegate { + type_erased_handle_type handle; + promise_base_type *promise; + +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value>> +# endif + explicit handle_delegate( + const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &origin_handle) noexcept + : handle{origin_handle} { + if (handle) { + promise = &origin_handle.promise(); + } else { + promise = nullptr; + } + } + + explicit handle_delegate(std::nullptr_t) noexcept : handle{nullptr}, promise{nullptr} {} + + friend inline bool operator==(const handle_delegate &l, const handle_delegate &r) noexcept { + return l.handle == r.handle; + } + friend inline bool operator!=(const handle_delegate &l, const handle_delegate &r) noexcept { + return l.handle != r.handle; + } + friend inline bool operator<(const handle_delegate &l, const handle_delegate &r) noexcept { + return l.handle < r.handle; + } + friend inline bool operator<=(const handle_delegate &l, const handle_delegate &r) noexcept { + return l.handle <= r.handle; + } + friend inline bool operator>(const handle_delegate &l, const handle_delegate &r) noexcept { + return l.handle > r.handle; + } + friend inline bool operator>=(const handle_delegate &l, const handle_delegate &r) noexcept { + return l.handle >= r.handle; + } + inline operator bool() const noexcept { return !!handle; } + +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value>> +# endif + inline handle_delegate &operator=( + const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &origin_handle) noexcept { + handle = origin_handle; + if (handle) { + promise = &origin_handle.promise(); + } else { + promise = nullptr; + } + } + inline handle_delegate &operator=(std::nullptr_t) noexcept { + handle = nullptr; + promise = nullptr; + return *this; + } + }; + + struct pick_promise_status_awaitable { + promise_status data; + + LIBCOPP_COPP_API pick_promise_status_awaitable() noexcept; + LIBCOPP_COPP_API pick_promise_status_awaitable(pick_promise_status_awaitable &&other) noexcept; + LIBCOPP_COPP_API ~pick_promise_status_awaitable(); + pick_promise_status_awaitable(const pick_promise_status_awaitable &) = delete; + pick_promise_status_awaitable &operator=(const pick_promise_status_awaitable &) = delete; + + LIBCOPP_COPP_API_HEAD_ONLY inline bool await_ready() const noexcept { return true; } + LIBCOPP_COPP_API_HEAD_ONLY inline promise_status await_resume() const noexcept { return data; } + LIBCOPP_COPP_API_HEAD_ONLY inline void await_suspend(type_erased_handle_type) noexcept {} + }; + public: LIBCOPP_COPP_API promise_base_type(); LIBCOPP_COPP_API ~promise_base_type(); @@ -43,10 +129,8 @@ class promise_base_type { LIBCOPP_COPP_API bool set_status(promise_status value, promise_status *expect = nullptr) noexcept; LIBCOPP_COPP_API promise_status get_status() const noexcept; - LIBCOPP_COPP_API LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle get_waiting_handle() - const noexcept; - LIBCOPP_COPP_API void set_waiting_handle(std::nullptr_t) noexcept; + LIBCOPP_COPP_API void set_waiting_handle(handle_delegate handle); # if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS template # else @@ -57,20 +141,86 @@ class promise_base_type { if (nullptr == handle) { set_waiting_handle(nullptr); } else { - set_waiting_handle(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<>::from_address(handle.address())); + set_waiting_handle(handle_delegate{handle}); + } + } + + /** + * @brief Resume waiting handle, this should only be called in await_resume and after this call, callee maybe + * destroyed + */ +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value>> +# endif + LIBCOPP_COPP_API_HEAD_ONLY inline void resume_waiting( + const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &handle, bool inherit_status) { + resume_waiting(handle_delegate{handle}, inherit_status); + }; + + LIBCOPP_COPP_API void resume_waiting(handle_delegate current_delegate, bool inherit_status); + + // C++20 coroutine + struct LIBCOPP_COPP_API_HEAD_ONLY final_awaitable { + inline bool await_ready() const noexcept { return false; } + inline void await_resume() const noexcept {} + +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value>> +# endif + inline void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle self) noexcept { + self.promise().resume_callers(); } + }; + final_awaitable final_suspend() noexcept { return {}; } + + LIBCOPP_COPP_API void add_caller(handle_delegate handle) noexcept; +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value>> +# endif + LIBCOPP_COPP_API_HEAD_ONLY void add_caller( + const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &handle) noexcept { + add_caller(handle_delegate{handle}); + } + + LIBCOPP_COPP_API void remove_caller(handle_delegate handle, bool inherit_status) noexcept; +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value>> +# endif + LIBCOPP_COPP_API_HEAD_ONLY void remove_caller( + const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &handle, bool inherit_status) noexcept { + remove_caller(handle_delegate{handle}, inherit_status); } - LIBCOPP_COPP_API void resume_waiting(); + LIBCOPP_COPP_API pick_promise_status_awaitable yield_value(pick_promise_status_awaitable &&args) const noexcept; + static LIBCOPP_COPP_API_HEAD_ONLY inline pick_promise_status_awaitable pick_current_status() noexcept { return {}; } private: - LIBCOPP_COPP_API void set_waiting_handle(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<> handle); + LIBCOPP_COPP_API void resume_callers(); private: // promise_status util::lock::atomic_int_type status_; // We must erase type here, because MSVC use is_empty_v>, which need to calculate the type size - LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<> current_waiting_; + handle_delegate current_waiting_; + handle_delegate unique_caller_; + + // hash for handle_delegate + struct LIBCOPP_COPP_API_HEAD_ONLY handle_delegate_hash { + inline size_t operator()(const handle_delegate &handle_delegate) const noexcept { + return std::hash()(handle_delegate.handle.address()); + } + }; + + // Mostly, there is only one caller for a promise, we needn't hash map to store one handle + std::unique_ptr> multiple_callers_; }; class awaitable_base_type { @@ -78,11 +228,10 @@ class awaitable_base_type { LIBCOPP_COPP_API awaitable_base_type(); LIBCOPP_COPP_API ~awaitable_base_type(); - LIBCOPP_COPP_API LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle get_caller() - const noexcept; + LIBCOPP_COPP_API promise_base_type::handle_delegate get_caller() const noexcept; - LIBCOPP_COPP_API void set_caller( - const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &handle) noexcept; + LIBCOPP_COPP_API void set_caller(promise_base_type::handle_delegate caller) noexcept; + LIBCOPP_COPP_API void set_caller(std::nullptr_t) noexcept; # if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS template @@ -94,13 +243,12 @@ class awaitable_base_type { if (nullptr == handle) { set_caller(nullptr); } else { - set_caller( - LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle::from_promise(handle.promise())); + set_caller(promise_base_type::handle_delegate{handle}); } } private: - LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle caller_; + promise_base_type::handle_delegate caller_; }; template diff --git a/src/libcopp/coroutine/std_coroutine_common.cpp b/src/libcopp/coroutine/std_coroutine_common.cpp index e1c1819..08d0d73 100644 --- a/src/libcopp/coroutine/std_coroutine_common.cpp +++ b/src/libcopp/coroutine/std_coroutine_common.cpp @@ -4,14 +4,24 @@ #include +#include #include #if defined(LIBCOPP_MACRO_ENABLE_STD_COROUTINE) && LIBCOPP_MACRO_ENABLE_STD_COROUTINE LIBCOPP_COPP_NAMESPACE_BEGIN +LIBCOPP_COPP_API promise_base_type::pick_promise_status_awaitable::pick_promise_status_awaitable() noexcept + : data(promise_status::kInvalid) {} + +LIBCOPP_COPP_API promise_base_type::pick_promise_status_awaitable::pick_promise_status_awaitable( + pick_promise_status_awaitable &&other) noexcept + : data(other.data) {} + +LIBCOPP_COPP_API promise_base_type::pick_promise_status_awaitable::~pick_promise_status_awaitable() {} + LIBCOPP_COPP_API promise_base_type::promise_base_type() - : status_{static_cast(promise_status::kCreated)}, current_waiting_{nullptr} {} + : status_{static_cast(promise_status::kCreated)}, current_waiting_{nullptr}, unique_caller_{nullptr} {} LIBCOPP_COPP_API promise_base_type::~promise_base_type() {} @@ -33,39 +43,95 @@ LIBCOPP_COPP_API promise_status promise_base_type::get_status() const noexcept { return static_cast(status_.load()); } -LIBCOPP_COPP_API LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle -promise_base_type::get_waiting_handle() const noexcept { - return LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle::from_address( - current_waiting_.address()); +LIBCOPP_COPP_API void promise_base_type::set_waiting_handle(std::nullptr_t) noexcept { current_waiting_ = nullptr; } + +LIBCOPP_COPP_API void promise_base_type::set_waiting_handle(handle_delegate handle) { current_waiting_ = handle; } + +LIBCOPP_COPP_API void promise_base_type::resume_waiting(handle_delegate current_delegate, bool inherit_status) { + // Atfer resume(), this object maybe destroyed. + auto waiting_delegate = current_waiting_; + if (waiting_delegate.handle && !waiting_delegate.handle.done()) { + current_waiting_ = nullptr; + // Prevent the waiting coroutine remuse this again. + assert(waiting_delegate.promise); + waiting_delegate.promise->remove_caller(current_delegate, inherit_status); + waiting_delegate.handle.resume(); + } } -LIBCOPP_COPP_API void promise_base_type::set_waiting_handle(std::nullptr_t) noexcept { current_waiting_ = nullptr; } +LIBCOPP_COPP_API void promise_base_type::add_caller(handle_delegate delegate) noexcept { + if (!delegate.handle || delegate.handle.done()) { + return; + } -LIBCOPP_COPP_API void promise_base_type::set_waiting_handle( - LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<> handle) { - current_waiting_ = handle; + if (!unique_caller_.handle) { + unique_caller_ = delegate; + return; + } + + if (!multiple_callers_) { + multiple_callers_.reset(new std::unordered_set()); + } + multiple_callers_->insert(delegate); } -LIBCOPP_COPP_API void promise_base_type::resume_waiting() { - while (current_waiting_ && !current_waiting_.done()) { - current_waiting_.resume(); +LIBCOPP_COPP_API void promise_base_type::remove_caller(handle_delegate delegate, bool inherit_status) noexcept { + bool has_caller = false; + do { + if (unique_caller_.handle == delegate.handle) { + unique_caller_ = nullptr; + has_caller = true; + break; + } + + if (multiple_callers_) { + has_caller = multiple_callers_->erase(delegate) > 0; + } + } while (false); + + if (has_caller && inherit_status && nullptr != delegate.promise && get_status() < promise_status::kDone && + delegate.promise->get_status() > promise_status::kDone) { + set_status(delegate.promise->get_status()); + } +} + +LIBCOPP_COPP_API promise_base_type::pick_promise_status_awaitable promise_base_type::yield_value( + pick_promise_status_awaitable &&args) const noexcept { + args.data = static_cast(status_.load()); + return args; +} + +LIBCOPP_COPP_API void promise_base_type::resume_callers() { + auto unique_caller = unique_caller_; + unique_caller_ = nullptr; + std::unique_ptr> multiple_callers; + multiple_callers.swap(multiple_callers_); + + // The promise object may be destroyed after first caller.resume() + if (unique_caller.handle && !unique_caller.handle.done()) { + unique_caller.handle.resume(); + } + + if (multiple_callers) { + for (auto &caller : *multiple_callers) { + if (caller.handle && !caller.handle.done()) { + caller.handle.resume(); + } + } } - current_waiting_ = nullptr; } LIBCOPP_COPP_API awaitable_base_type::awaitable_base_type() : caller_{nullptr} {} LIBCOPP_COPP_API awaitable_base_type::~awaitable_base_type() {} -LIBCOPP_COPP_API LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle -awaitable_base_type::get_caller() const noexcept { - return caller_; -} +LIBCOPP_COPP_API promise_base_type::handle_delegate awaitable_base_type::get_caller() const noexcept { return caller_; } -LIBCOPP_COPP_API void awaitable_base_type::set_caller( - const LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle &handle) noexcept { - caller_ = handle; +LIBCOPP_COPP_API void awaitable_base_type::set_caller(promise_base_type::handle_delegate caller) noexcept { + caller_ = caller; } +LIBCOPP_COPP_API void awaitable_base_type::set_caller(std::nullptr_t) noexcept { caller_ = nullptr; } + LIBCOPP_COPP_NAMESPACE_END #endif diff --git a/test/case/callable_promise_test.cpp b/test/case/callable_promise_test.cpp index 2f948ec..72efe14 100644 --- a/test/case/callable_promise_test.cpp +++ b/test/case/callable_promise_test.cpp @@ -15,17 +15,44 @@ struct callable_promise_test_pending_awaitable { bool await_ready() noexcept { return false; } - void await_resume() {} + void await_resume() noexcept { detach(); } - void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<> caller) noexcept { +# if defined(LIBCOPP_MACRO_ENABLE_CONCEPTS) && LIBCOPP_MACRO_ENABLE_CONCEPTS + template +# else + template ::value> > +# endif + void await_suspend(LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle caller) noexcept { pending.push_back(caller); + current = caller; } + void detach() noexcept { + if (!current) { + return; + } + + for (auto iter = pending.begin(); iter != pending.end(); ++iter) { + if (*iter == current) { + pending.erase(iter); + current = nullptr; + break; + } + } + } + + callable_promise_test_pending_awaitable() {} + ~callable_promise_test_pending_awaitable() { + // detach when destroyed + detach(); + } + + LIBCOPP_MACRO_STD_COROUTINE_NAMESPACE coroutine_handle<> current = nullptr; + static std::list > pending; static void resume_all() { while (!pending.empty()) { pending.front().resume(); - pending.pop_front(); } } }; @@ -54,6 +81,10 @@ static copp::callable_future callable_func_await_int() { CASE_TEST(callable_promise, callable_future_integer_need_resume) { copp::callable_future f = callable_func_await_int(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + CASE_EXPECT_NE(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); CASE_EXPECT_FALSE(f.is_ready()); @@ -74,6 +105,10 @@ static copp::callable_future callable_func_await_int_ready() { CASE_TEST(callable_promise, callable_future_integer_ready) { copp::callable_future f = callable_func_await_int_ready(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); CASE_EXPECT_TRUE(f.is_ready()); CASE_EXPECT_EQ(64, f.get_internal_promise().data()); @@ -101,6 +136,10 @@ static copp::callable_future callable_func_await_void() { CASE_TEST(callable_promise, callable_future_void_need_resume) { copp::callable_future f = callable_func_await_void(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + CASE_EXPECT_NE(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); CASE_EXPECT_FALSE(f.is_ready()); callable_promise_test_pending_awaitable::resume_all(); @@ -118,6 +157,10 @@ static copp::callable_future callable_func_await_void_ready() { CASE_TEST(callable_promise, callable_future_void_ready) { copp::callable_future f = callable_func_await_void_ready(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + callable_promise_test_pending_awaitable::resume_all(); CASE_EXPECT_EQ(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); CASE_EXPECT_TRUE(f.is_ready()); @@ -136,6 +179,10 @@ static copp::callable_future callable_func_int_await_void() { CASE_TEST(callable_promise, callable_future_int_await_void) { copp::callable_future f = callable_func_int_await_void(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + CASE_EXPECT_NE(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); CASE_EXPECT_FALSE(f.is_ready()); callable_promise_test_pending_awaitable::resume_all(); @@ -144,6 +191,66 @@ CASE_TEST(callable_promise, callable_future_int_await_void) { CASE_EXPECT_EQ(40, f.get_internal_promise().data()); } +static copp::callable_future callable_func_killed_by_caller_l3() { + co_await callable_promise_test_pending_awaitable(); + auto current_status = co_yield copp::callable_future::yield_status(); + CASE_EXPECT_TRUE(copp::promise_status::kKilled == current_status); + + // await again and return immdiately + // co_await callable_promise_test_pending_awaitable(); + co_return -static_cast(current_status); +} + +static copp::callable_future callable_func_killed_by_caller_l2() { + int result = co_await callable_func_killed_by_caller_l3(); + co_return result; +} + +static copp::callable_future callable_func_killed_by_caller_l1() { + int result = co_await callable_func_killed_by_caller_l2(); + co_return result; +} + +CASE_TEST(callable_promise, killed_by_caller_resume_waiting) { + copp::callable_future f = callable_func_killed_by_caller_l1(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + + CASE_EXPECT_NE(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); + CASE_EXPECT_FALSE(f.is_ready()); + + // Mock to kill by caller + f.get_internal_promise().set_status(copp::promise_status::kKilled); + f.get_internal_handle().resume(); + CASE_EXPECT_TRUE(f.is_ready()); + CASE_EXPECT_EQ(f.get_internal_promise().data(), -static_cast(copp::promise_status::kKilled)); + + // cleanup + callable_promise_test_pending_awaitable::resume_all(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kKilled), static_cast(f.get_status())); +} + +CASE_TEST(callable_promise, killed_by_caller_drop_generator) { + copp::callable_future f = callable_func_killed_by_caller_l2(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kCreated), static_cast(f.get_status())); + // Start + f.get_internal_handle().resume(); + + CASE_EXPECT_NE(static_cast(copp::promise_status::kDone), static_cast(f.get_status())); + CASE_EXPECT_FALSE(f.is_ready()); + + // Mock to kill by caller + f.get_internal_promise().set_status(copp::promise_status::kKilled); + f.get_internal_handle().resume(); + CASE_EXPECT_TRUE(f.is_ready()); + CASE_EXPECT_EQ(f.get_internal_promise().data(), -static_cast(copp::promise_status::kKilled)); + + // cleanup + callable_promise_test_pending_awaitable::resume_all(); + CASE_EXPECT_EQ(static_cast(copp::promise_status::kKilled), static_cast(f.get_status())); +} + #else CASE_TEST(callable_promise, disabled) {} #endif \ No newline at end of file