/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include namespace folly { namespace channels { namespace detail { template class ChannelProcessorImpl { public: ChannelProcessorImpl( std::vector> executors, std::shared_ptr rateLimiter, MergeChannel mergeChannel, Receiver> mergeChannelReceiver) : implState_(make_intrusive( std::move(executors), std::move(rateLimiter))), channels_(std::move(mergeChannel)), handle_(consumeChannelWithCallback( std::move(mergeChannelReceiver), implState_->executors[0], [](folly::Try>) -> folly::coro::Task { // Do nothing co_return true; })) {} template void addChannel(KeyType key, ReceiverType receiver, OnUpdateFunc onUpdate) { using InputValueType = typename ReceiverType::ValueType; channels_.removeReceiver(key); channels_.addNewReceiver( std::move(key), transform( std::move(receiver), Transformer( implState_, std::move(onUpdate)))); } template < typename InitializeArg, typename InitializeFunc, typename OnUpdateFunc> void addResumableChannelWithState( KeyType key, InitializeArg initializeArg, InitializeFunc initialize, OnUpdateFunc onUpdate) { addResumableChannelWithState( std::move(key), std::move(initializeArg), std::move(initialize), std::move(onUpdate), NoChannelState()); } template < typename InitializeArg, typename InitializeFunc, typename OnUpdateFunc, typename ChannelState> void addResumableChannelWithState( KeyType key, InitializeArg initializeArg, InitializeFunc initialize, OnUpdateFunc onUpdate, ChannelState channelState) { using ReceiverType = typename decltype(initialize( std::move(initializeArg), channelState))::StorageType; using InputValueType = typename ReceiverType::ValueType; channels_.removeReceiver(key); channels_.addNewReceiver( std::move(key), resumableTransform( std::move(initializeArg), ResumableTransformer< InitializeArg, InputValueType, InitializeFunc, OnUpdateFunc, ChannelState>( implState_, std::move(initialize), std::move(onUpdate), std::move(channelState)))); } void removeChannel(const KeyType& keyType) { channels_.removeReceiver(keyType); } private: struct NoChannelState {}; template < typename Function, typename ReturnType = typename std::invoke_result_t::StorageType> static folly::coro::Task catchNonCoroException(Function func) { auto result = folly::makeTryWith(std::move(func)); if (result.hasException()) { return folly::coro::makeErrorTask( std::move(result.exception())); } else { return std::move(result.value()); } } struct ImplState : public IntrusivePtrBase { ImplState( std::vector> _executors, std::shared_ptr _rateLimiter) : executors(std::move(_executors)), rateLimiter(std::move(_rateLimiter)) {} std::vector> executors; std::shared_ptr rateLimiter; }; template class Transformer : public std::tuple { public: Transformer(intrusive_ptr implState, OnUpdateFunc onUpdate) : std::tuple(std::move(onUpdate)), implState_(std::move(implState)) {} folly::Executor::KeepAlive getExecutor() { return implState_->executors [std::hash()(this) % implState_->executors.size()]; } std::shared_ptr getRateLimiter() { return implState_->rateLimiter; } folly::coro::AsyncGenerator transformValue( folly::Try value) { auto result = co_await folly::coro::co_awaitTry(catchNonCoroException( [&] { return std::get(*this)(std::move(value)); })); if (result.template hasException() || result.template hasException()) { co_yield folly::coro::co_error(OnClosedException()); } else if (result.hasException()) { LOG(FATAL) << fmt::format( "Encountered exception from callback when consuming channel of " "type {}: {}", typeid(InputValueType).name(), result.exception().what()); } } private: intrusive_ptr implState_; }; template < typename InitializeArg, typename InputValueType, typename InitializeFunc, typename OnUpdateFunc, typename ChannelState> class ResumableTransformer : public std::tuple { public: ResumableTransformer( intrusive_ptr implState, InitializeFunc initialize, OnUpdateFunc onUpdate, ChannelState channelState) : std::tuple( std::move(initialize), std::move(onUpdate), std::move(channelState)), implState_(std::move(implState)) {} folly::Executor::KeepAlive getExecutor() { return implState_->executors [std::hash()(this) % implState_->executors.size()]; } std::shared_ptr getRateLimiter() { return implState_->rateLimiter; } folly::coro::Task< std::pair, Receiver>> initializeTransform(InitializeArg initializeArg) { auto result = co_await folly::coro::co_awaitTry( initialize(std::move(initializeArg))); if (result.template hasException() || result.template hasException()) { co_yield folly::coro::co_error(OnClosedException()); } else if (result.hasException()) { LOG(FATAL) << folly::sformat( "Encountered exception from callback when consuming channel of " "type {}: {}", typeid(InputValueType).name(), result.exception().what()); } co_return std::make_pair( std::vector(), std::move(result.value())); } folly::coro::AsyncGenerator transformValue( folly::Try value) { auto result = co_await folly::coro::co_awaitTry(onUpdate(std::move(value))); if (result .template hasException>()) { co_yield folly::coro::co_error(std::move(result.exception())); } else if ( result.template hasException() || result.template hasException()) { co_yield folly::coro::co_error(OnClosedException()); } else if (result.hasException()) { LOG(FATAL) << folly::sformat( "Encountered exception from callback when consuming channel of " "type {}: {}", typeid(InputValueType).name(), result.exception().what()); } } private: folly::coro::Task> initialize( InitializeArg initializeArg) { if constexpr (std::is_same_v) { co_return co_await catchNonCoroException([&] { return std::get(*this)(std::move(initializeArg)); }); } else { co_return co_await catchNonCoroException([&] { return std::get(*this)( std::move(initializeArg), std::get(*this)); }); } } folly::coro::Task onUpdate(folly::Try value) { if constexpr (std::is_same_v) { co_await catchNonCoroException( [&] { return std::get(*this)(std::move(value)); }); } else { co_await catchNonCoroException([&] { return std::get(*this)( std::move(value), std::get(*this)); }); } } intrusive_ptr implState_; }; intrusive_ptr implState_; MergeChannel channels_; ChannelCallbackHandle handle_; }; } // namespace detail template ChannelProcessor::ChannelProcessor( std::unique_ptr> impl) : impl_(std::move(impl)) {} template template void ChannelProcessor::addChannel( KeyType key, ReceiverType receiver, OnUpdateFunc onUpdate) { impl_->addChannel(std::move(key), std::move(receiver), std::move(onUpdate)); } template template < typename InitializeArg, typename InitializeFunc, typename OnUpdateFunc> void ChannelProcessor::addResumableChannel( KeyType key, InitializeArg initializeArg, InitializeFunc initialize, OnUpdateFunc onUpdate) { impl_->addResumableChannel( std::move(key), std::move(initializeArg), std::move(initialize), std::move(onUpdate)); } template template < typename InitializeArg, typename InitializeFunc, typename OnUpdateFunc, typename ChannelState> void ChannelProcessor::addResumableChannelWithState( KeyType key, InitializeArg initializeArg, InitializeFunc initialize, OnUpdateFunc onUpdate, ChannelState channelState) { impl_->addResumableChannelWithState( std::move(key), std::move(initializeArg), std::move(initialize), std::move(onUpdate), std::move(channelState)); } template void ChannelProcessor::removeChannel(const KeyType& keyType) { impl_->removeChannel(keyType); } template void ChannelProcessor::close() && { impl_.reset(); } template ChannelProcessor createChannelProcessor( folly::Executor::KeepAlive<> executor, std::shared_ptr rateLimiter, size_t numSequencedExecutors) { CHECK_GT(numSequencedExecutors, 0); auto executors = std::vector>(); for (size_t i = 0; i < numSequencedExecutors; i++) { executors.push_back(folly::SerialExecutor::create(executor)); } auto [mergeChannelReceiver, mergeChannel] = createMergeChannel(executors[0]); return ChannelProcessor( std::make_unique>( std::move(executors), std::move(rateLimiter), std::move(mergeChannel), std::move(mergeChannelReceiver))); } } // namespace channels } // namespace folly