TLA Line data Source code
1 : //
2 : // Copyright (c) 2026 Steve Gerbino
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 : #define BOOST_CAPY_WHEN_ALL_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/concept/executor.hpp>
15 : #include <boost/capy/concept/io_awaitable.hpp>
16 : #include <coroutine>
17 : #include <boost/capy/ex/io_env.hpp>
18 : #include <boost/capy/ex/frame_allocator.hpp>
19 : #include <boost/capy/task.hpp>
20 :
21 : #include <array>
22 : #include <atomic>
23 : #include <exception>
24 : #include <optional>
25 : #include <stop_token>
26 : #include <tuple>
27 : #include <type_traits>
28 : #include <utility>
29 :
30 : namespace boost {
31 : namespace capy {
32 :
33 : namespace detail {
34 :
35 : /** Type trait to filter void types from a tuple.
36 :
37 : Void-returning tasks do not contribute a value to the result tuple.
38 : This trait computes the filtered result type.
39 :
40 : Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 : */
42 : template<typename T>
43 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44 :
45 : template<typename... Ts>
46 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47 :
48 : /** Holds the result of a single task within when_all.
49 : */
50 : template<typename T>
51 : struct result_holder
52 : {
53 : std::optional<T> value_;
54 :
55 HIT 60 : void set(T v)
56 : {
57 60 : value_ = std::move(v);
58 60 : }
59 :
60 53 : T get() &&
61 : {
62 53 : return std::move(*value_);
63 : }
64 : };
65 :
66 : /** Specialization for void tasks - no value storage needed.
67 : */
68 : template<>
69 : struct result_holder<void>
70 : {
71 : };
72 :
73 : /** Shared state for when_all operation.
74 :
75 : @tparam Ts The result types of the tasks.
76 : */
77 : template<typename... Ts>
78 : struct when_all_state
79 : {
80 : static constexpr std::size_t task_count = sizeof...(Ts);
81 :
82 : // Completion tracking - when_all waits for all children
83 : std::atomic<std::size_t> remaining_count_;
84 :
85 : // Result storage in input order
86 : std::tuple<result_holder<Ts>...> results_;
87 :
88 : // Runner handles - destroyed in await_resume while allocator is valid
89 : std::array<std::coroutine_handle<>, task_count> runner_handles_{};
90 :
91 : // Exception storage - first error wins, others discarded
92 : std::atomic<bool> has_exception_{false};
93 : std::exception_ptr first_exception_;
94 :
95 : // Stop propagation - on error, request stop for siblings
96 : std::stop_source stop_source_;
97 :
98 : // Connects parent's stop_token to our stop_source
99 : struct stop_callback_fn
100 : {
101 : std::stop_source* source_;
102 4 : void operator()() const { source_->request_stop(); }
103 : };
104 : using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 : std::optional<stop_callback_t> parent_stop_callback_;
106 :
107 : // Parent resumption
108 : std::coroutine_handle<> continuation_;
109 : io_env const* caller_env_ = nullptr;
110 :
111 59 : when_all_state()
112 59 : : remaining_count_(task_count)
113 : {
114 59 : }
115 :
116 : // Runners self-destruct in final_suspend. No destruction needed here.
117 :
118 : /** Capture an exception (first one wins).
119 : */
120 20 : void capture_exception(std::exception_ptr ep)
121 : {
122 20 : bool expected = false;
123 20 : if(has_exception_.compare_exchange_strong(
124 : expected, true, std::memory_order_relaxed))
125 17 : first_exception_ = ep;
126 20 : }
127 :
128 : };
129 :
130 : /** Wrapper coroutine that intercepts task completion.
131 :
132 : This runner awaits its assigned task and stores the result in
133 : the shared state, or captures the exception and requests stop.
134 : */
135 : template<typename T, typename... Ts>
136 : struct when_all_runner
137 : {
138 : struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 : {
140 : when_all_state<Ts...>* state_ = nullptr;
141 : io_env env_;
142 :
143 130 : when_all_runner get_return_object()
144 : {
145 130 : return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
146 : }
147 :
148 130 : std::suspend_always initial_suspend() noexcept
149 : {
150 130 : return {};
151 : }
152 :
153 130 : auto final_suspend() noexcept
154 : {
155 : struct awaiter
156 : {
157 : promise_type* p_;
158 :
159 130 : bool await_ready() const noexcept
160 : {
161 130 : return false;
162 : }
163 :
164 130 : std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
165 : {
166 : // Extract everything needed before self-destruction.
167 130 : auto* state = p_->state_;
168 130 : auto* counter = &state->remaining_count_;
169 130 : auto* caller_env = state->caller_env_;
170 130 : auto cont = state->continuation_;
171 :
172 130 : h.destroy();
173 :
174 : // If last runner, dispatch parent for symmetric transfer.
175 130 : auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
176 130 : if(remaining == 1)
177 59 : return caller_env->executor.dispatch(cont);
178 71 : return std::noop_coroutine();
179 : }
180 :
181 MIS 0 : void await_resume() const noexcept
182 : {
183 0 : }
184 : };
185 HIT 130 : return awaiter{this};
186 : }
187 :
188 110 : void return_void()
189 : {
190 110 : }
191 :
192 20 : void unhandled_exception()
193 : {
194 20 : state_->capture_exception(std::current_exception());
195 : // Request stop for sibling tasks
196 20 : state_->stop_source_.request_stop();
197 20 : }
198 :
199 : template<class Awaitable>
200 : struct transform_awaiter
201 : {
202 : std::decay_t<Awaitable> a_;
203 : promise_type* p_;
204 :
205 130 : bool await_ready()
206 : {
207 130 : return a_.await_ready();
208 : }
209 :
210 130 : decltype(auto) await_resume()
211 : {
212 130 : return a_.await_resume();
213 : }
214 :
215 : template<class Promise>
216 129 : auto await_suspend(std::coroutine_handle<Promise> h)
217 : {
218 : #ifdef _MSC_VER
219 : using R = decltype(a_.await_suspend(h, &p_->env_));
220 : if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
221 : a_.await_suspend(h, &p_->env_).resume();
222 : else
223 : return a_.await_suspend(h, &p_->env_);
224 : #else
225 129 : return a_.await_suspend(h, &p_->env_);
226 : #endif
227 : }
228 : };
229 :
230 : template<class Awaitable>
231 130 : auto await_transform(Awaitable&& a)
232 : {
233 : using A = std::decay_t<Awaitable>;
234 : if constexpr (IoAwaitable<A>)
235 : {
236 : return transform_awaiter<Awaitable>{
237 260 : std::forward<Awaitable>(a), this};
238 : }
239 : else
240 : {
241 : static_assert(sizeof(A) == 0, "requires IoAwaitable");
242 : }
243 130 : }
244 : };
245 :
246 : std::coroutine_handle<promise_type> h_;
247 :
248 130 : explicit when_all_runner(std::coroutine_handle<promise_type> h)
249 130 : : h_(h)
250 : {
251 130 : }
252 :
253 : // Enable move for all clang versions - some versions need it
254 : when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
255 :
256 : // Non-copyable
257 : when_all_runner(when_all_runner const&) = delete;
258 : when_all_runner& operator=(when_all_runner const&) = delete;
259 : when_all_runner& operator=(when_all_runner&&) = delete;
260 :
261 130 : auto release() noexcept
262 : {
263 130 : return std::exchange(h_, nullptr);
264 : }
265 : };
266 :
267 : /** Create a runner coroutine for a single awaitable.
268 :
269 : Awaitable is passed directly to ensure proper coroutine frame storage.
270 : */
271 : template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
272 : when_all_runner<awaitable_result_t<Awaitable>, Ts...>
273 130 : make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
274 : {
275 : using T = awaitable_result_t<Awaitable>;
276 : if constexpr (std::is_void_v<T>)
277 : {
278 : co_await std::move(inner);
279 : }
280 : else
281 : {
282 : std::get<Index>(state->results_).set(co_await std::move(inner));
283 : }
284 260 : }
285 :
286 : /** Internal awaitable that launches all runner coroutines and waits.
287 :
288 : This awaitable is used inside the when_all coroutine to handle
289 : the concurrent execution of child awaitables.
290 : */
291 : template<IoAwaitable... Awaitables>
292 : class when_all_launcher
293 : {
294 : using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
295 :
296 : std::tuple<Awaitables...>* awaitables_;
297 : state_type* state_;
298 :
299 : public:
300 59 : when_all_launcher(
301 : std::tuple<Awaitables...>* awaitables,
302 : state_type* state)
303 59 : : awaitables_(awaitables)
304 59 : , state_(state)
305 : {
306 59 : }
307 :
308 59 : bool await_ready() const noexcept
309 : {
310 59 : return sizeof...(Awaitables) == 0;
311 : }
312 :
313 59 : std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
314 : {
315 59 : state_->continuation_ = continuation;
316 59 : state_->caller_env_ = caller_env;
317 :
318 : // Forward parent's stop requests to children
319 59 : if(caller_env->stop_token.stop_possible())
320 : {
321 16 : state_->parent_stop_callback_.emplace(
322 8 : caller_env->stop_token,
323 8 : typename state_type::stop_callback_fn{&state_->stop_source_});
324 :
325 8 : if(caller_env->stop_token.stop_requested())
326 4 : state_->stop_source_.request_stop();
327 : }
328 :
329 : // CRITICAL: If the last task finishes synchronously then the parent
330 : // coroutine resumes, destroying its frame, and destroying this object
331 : // prior to the completion of await_suspend. Therefore, await_suspend
332 : // must ensure `this` cannot be referenced after calling `launch_one`
333 : // for the last time.
334 59 : auto token = state_->stop_source_.get_token();
335 58 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
336 59 : (..., launch_one<Is>(caller_env->executor, token));
337 59 : }(std::index_sequence_for<Awaitables...>{});
338 :
339 : // Let signal_completion() handle resumption
340 118 : return std::noop_coroutine();
341 59 : }
342 :
343 59 : void await_resume() const noexcept
344 : {
345 : // Results are extracted by the when_all coroutine from state
346 59 : }
347 :
348 : private:
349 : template<std::size_t I>
350 130 : void launch_one(executor_ref caller_ex, std::stop_token token)
351 : {
352 130 : auto runner = make_when_all_runner<I>(
353 130 : std::move(std::get<I>(*awaitables_)), state_);
354 :
355 130 : auto h = runner.release();
356 130 : h.promise().state_ = state_;
357 130 : h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->allocator};
358 :
359 130 : std::coroutine_handle<> ch{h};
360 130 : state_->runner_handles_[I] = ch;
361 130 : state_->caller_env_->executor.post(ch);
362 260 : }
363 : };
364 :
365 : /** Compute the result type for when_all.
366 :
367 : Returns void when all tasks are void (P2300 aligned),
368 : otherwise returns a tuple with void types filtered out.
369 : */
370 : template<typename... Ts>
371 : using when_all_result_t = std::conditional_t<
372 : std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
373 : void,
374 : filter_void_tuple_t<Ts...>>;
375 :
376 : /** Helper to extract a single result, returning empty tuple for void.
377 : This is a separate function to work around a GCC-11 ICE that occurs
378 : when using nested immediately-invoked lambdas with pack expansion.
379 : */
380 : template<std::size_t I, typename... Ts>
381 57 : auto extract_single_result(when_all_state<Ts...>& state)
382 : {
383 : using T = std::tuple_element_t<I, std::tuple<Ts...>>;
384 : if constexpr (std::is_void_v<T>)
385 4 : return std::tuple<>();
386 : else
387 53 : return std::make_tuple(std::move(std::get<I>(state.results_)).get());
388 : }
389 :
390 : /** Extract results from state, filtering void types.
391 : */
392 : template<typename... Ts>
393 24 : auto extract_results(when_all_state<Ts...>& state)
394 : {
395 43 : return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
396 25 : return std::tuple_cat(extract_single_result<Is>(state)...);
397 48 : }(std::index_sequence_for<Ts...>{});
398 : }
399 :
400 : } // namespace detail
401 :
402 : /** Execute multiple awaitables concurrently and collect their results.
403 :
404 : Launches all awaitables simultaneously and waits for all to complete
405 : before returning. Results are collected in input order. If any
406 : awaitable throws, cancellation is requested for siblings and the first
407 : exception is rethrown after all awaitables complete.
408 :
409 : @li All child awaitables run concurrently on the caller's executor
410 : @li Results are returned as a tuple in input order
411 : @li Void-returning awaitables do not contribute to the result tuple
412 : @li If all awaitables return void, `when_all` returns `task<void>`
413 : @li First exception wins; subsequent exceptions are discarded
414 : @li Stop is requested for siblings on first error
415 : @li Completes only after all children have finished
416 :
417 : @par Thread Safety
418 : The returned task must be awaited from a single execution context.
419 : Child awaitables execute concurrently but complete through the caller's
420 : executor.
421 :
422 : @param awaitables The awaitables to execute concurrently. Each must
423 : satisfy @ref IoAwaitable and is consumed (moved-from) when
424 : `when_all` is awaited.
425 :
426 : @return A task yielding a tuple of non-void results. Returns
427 : `task<void>` when all input awaitables return void.
428 :
429 : @par Example
430 :
431 : @code
432 : task<> example()
433 : {
434 : // Concurrent fetch, results collected in order
435 : auto [user, posts] = co_await when_all(
436 : fetch_user( id ), // task<User>
437 : fetch_posts( id ) // task<std::vector<Post>>
438 : );
439 :
440 : // Void awaitables don't contribute to result
441 : co_await when_all(
442 : log_event( "start" ), // task<void>
443 : notify_user( id ) // task<void>
444 : );
445 : // Returns task<void>, no result tuple
446 : }
447 : @endcode
448 :
449 : @see IoAwaitable, task
450 : */
451 : template<IoAwaitable... As>
452 59 : [[nodiscard]] auto when_all(As... awaitables)
453 : -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
454 : {
455 : using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
456 :
457 : // State is stored in the coroutine frame, using the frame allocator
458 : detail::when_all_state<detail::awaitable_result_t<As>...> state;
459 :
460 : // Store awaitables in the frame
461 : std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
462 :
463 : // Launch all awaitables and wait for completion
464 : co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
465 :
466 : // Propagate first exception if any.
467 : // Safe without explicit acquire: capture_exception() is sequenced-before
468 : // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
469 : // last task's decrement that resumes this coroutine.
470 : if(state.first_exception_)
471 : std::rethrow_exception(state.first_exception_);
472 :
473 : // Extract and return results
474 : if constexpr (std::is_void_v<result_type>)
475 : co_return;
476 : else
477 : co_return detail::extract_results(state);
478 118 : }
479 :
480 : /// Compute the result type of `when_all` for the given task types.
481 : template<typename... Ts>
482 : using when_all_result_type = detail::when_all_result_t<Ts...>;
483 :
484 : } // namespace capy
485 : } // namespace boost
486 :
487 : #endif
|