/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include #include namespace folly { namespace channels { namespace detail { struct CloseResult { CloseResult() {} explicit CloseResult(folly::exception_wrapper _exception) : exception(std::move(_exception)) {} std::optional exception; }; enum class ChannelState { Active, CancellationTriggered, CancellationProcessed }; template ChannelState getSenderState(TSender* sender) { if (sender == nullptr) { return ChannelState::CancellationProcessed; } else if (sender->isSenderClosed()) { return ChannelState::CancellationTriggered; } else { return ChannelState::Active; } } template ChannelState getReceiverState(TReceiver* receiver) { if (receiver == nullptr) { return ChannelState::CancellationProcessed; } else if (receiver->isReceiverCancelled()) { return ChannelState::CancellationTriggered; } else { return ChannelState::Active; } } inline std::ostream& operator<<(std::ostream& os, ChannelState state) { switch (state) { case ChannelState::Active: return os << "Active"; case ChannelState::CancellationTriggered: return os << "CancellationTriggered"; case ChannelState::CancellationProcessed: return os << "CancellationProcessed"; default: return os << "Should never be hit"; } } /** * A cancellation callback that wraps an existing channel callback. When the * callback is fired, this object will trigger cancellation on its cancellation * source (in addition to firing the wrapped callback). */ template class SenderCancellationCallback : public IChannelCallback { public: explicit SenderCancellationCallback( TSender& sender, folly::Executor::KeepAlive executor, IChannelCallback* channelCallback) : sender_(sender), executor_(std::move(executor)), channelCallback_(channelCallback), callbackToFire_(folly::coro::makePromiseContract()) { if (channelCallback_ == nullptr) { // The sender was already canceled runOperationWithSenderCancellation was // even called. This means the cancelled callback already was fired, so // we will not set the callback to fire here. cancelSource_.requestCancellation(); return; } CHECK(sender_); if (!sender_->senderWait(this)) { // The sender was cancelled after runOperationWithSenderCancellation was // called, but before we had a chance to start the operation. This means // that the cancelled callback was never called. We will therefore set it // to fire here, when the operation is complete. cancelSource_.requestCancellation(); callbackToFire_.first.setValue(CallbackToFire::Consume); } } folly::coro::Task onTaskCompleted() { if (!channelCallback_) { co_return; } auto callbackToFire = std::optional(); bool promiseSet = false; if (callbackToFire_.second.isReady()) { // The callback was fired. promiseSet = true; callbackToFire = co_await std::move(callbackToFire_.second); } else { // The callback has not yet been fired. if (!sender_->cancelSenderWait()) { // The sender has been cancelled, but the callback has not been called // yet. Wait for the callback to be called. promiseSet = true; callbackToFire = co_await std::move(callbackToFire_.second); } else if (!sender_->senderWait(channelCallback_)) { // The sender was cancelled between the call to cancelSenderWait and // the call to senderWait. This means that the cancelled callback was // never called. We will therefore set it to fire here. callbackToFire = CallbackToFire::Consume; } } if (!promiseSet) { // Set a default value here, so we don't need to waste time constructing a // broken promise exception when the promise is destructed. This value // will not be read. callbackToFire_.first.setValue(CallbackToFire::Consume); } if (callbackToFire.has_value()) { switch (callbackToFire.value()) { case CallbackToFire::Consume: channelCallback_->consume(sender_.get()); co_return; case CallbackToFire::Canceled: channelCallback_->canceled(sender_.get()); co_return; } } // The sender has not yet been cancelled, and we are now back in the state // where the sender is waiting on the user-provided callback. We are done. } /** * Returns a cancellation token that will trigger when the sender */ folly::CancellationToken getCancellationToken() { return cancelSource_.getToken(); } /** * Requests cancellation, and triggers the consume function on the callback * if the callback was not previously triggered. */ void consume(ChannelBridgeBase*) override { cancelSource_.requestCancellation(); executor_->add([=]() { CHECK(!callbackToFire_.second.isReady()); callbackToFire_.first.setValue(CallbackToFire::Consume); }); } /** * Requests cancellation, and triggers the canceled function on the callback * if the callback was not previously triggered. */ void canceled(ChannelBridgeBase*) override { cancelSource_.requestCancellation(); executor_->add([=]() { CHECK(!callbackToFire_.second.isReady()); callbackToFire_.first.setValue(CallbackToFire::Canceled); }); } private: enum class CallbackToFire { Consume, Canceled }; TSender& sender_; folly::Executor::KeepAlive executor_; IChannelCallback* channelCallback_; folly::CancellationSource cancelSource_; std::pair< folly::coro::Promise, folly::coro::Future> callbackToFire_; }; /** * Any object that produces an output receiver (transform, merge, * MergeChannel, etc) will listen for a cancellation signal from that output * receiver. Once the consumer of the output receiver stops consuming, a * callback will be called that triggers these objects to start cleaning * themselves up (and eventually destroy themselves). * * However, when one of these objects decides to run a user coroutine, they * would like that user coroutine to be able to get notified when that * cancellation signal is received. That allows the coroutine to stop any * long-running operations quickly, rather than running a long time when the * consumer of the output receiver no longer cares about the result. * * This function enables that behavior. It will run the provided operation * coroutine. While that coroutine is running, it will listen to cancellation * events from the output receiver (through its sender). If it receives a * cancellation signal from the sender, it will trigger cancellation of the * operation coroutine. * * Once the coroutine finishes, it will then call the given channel callback * to notify it of the cancellation event (the same way that callback would * have been notified if no coroutine had been started). It will also resume * waiting on the channel callback. * * @param executor: The executor to run the coroutine on. * * @param sender: The sender to use to listen for cancellation. If this is * null, we will assume that cancellation already occurred. * * @param alreadyStartedWaiting: Whether or not the caller already started * listening for a cancellation signal from the output receiver. If so, this * function will temporarily stop waiting with that callback (so it can listen * for the cancellation signal to stop the coroutine). * * @param channelCallbackToRestore: The channel callback to restore once the * coroutine operation is complete. * * @param operation: The operation to run. * * @param token: The rate limiter token for this operation. */ template void runOperationWithSenderCancellation( folly::Executor::KeepAlive executor, TSender& sender, bool alreadyStartedWaiting, IChannelCallback* channelCallbackToRestore, folly::coro::Task operation, RateLimiter::Token token) noexcept { if (alreadyStartedWaiting && (!sender || !sender->cancelSenderWait())) { // The output receiver was cancelled before starting this operation // (indicating that the channel callback already ran). channelCallbackToRestore = nullptr; } folly::coro::co_invoke( [&sender, executor, channelCallbackToRestore, token = std::move(token), operation = std::move(operation)]() mutable -> folly::coro::Task { auto senderCancellationCallback = SenderCancellationCallback( sender, executor, channelCallbackToRestore); auto result = co_await folly::coro::co_awaitTry(folly::coro::co_withCancellation( senderCancellationCallback.getCancellationToken(), std::move(operation))); if (result.hasException()) { LOG(FATAL) << fmt::format( "Unexpected exception when running coroutine operation with " "sender cancellation: {}", result.exception().what()); } co_await senderCancellationCallback.onTaskCompleted(); }) .scheduleOn(executor) .start(); } } // namespace detail } // namespace channels } // namespace folly