include/boost/capy/when_all.hpp

96.9% Lines (95/98) 91.3% Functions (484/530)
include/boost/capy/when_all.hpp
Line TLA Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_awaitable.hpp>
16 #include <coroutine>
17 #include <boost/capy/ex/io_env.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 60 void set(T v)
56 {
57 60 value_ = std::move(v);
58 60 }
59
60 53 T get() &&
61 {
62 53 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<std::coroutine_handle<>, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 4 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 std::coroutine_handle<> continuation_;
109 io_env const* caller_env_ = nullptr;
110
111 59 when_all_state()
112 59 : remaining_count_(task_count)
113 {
114 59 }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 20 void capture_exception(std::exception_ptr ep)
121 {
122 20 bool expected = false;
123 20 if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 17 first_exception_ = ep;
126 20 }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 io_env env_;
142
143 130 when_all_runner get_return_object()
144 {
145 130 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
146 }
147
148 130 std::suspend_always initial_suspend() noexcept
149 {
150 130 return {};
151 }
152
153 130 auto final_suspend() noexcept
154 {
155 struct awaiter
156 {
157 promise_type* p_;
158
159 58 bool await_ready() const noexcept
160 {
161 58 return false;
162 }
163
164 58 std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
165 {
166 // Extract everything needed before self-destruction.
167 58 auto* state = p_->state_;
168 58 auto* counter = &state->remaining_count_;
169 58 auto* caller_env = state->caller_env_;
170 58 auto cont = state->continuation_;
171
172 58 h.destroy();
173
174 // If last runner, dispatch parent for symmetric transfer.
175 58 auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
176 58 if(remaining == 1)
177 29 return caller_env->executor.dispatch(cont);
178 29 return std::noop_coroutine();
179 }
180
181 void await_resume() const noexcept
182 {
183 }
184 };
185 130 return awaiter{this};
186 }
187
188 110 void return_void()
189 {
190 110 }
191
192 20 void unhandled_exception()
193 {
194 20 state_->capture_exception(std::current_exception());
195 // Request stop for sibling tasks
196 20 state_->stop_source_.request_stop();
197 20 }
198
199 template<class Awaitable>
200 struct transform_awaiter
201 {
202 std::decay_t<Awaitable> a_;
203 promise_type* p_;
204
205 130 bool await_ready()
206 {
207 130 return a_.await_ready();
208 }
209
210 130 decltype(auto) await_resume()
211 {
212 130 return a_.await_resume();
213 }
214
215 template<class Promise>
216 129 auto await_suspend(std::coroutine_handle<Promise> h)
217 {
218 #ifdef _MSC_VER
219 using R = decltype(a_.await_suspend(h, &p_->env_));
220 if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
221 a_.await_suspend(h, &p_->env_).resume();
222 else
223 return a_.await_suspend(h, &p_->env_);
224 #else
225 129 return a_.await_suspend(h, &p_->env_);
226 #endif
227 }
228 };
229
230 template<class Awaitable>
231 130 auto await_transform(Awaitable&& a)
232 {
233 using A = std::decay_t<Awaitable>;
234 if constexpr (IoAwaitable<A>)
235 {
236 return transform_awaiter<Awaitable>{
237 260 std::forward<Awaitable>(a), this};
238 }
239 else
240 {
241 static_assert(sizeof(A) == 0, "requires IoAwaitable");
242 }
243 130 }
244 };
245
246 std::coroutine_handle<promise_type> h_;
247
248 130 explicit when_all_runner(std::coroutine_handle<promise_type> h)
249 130 : h_(h)
250 {
251 130 }
252
253 // Enable move for all clang versions - some versions need it
254 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
255
256 // Non-copyable
257 when_all_runner(when_all_runner const&) = delete;
258 when_all_runner& operator=(when_all_runner const&) = delete;
259 when_all_runner& operator=(when_all_runner&&) = delete;
260
261 130 auto release() noexcept
262 {
263 130 return std::exchange(h_, nullptr);
264 }
265 };
266
267 /** Create a runner coroutine for a single awaitable.
268
269 Awaitable is passed directly to ensure proper coroutine frame storage.
270 */
271 template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
272 when_all_runner<awaitable_result_t<Awaitable>, Ts...>
273 130 make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
274 {
275 using T = awaitable_result_t<Awaitable>;
276 if constexpr (std::is_void_v<T>)
277 {
278 co_await std::move(inner);
279 }
280 else
281 {
282 std::get<Index>(state->results_).set(co_await std::move(inner));
283 }
284 260 }
285
286 /** Internal awaitable that launches all runner coroutines and waits.
287
288 This awaitable is used inside the when_all coroutine to handle
289 the concurrent execution of child awaitables.
290 */
291 template<IoAwaitable... Awaitables>
292 class when_all_launcher
293 {
294 using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
295
296 std::tuple<Awaitables...>* awaitables_;
297 state_type* state_;
298
299 public:
300 59 when_all_launcher(
301 std::tuple<Awaitables...>* awaitables,
302 state_type* state)
303 59 : awaitables_(awaitables)
304 59 , state_(state)
305 {
306 59 }
307
308 59 bool await_ready() const noexcept
309 {
310 59 return sizeof...(Awaitables) == 0;
311 }
312
313 59 std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
314 {
315 59 state_->continuation_ = continuation;
316 59 state_->caller_env_ = caller_env;
317
318 // Forward parent's stop requests to children
319 59 if(caller_env->stop_token.stop_possible())
320 {
321 16 state_->parent_stop_callback_.emplace(
322 8 caller_env->stop_token,
323 8 typename state_type::stop_callback_fn{&state_->stop_source_});
324
325 8 if(caller_env->stop_token.stop_requested())
326 4 state_->stop_source_.request_stop();
327 }
328
329 // CRITICAL: If the last task finishes synchronously then the parent
330 // coroutine resumes, destroying its frame, and destroying this object
331 // prior to the completion of await_suspend. Therefore, await_suspend
332 // must ensure `this` cannot be referenced after calling `launch_one`
333 // for the last time.
334 59 auto token = state_->stop_source_.get_token();
335 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
336 30 (..., launch_one<Is>(caller_env->executor, token));
337 59 }(std::index_sequence_for<Awaitables...>{});
338
339 // Let signal_completion() handle resumption
340 118 return std::noop_coroutine();
341 59 }
342
343 59 void await_resume() const noexcept
344 {
345 // Results are extracted by the when_all coroutine from state
346 59 }
347
348 private:
349 template<std::size_t I>
350 130 void launch_one(executor_ref caller_ex, std::stop_token token)
351 {
352 130 auto runner = make_when_all_runner<I>(
353 130 std::move(std::get<I>(*awaitables_)), state_);
354
355 130 auto h = runner.release();
356 130 h.promise().state_ = state_;
357 130 h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->allocator};
358
359 130 std::coroutine_handle<> ch{h};
360 130 state_->runner_handles_[I] = ch;
361 130 state_->caller_env_->executor.post(ch);
362 260 }
363 };
364
365 /** Compute the result type for when_all.
366
367 Returns void when all tasks are void (P2300 aligned),
368 otherwise returns a tuple with void types filtered out.
369 */
370 template<typename... Ts>
371 using when_all_result_t = std::conditional_t<
372 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
373 void,
374 filter_void_tuple_t<Ts...>>;
375
376 /** Helper to extract a single result, returning empty tuple for void.
377 This is a separate function to work around a GCC-11 ICE that occurs
378 when using nested immediately-invoked lambdas with pack expansion.
379 */
380 template<std::size_t I, typename... Ts>
381 57 auto extract_single_result(when_all_state<Ts...>& state)
382 {
383 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
384 if constexpr (std::is_void_v<T>)
385 4 return std::tuple<>();
386 else
387 53 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
388 }
389
390 /** Extract results from state, filtering void types.
391 */
392 template<typename... Ts>
393 24 auto extract_results(when_all_state<Ts...>& state)
394 {
395 24 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
396 5 return std::tuple_cat(extract_single_result<Is>(state)...);
397 48 }(std::index_sequence_for<Ts...>{});
398 }
399
400 } // namespace detail
401
402 /** Execute multiple awaitables concurrently and collect their results.
403
404 Launches all awaitables simultaneously and waits for all to complete
405 before returning. Results are collected in input order. If any
406 awaitable throws, cancellation is requested for siblings and the first
407 exception is rethrown after all awaitables complete.
408
409 @li All child awaitables run concurrently on the caller's executor
410 @li Results are returned as a tuple in input order
411 @li Void-returning awaitables do not contribute to the result tuple
412 @li If all awaitables return void, `when_all` returns `task<void>`
413 @li First exception wins; subsequent exceptions are discarded
414 @li Stop is requested for siblings on first error
415 @li Completes only after all children have finished
416
417 @par Thread Safety
418 The returned task must be awaited from a single execution context.
419 Child awaitables execute concurrently but complete through the caller's
420 executor.
421
422 @param awaitables The awaitables to execute concurrently. Each must
423 satisfy @ref IoAwaitable and is consumed (moved-from) when
424 `when_all` is awaited.
425
426 @return A task yielding a tuple of non-void results. Returns
427 `task<void>` when all input awaitables return void.
428
429 @par Example
430
431 @code
432 task<> example()
433 {
434 // Concurrent fetch, results collected in order
435 auto [user, posts] = co_await when_all(
436 fetch_user( id ), // task<User>
437 fetch_posts( id ) // task<std::vector<Post>>
438 );
439
440 // Void awaitables don't contribute to result
441 co_await when_all(
442 log_event( "start" ), // task<void>
443 notify_user( id ) // task<void>
444 );
445 // Returns task<void>, no result tuple
446 }
447 @endcode
448
449 @see IoAwaitable, task
450 */
451 template<IoAwaitable... As>
452 59 [[nodiscard]] auto when_all(As... awaitables)
453 -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
454 {
455 using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
456
457 // State is stored in the coroutine frame, using the frame allocator
458 detail::when_all_state<detail::awaitable_result_t<As>...> state;
459
460 // Store awaitables in the frame
461 std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
462
463 // Launch all awaitables and wait for completion
464 co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
465
466 // Propagate first exception if any.
467 // Safe without explicit acquire: capture_exception() is sequenced-before
468 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
469 // last task's decrement that resumes this coroutine.
470 if(state.first_exception_)
471 std::rethrow_exception(state.first_exception_);
472
473 // Extract and return results
474 if constexpr (std::is_void_v<result_type>)
475 co_return;
476 else
477 co_return detail::extract_results(state);
478 118 }
479
480 /// Compute the result type of `when_all` for the given task types.
481 template<typename... Ts>
482 using when_all_result_type = detail::when_all_result_t<Ts...>;
483
484 } // namespace capy
485 } // namespace boost
486
487 #endif
488