LCOV - code coverage report
Current view: top level - capy - when_all.hpp (source / functions) Coverage Total Hit Missed
Test: coverage_remapped.info Lines: 98.0 % 98 96 2
Test Date: 2026-02-23 14:24:58 Functions: 89.0 % 617 549 68

           TLA  Line data    Source code
       1                 : //
       2                 : // Copyright (c) 2026 Steve Gerbino
       3                 : //
       4                 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
       5                 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
       6                 : //
       7                 : // Official repository: https://github.com/cppalliance/capy
       8                 : //
       9                 : 
      10                 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
      11                 : #define BOOST_CAPY_WHEN_ALL_HPP
      12                 : 
      13                 : #include <boost/capy/detail/config.hpp>
      14                 : #include <boost/capy/concept/executor.hpp>
      15                 : #include <boost/capy/concept/io_awaitable.hpp>
      16                 : #include <coroutine>
      17                 : #include <boost/capy/ex/io_env.hpp>
      18                 : #include <boost/capy/ex/frame_allocator.hpp>
      19                 : #include <boost/capy/task.hpp>
      20                 : 
      21                 : #include <array>
      22                 : #include <atomic>
      23                 : #include <exception>
      24                 : #include <optional>
      25                 : #include <stop_token>
      26                 : #include <tuple>
      27                 : #include <type_traits>
      28                 : #include <utility>
      29                 : 
      30                 : namespace boost {
      31                 : namespace capy {
      32                 : 
      33                 : namespace detail {
      34                 : 
      35                 : /** Type trait to filter void types from a tuple.
      36                 : 
      37                 :     Void-returning tasks do not contribute a value to the result tuple.
      38                 :     This trait computes the filtered result type.
      39                 : 
      40                 :     Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
      41                 : */
      42                 : template<typename T>
      43                 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
      44                 : 
      45                 : template<typename... Ts>
      46                 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
      47                 : 
      48                 : /** Holds the result of a single task within when_all.
      49                 : */
      50                 : template<typename T>
      51                 : struct result_holder
      52                 : {
      53                 :     std::optional<T> value_;
      54                 : 
      55 HIT          60 :     void set(T v)
      56                 :     {
      57              60 :         value_ = std::move(v);
      58              60 :     }
      59                 : 
      60              53 :     T get() &&
      61                 :     {
      62              53 :         return std::move(*value_);
      63                 :     }
      64                 : };
      65                 : 
      66                 : /** Specialization for void tasks - no value storage needed.
      67                 : */
      68                 : template<>
      69                 : struct result_holder<void>
      70                 : {
      71                 : };
      72                 : 
      73                 : /** Shared state for when_all operation.
      74                 : 
      75                 :     @tparam Ts The result types of the tasks.
      76                 : */
      77                 : template<typename... Ts>
      78                 : struct when_all_state
      79                 : {
      80                 :     static constexpr std::size_t task_count = sizeof...(Ts);
      81                 : 
      82                 :     // Completion tracking - when_all waits for all children
      83                 :     std::atomic<std::size_t> remaining_count_;
      84                 : 
      85                 :     // Result storage in input order
      86                 :     std::tuple<result_holder<Ts>...> results_;
      87                 : 
      88                 :     // Runner handles - destroyed in await_resume while allocator is valid
      89                 :     std::array<std::coroutine_handle<>, task_count> runner_handles_{};
      90                 : 
      91                 :     // Exception storage - first error wins, others discarded
      92                 :     std::atomic<bool> has_exception_{false};
      93                 :     std::exception_ptr first_exception_;
      94                 : 
      95                 :     // Stop propagation - on error, request stop for siblings
      96                 :     std::stop_source stop_source_;
      97                 : 
      98                 :     // Connects parent's stop_token to our stop_source
      99                 :     struct stop_callback_fn
     100                 :     {
     101                 :         std::stop_source* source_;
     102               4 :         void operator()() const { source_->request_stop(); }
     103                 :     };
     104                 :     using stop_callback_t = std::stop_callback<stop_callback_fn>;
     105                 :     std::optional<stop_callback_t> parent_stop_callback_;
     106                 : 
     107                 :     // Parent resumption
     108                 :     std::coroutine_handle<> continuation_;
     109                 :     io_env const* caller_env_ = nullptr;
     110                 : 
     111              59 :     when_all_state()
     112              59 :         : remaining_count_(task_count)
     113                 :     {
     114              59 :     }
     115                 : 
     116                 :     // Runners self-destruct in final_suspend. No destruction needed here.
     117                 : 
     118                 :     /** Capture an exception (first one wins).
     119                 :     */
     120              20 :     void capture_exception(std::exception_ptr ep)
     121                 :     {
     122              20 :         bool expected = false;
     123              20 :         if(has_exception_.compare_exchange_strong(
     124                 :             expected, true, std::memory_order_relaxed))
     125              17 :             first_exception_ = ep;
     126              20 :     }
     127                 : 
     128                 : };
     129                 : 
     130                 : /** Wrapper coroutine that intercepts task completion.
     131                 : 
     132                 :     This runner awaits its assigned task and stores the result in
     133                 :     the shared state, or captures the exception and requests stop.
     134                 : */
     135                 : template<typename T, typename... Ts>
     136                 : struct when_all_runner
     137                 : {
     138                 :     struct promise_type // : frame_allocating_base  // DISABLED FOR TESTING
     139                 :     {
     140                 :         when_all_state<Ts...>* state_ = nullptr;
     141                 :         io_env env_;
     142                 : 
     143             130 :         when_all_runner get_return_object()
     144                 :         {
     145             130 :             return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
     146                 :         }
     147                 : 
     148             130 :         std::suspend_always initial_suspend() noexcept
     149                 :         {
     150             130 :             return {};
     151                 :         }
     152                 : 
     153             130 :         auto final_suspend() noexcept
     154                 :         {
     155                 :             struct awaiter
     156                 :             {
     157                 :                 promise_type* p_;
     158                 : 
     159             130 :                 bool await_ready() const noexcept
     160                 :                 {
     161             130 :                     return false;
     162                 :                 }
     163                 : 
     164             130 :                 std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
     165                 :                 {
     166                 :                     // Extract everything needed before self-destruction.
     167             130 :                     auto* state = p_->state_;
     168             130 :                     auto* counter = &state->remaining_count_;
     169             130 :                     auto* caller_env = state->caller_env_;
     170             130 :                     auto cont = state->continuation_;
     171                 : 
     172             130 :                     h.destroy();
     173                 : 
     174                 :                     // If last runner, dispatch parent for symmetric transfer.
     175             130 :                     auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
     176             130 :                     if(remaining == 1)
     177              59 :                         return caller_env->executor.dispatch(cont);
     178              71 :                     return std::noop_coroutine();
     179                 :                 }
     180                 : 
     181 MIS           0 :                 void await_resume() const noexcept
     182                 :                 {
     183               0 :                 }
     184                 :             };
     185 HIT         130 :             return awaiter{this};
     186                 :         }
     187                 : 
     188             110 :         void return_void()
     189                 :         {
     190             110 :         }
     191                 : 
     192              20 :         void unhandled_exception()
     193                 :         {
     194              20 :             state_->capture_exception(std::current_exception());
     195                 :             // Request stop for sibling tasks
     196              20 :             state_->stop_source_.request_stop();
     197              20 :         }
     198                 : 
     199                 :         template<class Awaitable>
     200                 :         struct transform_awaiter
     201                 :         {
     202                 :             std::decay_t<Awaitable> a_;
     203                 :             promise_type* p_;
     204                 : 
     205             130 :             bool await_ready()
     206                 :             {
     207             130 :                 return a_.await_ready();
     208                 :             }
     209                 : 
     210             130 :             decltype(auto) await_resume()
     211                 :             {
     212             130 :                 return a_.await_resume();
     213                 :             }
     214                 : 
     215                 :             template<class Promise>
     216             129 :             auto await_suspend(std::coroutine_handle<Promise> h)
     217                 :             {
     218                 : #ifdef _MSC_VER
     219                 :                 using R = decltype(a_.await_suspend(h, &p_->env_));
     220                 :                 if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
     221                 :                     a_.await_suspend(h, &p_->env_).resume();
     222                 :                 else
     223                 :                     return a_.await_suspend(h, &p_->env_);
     224                 : #else
     225             129 :                 return a_.await_suspend(h, &p_->env_);
     226                 : #endif
     227                 :             }
     228                 :         };
     229                 : 
     230                 :         template<class Awaitable>
     231             130 :         auto await_transform(Awaitable&& a)
     232                 :         {
     233                 :             using A = std::decay_t<Awaitable>;
     234                 :             if constexpr (IoAwaitable<A>)
     235                 :             {
     236                 :                 return transform_awaiter<Awaitable>{
     237             260 :                     std::forward<Awaitable>(a), this};
     238                 :             }
     239                 :             else
     240                 :             {
     241                 :                 static_assert(sizeof(A) == 0, "requires IoAwaitable");
     242                 :             }
     243             130 :         }
     244                 :     };
     245                 : 
     246                 :     std::coroutine_handle<promise_type> h_;
     247                 : 
     248             130 :     explicit when_all_runner(std::coroutine_handle<promise_type> h)
     249             130 :         : h_(h)
     250                 :     {
     251             130 :     }
     252                 : 
     253                 :     // Enable move for all clang versions - some versions need it
     254                 :     when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
     255                 : 
     256                 :     // Non-copyable
     257                 :     when_all_runner(when_all_runner const&) = delete;
     258                 :     when_all_runner& operator=(when_all_runner const&) = delete;
     259                 :     when_all_runner& operator=(when_all_runner&&) = delete;
     260                 : 
     261             130 :     auto release() noexcept
     262                 :     {
     263             130 :         return std::exchange(h_, nullptr);
     264                 :     }
     265                 : };
     266                 : 
     267                 : /** Create a runner coroutine for a single awaitable.
     268                 : 
     269                 :     Awaitable is passed directly to ensure proper coroutine frame storage.
     270                 : */
     271                 : template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
     272                 : when_all_runner<awaitable_result_t<Awaitable>, Ts...>
     273             130 : make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
     274                 : {
     275                 :     using T = awaitable_result_t<Awaitable>;
     276                 :     if constexpr (std::is_void_v<T>)
     277                 :     {
     278                 :         co_await std::move(inner);
     279                 :     }
     280                 :     else
     281                 :     {
     282                 :         std::get<Index>(state->results_).set(co_await std::move(inner));
     283                 :     }
     284             260 : }
     285                 : 
     286                 : /** Internal awaitable that launches all runner coroutines and waits.
     287                 : 
     288                 :     This awaitable is used inside the when_all coroutine to handle
     289                 :     the concurrent execution of child awaitables.
     290                 : */
     291                 : template<IoAwaitable... Awaitables>
     292                 : class when_all_launcher
     293                 : {
     294                 :     using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
     295                 : 
     296                 :     std::tuple<Awaitables...>* awaitables_;
     297                 :     state_type* state_;
     298                 : 
     299                 : public:
     300              59 :     when_all_launcher(
     301                 :         std::tuple<Awaitables...>* awaitables,
     302                 :         state_type* state)
     303              59 :         : awaitables_(awaitables)
     304              59 :         , state_(state)
     305                 :     {
     306              59 :     }
     307                 : 
     308              59 :     bool await_ready() const noexcept
     309                 :     {
     310              59 :         return sizeof...(Awaitables) == 0;
     311                 :     }
     312                 : 
     313              59 :     std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
     314                 :     {
     315              59 :         state_->continuation_ = continuation;
     316              59 :         state_->caller_env_ = caller_env;
     317                 : 
     318                 :         // Forward parent's stop requests to children
     319              59 :         if(caller_env->stop_token.stop_possible())
     320                 :         {
     321              16 :             state_->parent_stop_callback_.emplace(
     322               8 :                 caller_env->stop_token,
     323               8 :                 typename state_type::stop_callback_fn{&state_->stop_source_});
     324                 : 
     325               8 :             if(caller_env->stop_token.stop_requested())
     326               4 :                 state_->stop_source_.request_stop();
     327                 :         }
     328                 : 
     329                 :         // CRITICAL: If the last task finishes synchronously then the parent
     330                 :         // coroutine resumes, destroying its frame, and destroying this object
     331                 :         // prior to the completion of await_suspend. Therefore, await_suspend
     332                 :         // must ensure `this` cannot be referenced after calling `launch_one`
     333                 :         // for the last time.
     334              59 :         auto token = state_->stop_source_.get_token();
     335              58 :         [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     336              59 :             (..., launch_one<Is>(caller_env->executor, token));
     337              59 :         }(std::index_sequence_for<Awaitables...>{});
     338                 : 
     339                 :         // Let signal_completion() handle resumption
     340             118 :         return std::noop_coroutine();
     341              59 :     }
     342                 : 
     343              59 :     void await_resume() const noexcept
     344                 :     {
     345                 :         // Results are extracted by the when_all coroutine from state
     346              59 :     }
     347                 : 
     348                 : private:
     349                 :     template<std::size_t I>
     350             130 :     void launch_one(executor_ref caller_ex, std::stop_token token)
     351                 :     {
     352             130 :         auto runner = make_when_all_runner<I>(
     353             130 :             std::move(std::get<I>(*awaitables_)), state_);
     354                 : 
     355             130 :         auto h = runner.release();
     356             130 :         h.promise().state_ = state_;
     357             130 :         h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->allocator};
     358                 : 
     359             130 :         std::coroutine_handle<> ch{h};
     360             130 :         state_->runner_handles_[I] = ch;
     361             130 :         state_->caller_env_->executor.post(ch);
     362             260 :     }
     363                 : };
     364                 : 
     365                 : /** Compute the result type for when_all.
     366                 : 
     367                 :     Returns void when all tasks are void (P2300 aligned),
     368                 :     otherwise returns a tuple with void types filtered out.
     369                 : */
     370                 : template<typename... Ts>
     371                 : using when_all_result_t = std::conditional_t<
     372                 :     std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
     373                 :     void,
     374                 :     filter_void_tuple_t<Ts...>>;
     375                 : 
     376                 : /** Helper to extract a single result, returning empty tuple for void.
     377                 :     This is a separate function to work around a GCC-11 ICE that occurs
     378                 :     when using nested immediately-invoked lambdas with pack expansion.
     379                 : */
     380                 : template<std::size_t I, typename... Ts>
     381              57 : auto extract_single_result(when_all_state<Ts...>& state)
     382                 : {
     383                 :     using T = std::tuple_element_t<I, std::tuple<Ts...>>;
     384                 :     if constexpr (std::is_void_v<T>)
     385               4 :         return std::tuple<>();
     386                 :     else
     387              53 :         return std::make_tuple(std::move(std::get<I>(state.results_)).get());
     388                 : }
     389                 : 
     390                 : /** Extract results from state, filtering void types.
     391                 : */
     392                 : template<typename... Ts>
     393              24 : auto extract_results(when_all_state<Ts...>& state)
     394                 : {
     395              43 :     return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     396              25 :         return std::tuple_cat(extract_single_result<Is>(state)...);
     397              48 :     }(std::index_sequence_for<Ts...>{});
     398                 : }
     399                 : 
     400                 : } // namespace detail
     401                 : 
     402                 : /** Execute multiple awaitables concurrently and collect their results.
     403                 : 
     404                 :     Launches all awaitables simultaneously and waits for all to complete
     405                 :     before returning. Results are collected in input order. If any
     406                 :     awaitable throws, cancellation is requested for siblings and the first
     407                 :     exception is rethrown after all awaitables complete.
     408                 : 
     409                 :     @li All child awaitables run concurrently on the caller's executor
     410                 :     @li Results are returned as a tuple in input order
     411                 :     @li Void-returning awaitables do not contribute to the result tuple
     412                 :     @li If all awaitables return void, `when_all` returns `task<void>`
     413                 :     @li First exception wins; subsequent exceptions are discarded
     414                 :     @li Stop is requested for siblings on first error
     415                 :     @li Completes only after all children have finished
     416                 : 
     417                 :     @par Thread Safety
     418                 :     The returned task must be awaited from a single execution context.
     419                 :     Child awaitables execute concurrently but complete through the caller's
     420                 :     executor.
     421                 : 
     422                 :     @param awaitables The awaitables to execute concurrently. Each must
     423                 :         satisfy @ref IoAwaitable and is consumed (moved-from) when
     424                 :         `when_all` is awaited.
     425                 : 
     426                 :     @return A task yielding a tuple of non-void results. Returns
     427                 :         `task<void>` when all input awaitables return void.
     428                 : 
     429                 :     @par Example
     430                 : 
     431                 :     @code
     432                 :     task<> example()
     433                 :     {
     434                 :         // Concurrent fetch, results collected in order
     435                 :         auto [user, posts] = co_await when_all(
     436                 :             fetch_user( id ),      // task<User>
     437                 :             fetch_posts( id )      // task<std::vector<Post>>
     438                 :         );
     439                 : 
     440                 :         // Void awaitables don't contribute to result
     441                 :         co_await when_all(
     442                 :             log_event( "start" ),  // task<void>
     443                 :             notify_user( id )      // task<void>
     444                 :         );
     445                 :         // Returns task<void>, no result tuple
     446                 :     }
     447                 :     @endcode
     448                 : 
     449                 :     @see IoAwaitable, task
     450                 : */
     451                 : template<IoAwaitable... As>
     452              59 : [[nodiscard]] auto when_all(As... awaitables)
     453                 :     -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
     454                 : {
     455                 :     using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
     456                 : 
     457                 :     // State is stored in the coroutine frame, using the frame allocator
     458                 :     detail::when_all_state<detail::awaitable_result_t<As>...> state;
     459                 : 
     460                 :     // Store awaitables in the frame
     461                 :     std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
     462                 : 
     463                 :     // Launch all awaitables and wait for completion
     464                 :     co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
     465                 : 
     466                 :     // Propagate first exception if any.
     467                 :     // Safe without explicit acquire: capture_exception() is sequenced-before
     468                 :     // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
     469                 :     // last task's decrement that resumes this coroutine.
     470                 :     if(state.first_exception_)
     471                 :         std::rethrow_exception(state.first_exception_);
     472                 : 
     473                 :     // Extract and return results
     474                 :     if constexpr (std::is_void_v<result_type>)
     475                 :         co_return;
     476                 :     else
     477                 :         co_return detail::extract_results(state);
     478             118 : }
     479                 : 
     480                 : /// Compute the result type of `when_all` for the given task types.
     481                 : template<typename... Ts>
     482                 : using when_all_result_type = detail::when_all_result_t<Ts...>;
     483                 : 
     484                 : } // namespace capy
     485                 : } // namespace boost
     486                 : 
     487                 : #endif
        

Generated by: LCOV version 2.3