/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include namespace folly { namespace channels { namespace detail { template ChannelBridgePtr& senderGetBridge(Sender& sender) { return sender.bridge_; } template bool receiverWait( Receiver& receiver, detail::IChannelCallback* callback) { if (!receiver.buffer_.empty()) { return false; } return receiver.bridge_->receiverWait(callback); } template detail::IChannelCallback* cancelReceiverWait(Receiver& receiver) { return receiver.bridge_->cancelReceiverWait(); } template std::optional> receiverGetValue(Receiver& receiver) { if (receiver.buffer_.empty()) { receiver.buffer_ = receiver.bridge_->receiverGetValues(); if (receiver.buffer_.empty()) { return std::nullopt; } } auto result = std::move(receiver.buffer_.front()); receiver.buffer_.pop(); return result; } template std::pair, detail::ReceiverQueue> receiverUnbuffer(Receiver&& receiver) { return std::make_pair( std::move(receiver.bridge_), std::move(receiver.buffer_)); } } // namespace detail template class Receiver::Waiter : public detail::IChannelCallback { public: Waiter( Receiver* receiver, folly::CancellationToken cancelToken, bool closeOnCancel) : state_(State{.receiver = receiver}), cancelCallback_( makeCancellationCallback(std::move(cancelToken), closeOnCancel)) {} bool await_ready() const noexcept { // We are ready immediately if the receiver is either cancelled or closed. return state_.withRLock( [&](const State& state) { return state.cancelled || !state.receiver; }); } bool await_suspend(folly::coro::coroutine_handle<> awaitingCoroutine) { return state_.withWLock([&](State& state) { if (state.cancelled || !state.receiver || !receiverWait(*state.receiver, this)) { // We will not suspend at all if the receiver is either cancelled or // closed. return false; } state.awaitingCoroutine = awaitingCoroutine; return true; }); } std::optional await_resume() { auto result = getResult(); if (!result.hasValue() && !result.hasException()) { return std::nullopt; } return std::move(result.value()); } folly::Try await_resume_try() { return getResult(); } protected: struct State { Receiver* receiver; folly::coro::coroutine_handle<> awaitingCoroutine; bool cancelled{false}; }; std::unique_ptr makeCancellationCallback( folly::CancellationToken cancelToken, bool closeOnCancel) { if (!cancelToken.canBeCancelled()) { return nullptr; } return std::make_unique( std::move(cancelToken), [this, closeOnCancel] { auto receiver = state_.withWLock([&](State& state) { state.cancelled = true; return std::exchange(state.receiver, nullptr); }); if (!receiver) { return; } if (closeOnCancel) { std::move(*receiver).cancel(); } else { auto* callback = detail::cancelReceiverWait(*receiver); if (callback) { callback->canceled(nullptr); } } }); } void consume(detail::ChannelBridgeBase*) override { resume(); } void canceled(detail::ChannelBridgeBase*) override { resume(); } void resume() { auto awaitingCoroutine = state_.withWLock([&](State& state) { return std::exchange(state.awaitingCoroutine, nullptr); }); awaitingCoroutine.resume(); } folly::Try getResult() { cancelCallback_.reset(); return state_.withWLock([&](State& state) { if (state.cancelled) { return folly::Try( folly::make_exception_wrapper()); } if (!state.receiver) { return folly::Try(); } auto result = std::move(detail::receiverGetValue(*state.receiver).value()); if (!result.hasValue()) { std::move(*state.receiver).cancel(); state.receiver = nullptr; } return result; }); } folly::Synchronized state_; std::unique_ptr cancelCallback_; }; template struct Receiver::NextSemiAwaitable { public: explicit NextSemiAwaitable( Receiver* receiver, bool closeOnCancel, std::optional cancelToken = std::nullopt) : receiver_(receiver), closeOnCancel_(closeOnCancel), cancelToken_(std::move(cancelToken)) {} [[nodiscard]] Waiter operator co_await() { return Waiter( receiver_, cancelToken_.value_or(folly::CancellationToken()), closeOnCancel_); } friend NextSemiAwaitable co_withCancellation( folly::CancellationToken cancelToken, NextSemiAwaitable&& awaitable) { if (awaitable.cancelToken_.has_value()) { return std::move(awaitable); } return NextSemiAwaitable( awaitable.receiver_, awaitable.closeOnCancel_, std::move(cancelToken)); } private: Receiver* receiver_; bool closeOnCancel_; std::optional cancelToken_; }; } // namespace channels } // namespace folly