diff --git a/.github/workflows/TagIt.yml b/.github/workflows/TagIt.yml new file mode 100644 index 000000000..2c4b889d6 --- /dev/null +++ b/.github/workflows/TagIt.yml @@ -0,0 +1,68 @@ +on: + push: + tags: + # Only match TagIt tags, which always start with this prefix + - 'v20*' + +name: TagIt + +jobs: + build: + name: Release + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Archive project + id: archive_project + run: | + FILE_NAME=${GITHUB_REPOSITORY#*/}-${GITHUB_REF##*/} + git archive ${{ github.ref }} -o ${FILE_NAME}.zip + git archive ${{ github.ref }} -o ${FILE_NAME}.tar.gz + echo "::set-output name=file_name::${FILE_NAME}" + - name: Compute digests + id: compute_digests + run: | + echo "::set-output name=tgz_256::$(openssl dgst -sha256 ${{ steps.archive_project.outputs.file_name }}.tar.gz)" + echo "::set-output name=tgz_512::$(openssl dgst -sha512 ${{ steps.archive_project.outputs.file_name }}.tar.gz)" + echo "::set-output name=zip_256::$(openssl dgst -sha256 ${{ steps.archive_project.outputs.file_name }}.zip)" + echo "::set-output name=zip_512::$(openssl dgst -sha512 ${{ steps.archive_project.outputs.file_name }}.zip)" + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref }} + release_name: ${{ github.ref }} + body: | + Automated release from TagIt +
+ File Hashes + +
+ draft: false + prerelease: false + - name: Upload zip + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ./${{ steps.archive_project.outputs.file_name }}.zip + asset_name: ${{ steps.archive_project.outputs.file_name }}.zip + asset_content_type: application/zip + - name: Upload tar.gz + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ./${{ steps.archive_project.outputs.file_name }}.tar.gz + asset_name: ${{ steps.archive_project.outputs.file_name }}.tar.gz + asset_content_type: application/gzip diff --git a/.gitignore b/.gitignore index bf568c4c9..e00c3821f 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,6 @@ install_manifest.txt ### Project ### -/build /reactivesocket-cpp/CTestTestfile.cmake /reactivesocket-cpp/ReactiveSocketTest /reactivesocket-cpp/compile_commands.json diff --git a/.travis.yml b/.travis.yml index 8e72dfa19..1c88d192b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,9 +14,6 @@ addons: packages: &common_deps - lcov # Folly dependencies - - autoconf - - autoconf-archive - - automake - binutils-dev - g++ - libboost-all-dev @@ -25,42 +22,27 @@ addons: - libgflags-dev - libgoogle-glog-dev - libiberty-dev - - libjemalloc-dev - liblz4-dev - liblzma-dev - libsnappy-dev - libssl-dev - - libtool - make - - pkg-config - zlib1g-dev matrix: include: - # Set COMPILER environment variable instead of CC or CXX because the latter - # are overriden by Travis. Setting the compiler in Travis doesn't work - # either because it strips version. - - env: COMPILER=clang-4.0 + - env: COMPILER_EVAL="CC=clang-6.0 CXX=clang++-6.0" addons: apt: sources: - *common_srcs - - llvm-toolchain-trusty-4.0 + - llvm-toolchain-trusty-6.0 packages: - *common_deps - - clang-4.0 + - clang-6.0 - libstdc++-4.9-dev - - env: COMPILER=gcc-4.9 - addons: - apt: - sources: - - *common_srcs - packages: - - *common_deps - - g++-4.9 - - - env: COMPILER=gcc-5 + - env: COMPILER_EVAL="CC=gcc-5 CXX=g++-5" addons: apt: sources: @@ -68,8 +50,9 @@ matrix: packages: - *common_deps - g++-5 + - libjemalloc-dev - - env: COMPILER=gcc-6 + - env: COMPILER_EVAL="CC=gcc-6 CXX=g++-6" addons: apt: sources: @@ -77,6 +60,7 @@ matrix: packages: - *common_deps - g++-6 + - libjemalloc-dev env: global: @@ -93,30 +77,38 @@ env: eHz/lHAoLXWg/BhtgQbPmMYYKRrQaH7EKzBbqEHv6PhOk7vLMtdx5X7KmhVuFjpAMbaYoj zwxxH0u+VAnVB5iazzyjhySjvzkvx6pGzZtTnjLJHxKcp9633z4OU= -cache: - directories: - - $HOME/folly - before_script: + - eval "$COMPILER_EVAL" + - export DEP_INSTALL_DIR=$PWD/build/dep-install + # Ubuntu trusty only comes with OpenSSL 1.0.1f, but we require + # at least OpenSSL 1.0.2 for ALPN support. + - curl -L https://github.com/openssl/openssl/archive/OpenSSL_1_1_1.tar.gz -o OpenSSL_1_1_1.tar.gz + - tar -xzf OpenSSL_1_1_1.tar.gz + - cd openssl-OpenSSL_1_1_1 + - ./config --prefix=$DEP_INSTALL_DIR no-shared + - make -j4 + - make install_sw install_ssldirs + - cd .. # Install lcov to coveralls conversion + upload tool. - gem install coveralls-lcov - lcov --version + # Build folly + - ./scripts/build_folly.sh build/folly-src $DEP_INSTALL_DIR script: - - mkdir build - cd build - - cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DRSOCKET_CC=$COMPILER - -DRSOCKET_ASAN=$ASAN -DRSOCKET_INSTALL_DEPS=True + - cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DRSOCKET_ASAN=$ASAN + -DCMAKE_PREFIX_PATH=$DEP_INSTALL_DIR -DRSOCKET_BUILD_WITH_COVERAGE=ON .. - make -j4 - lcov --directory . --zerocounters - - make test + # - make test # - make coverage - - cd .. - - ./scripts/tck_test.sh -c cpp -s cpp - - ./scripts/tck_test.sh -c java -s java - - ./scripts/tck_test.sh -c java -s cpp - - ./scripts/tck_test.sh -c cpp -s java + # - cd .. + # - ./scripts/tck_test.sh -c cpp -s cpp + # - ./scripts/tck_test.sh -c java -s java + # - ./scripts/tck_test.sh -c java -s cpp + # - ./scripts/tck_test.sh -c cpp -s java after_success: # Upload to coveralls. diff --git a/.ycm_extra_conf.py b/.ycm_extra_conf.py deleted file mode 100644 index 4f893b15e..000000000 --- a/.ycm_extra_conf.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import os.path -import logging -import ycm_core - -BASE_FLAGS = [ - '-xc++', - '-Wall', - '-Wextra', - '-Werror', - '-std=c++11', - '-I.', - '-isystem/usr/lib/', - '-isystem/usr/include/', -] - -SOURCE_EXTENSIONS = [ - '.cpp', - '.cxx', - '.cc', - '.c', - '.m', - '.mm' -] - -HEADER_EXTENSIONS = [ - '.h', - '.hxx', - '.hpp', - '.hh', - '.icc', - '.tcc', -] - - -def IsHeaderFile(filename): - extension = os.path.splitext(filename)[1] - return extension in HEADER_EXTENSIONS - - -def GetCompilationInfoForFile(database, filename): - if IsHeaderFile(filename): - basename = os.path.splitext(filename)[0] - for extension in SOURCE_EXTENSIONS: - replacement_file = basename + extension - if os.path.exists(replacement_file): - compilation_info = database.GetCompilationInfoForFile( - replacement_file) - if compilation_info.compiler_flags_: - return compilation_info - return None - return database.GetCompilationInfoForFile(filename) - - -def FindNearest(path, target): - candidate = os.path.join(path, target) - if(os.path.isfile(candidate) or os.path.isdir(candidate)): - logging.info("Found nearest " + target + " at " + candidate) - return candidate - else: - parent = os.path.dirname(os.path.abspath(path)) - if(parent == path): - raise RuntimeError("Could not find " + target) - return FindNearest(parent, target) - - -def MakeRelativePathsInFlagsAbsolute(flags, working_directory): - if not working_directory: - return list(flags) - new_flags = [] - make_next_absolute = False - path_flags = ['-isystem', '-I', '-iquote', '--sysroot='] - for flag in flags: - new_flag = flag - - if make_next_absolute: - make_next_absolute = False - if not flag.startswith('/'): - new_flag = os.path.join(working_directory, flag) - - for path_flag in path_flags: - if flag == path_flag: - make_next_absolute = True - break - - if flag.startswith(path_flag): - path = flag[len(path_flag):] - new_flag = path_flag + os.path.join(working_directory, path) - break - - if new_flag: - new_flags.append(new_flag) - return new_flags - - -def FlagsForClangComplete(root): - try: - clang_complete_path = FindNearest(root, '.clang_complete') - clang_complete_flags = open( - clang_complete_path, 'r').read().splitlines() - return clang_complete_flags - except: - return None - - -def FlagsForInclude(root): - try: - include_path = FindNearest(root, 'include') - flags = [] - for dirroot, dirnames, filenames in os.walk(include_path): - for dir_path in dirnames: - real_path = os.path.join(dirroot, dir_path) - flags = flags + ["-I" + real_path] - return flags - except: - return None - - -def FlagsForCompilationDatabase(root, filename): - try: - compilation_db_path = FindNearest( - os.path.join(root, 'build'), 'compile_commands.json') - compilation_db_dir = os.path.dirname(compilation_db_path) - logging.info( - "Set compilation database directory to " + compilation_db_dir) - compilation_db = ycm_core.CompilationDatabase(compilation_db_dir) - if not compilation_db: - logging.info("Compilation database file found but unable to load") - return None - compilation_info = GetCompilationInfoForFile(compilation_db, filename) - if not compilation_info: - logging.info( - "No compilation info for " + filename + " in compilation database") - return None - return MakeRelativePathsInFlagsAbsolute( - compilation_info.compiler_flags_, - compilation_info.compiler_working_dir_) - except: - return None - - -def FlagsForFile(filename): - root = os.path.realpath(filename) - compilation_db_flags = FlagsForCompilationDatabase(root, filename) - if compilation_db_flags: - final_flags = compilation_db_flags - else: - final_flags = BASE_FLAGS - clang_flags = FlagsForClangComplete(root) - if clang_flags: - final_flags = final_flags + clang_flags - include_flags = FlagsForInclude(root) - if include_flags: - final_flags = final_flags + include_flags - return { - 'flags': final_flags, - 'do_cache': True - } diff --git a/CMakeLists.txt b/CMakeLists.txt index f115a12fe..c736ccbf0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,23 +1,21 @@ cmake_minimum_required(VERSION 3.2) -# The RSOCKET_CC CMake variable specifies the C compiler, e.g. gcc-4.9. -# The C++ compiler name is obtained by replacing "gcc" with "g++" and "clang" -# with "clang++"". If RSOCKET_CC is not given, the compiler is detected -# automatically. -if (RSOCKET_CC) - set(ENV{CC} ${RSOCKET_CC}) - if (${RSOCKET_CC} MATCHES clang) - string(REPLACE clang clang++ CXX ${RSOCKET_CC}) - else () - string(REPLACE gcc g++ CXX ${RSOCKET_CC}) - endif () - set(ENV{CXX} ${CXX}) -endif () - project(ReactiveSocket) +if (NOT DEFINED CPACK_GENERATOR) + set(CPACK_GENERATOR "RPM") +endif() +include(CPack) + # CMake modules. -set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/") +set(CMAKE_MODULE_PATH + "${CMAKE_SOURCE_DIR}/cmake/" + # For in-fbsource builds + "${CMAKE_CURRENT_SOURCE_DIR}/../opensource/fbcode_builder/CMake" + # For shipit-transformed builds + "${CMAKE_CURRENT_SOURCE_DIR}/build/fbcode_builder/CMake" + ${CMAKE_MODULE_PATH} +) # Joins arguments and stores the result in ${var}. function(join var) @@ -128,6 +126,8 @@ if (DEFINED ASAN_FLAGS) endif () option(BUILD_BENCHMARKS "Build benchmarks" ON) +option(BUILD_EXAMPLES "Build examples" ON) +option(BUILD_TESTS "Build tests" ON) enable_testing() @@ -136,31 +136,44 @@ include(CTest) include(${CMAKE_SOURCE_DIR}/cmake/InstallFolly.cmake) -# gmock -ExternalProject_Add( - gmock - URL ${CMAKE_CURRENT_SOURCE_DIR}/googletest-release-1.8.0.zip - INSTALL_COMMAND "" -) +if(BUILD_TESTS) + # gmock + ExternalProject_Add( + gmock + URL ${CMAKE_CURRENT_SOURCE_DIR}/googletest-release-1.8.0.zip + INSTALL_COMMAND "" + ) -ExternalProject_Get_Property(gmock source_dir) -set(GMOCK_SOURCE_DIR ${source_dir}) -ExternalProject_Get_Property(gmock binary_dir) -set(GMOCK_BINARY_DIR ${binary_dir}) + ExternalProject_Get_Property(gmock source_dir) + set(GMOCK_SOURCE_DIR ${source_dir}) + ExternalProject_Get_Property(gmock binary_dir) + set(GMOCK_BINARY_DIR ${binary_dir}) -set(GMOCK_LIBS - ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX} - ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX} + set(GMOCK_LIBS + ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX} + ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX} ) + include_directories(${GMOCK_SOURCE_DIR}/googlemock/include) + include_directories(${GMOCK_SOURCE_DIR}/googletest/include) + +endif() + set(CMAKE_CXX_STANDARD 14) +include(CheckCXXCompilerFlag) + # Common configuration for all build modes. -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Woverloaded-virtual") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") +if (NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Woverloaded-virtual") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") +endif() -set(EXTRA_CXX_FLAGS ${EXTRA_CXX_FLAGS} -Werror) +CHECK_CXX_COMPILER_FLAG(-Wnoexcept-type COMPILER_HAS_W_NOEXCEPT_TYPE) +if (COMPILER_HAS_W_NOEXCEPT_TYPE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-noexcept-type") +endif() if("${BUILD_TYPE_LOWER}" MATCHES "debug") message("debug mode was set") @@ -177,26 +190,16 @@ find_library(DOUBLE-CONVERSION double-conversion) find_package(OpenSSL REQUIRED) -# Find glog and gflags libraries specifically -find_path(GLOG_INCLUDE_DIR glog/logging.h) -find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h) +find_package(Gflags REQUIRED) -find_library(GLOG_LIBRARY glog) -find_library(GFLAGS_LIBRARY gflags) +# find glog::glog to satisfy the folly dep. +find_package(Glog REQUIRED) -message("gflags include_dir <${GFLAGS_INCLUDE_DIR}> lib <${GFLAGS_LIBRARY}>") -message("glog include_dir <${GLOG_INCLUDE_DIR}> lib <${GLOG_LIBRARY}>") +find_package(fmt CONFIG REQUIRED) include_directories(SYSTEM ${OPENSSL_INCLUDE_DIR}) include_directories(SYSTEM ${GFLAGS_INCLUDE_DIR}) -include_directories(SYSTEM ${GLOG_INCLUDE_DIR}) - -include_directories(${CMAKE_SOURCE_DIR}) - -include_directories(${CMAKE_CURRENT_BINARY_DIR}/reactivestreams/include) -include_directories(${GMOCK_SOURCE_DIR}/googlemock/include) -include_directories(${GMOCK_SOURCE_DIR}/googletest/include) add_subdirectory(yarpl) @@ -240,10 +243,6 @@ add_library( rsocket/framing/FrameProcessor.h rsocket/framing/FrameSerializer.cpp rsocket/framing/FrameSerializer.h - rsocket/framing/FrameSerializer_v0.cpp - rsocket/framing/FrameSerializer_v0.h - rsocket/framing/FrameSerializer_v0_1.cpp - rsocket/framing/FrameSerializer_v0_1.h rsocket/framing/FrameSerializer_v1_0.cpp rsocket/framing/FrameSerializer_v1_0.h rsocket/framing/FrameTransport.h @@ -255,8 +254,10 @@ add_library( rsocket/framing/FramedDuplexConnection.h rsocket/framing/FramedReader.cpp rsocket/framing/FramedReader.h - rsocket/framing/FramedWriter.cpp - rsocket/framing/FramedWriter.h + rsocket/framing/ProtocolVersion.cpp + rsocket/framing/ProtocolVersion.h + rsocket/framing/ResumeIdentificationToken.cpp + rsocket/framing/ResumeIdentificationToken.h rsocket/framing/ScheduledFrameProcessor.cpp rsocket/framing/ScheduledFrameProcessor.h rsocket/framing/ScheduledFrameTransport.cpp @@ -288,6 +289,8 @@ add_library( rsocket/statemachine/ChannelResponder.h rsocket/statemachine/ConsumerBase.cpp rsocket/statemachine/ConsumerBase.h + rsocket/statemachine/FireAndForgetResponder.cpp + rsocket/statemachine/FireAndForgetResponder.h rsocket/statemachine/PublisherBase.cpp rsocket/statemachine/PublisherBase.h rsocket/statemachine/RSocketStateMachine.cpp @@ -300,13 +303,12 @@ add_library( rsocket/statemachine/StreamRequester.h rsocket/statemachine/StreamResponder.cpp rsocket/statemachine/StreamResponder.h - rsocket/statemachine/StreamState.cpp - rsocket/statemachine/StreamState.h rsocket/statemachine/StreamStateMachineBase.cpp rsocket/statemachine/StreamStateMachineBase.h - rsocket/statemachine/StreamsFactory.cpp - rsocket/statemachine/StreamsFactory.h + rsocket/statemachine/StreamFragmentAccumulator.cpp + rsocket/statemachine/StreamFragmentAccumulator.h rsocket/statemachine/StreamsWriter.h + rsocket/statemachine/StreamsWriter.cpp rsocket/transports/tcp/TcpConnectionAcceptor.cpp rsocket/transports/tcp/TcpConnectionAcceptor.h rsocket/transports/tcp/TcpConnectionFactory.cpp @@ -314,10 +316,17 @@ add_library( rsocket/transports/tcp/TcpDuplexConnection.cpp rsocket/transports/tcp/TcpDuplexConnection.h) -target_include_directories(ReactiveSocket PUBLIC "${PROJECT_SOURCE_DIR}/yarpl/include") -target_include_directories(ReactiveSocket PUBLIC "${PROJECT_SOURCE_DIR}/yarpl/src") +target_include_directories( + ReactiveSocket + PUBLIC + $ + $ +) -target_link_libraries(ReactiveSocket yarpl ${GFLAGS_LIBRARY} ${GLOG_LIBRARY}) + +target_link_libraries(ReactiveSocket + PUBLIC yarpl glog::glog gflags + INTERFACE ${EXTRA_LINK_FLAGS}) target_compile_options( ReactiveSocket @@ -325,62 +334,72 @@ target_compile_options( enable_testing() -install(TARGETS ReactiveSocket DESTINATION lib) +install(TARGETS ReactiveSocket EXPORT rsocket-exports DESTINATION lib) install(DIRECTORY rsocket DESTINATION include FILES_MATCHING PATTERN "*.h") +install(EXPORT rsocket-exports NAMESPACE rsocket:: DESTINATION lib/cmake/rsocket) +include(CMakePackageConfigHelpers) +configure_package_config_file( + cmake/rsocket-config.cmake.in + rsocket-config.cmake + INSTALL_DESTINATION lib/cmake/rsocket +) +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/rsocket-config.cmake + DESTINATION lib/cmake/rsocket +) -# CMake doesn't seem to support "transitive" installing, and I can't access the -# "yarpl" target from this file, so just grab the library file directly. -install(FILES "${CMAKE_CURRENT_BINARY_DIR}/yarpl/libyarpl.a" DESTINATION lib) -install(DIRECTORY yarpl/include/yarpl DESTINATION include FILES_MATCHING PATTERN "*.h") - +if(BUILD_TESTS) add_executable( tests - test/ColdResumptionTest.cpp - test/ConnectionEventsTest.cpp - test/PayloadTest.cpp - test/RSocketClientServerTest.cpp - test/RSocketClientTest.cpp - test/RSocketTests.cpp - test/RSocketTests.h - test/RequestChannelTest.cpp - test/RequestResponseTest.cpp - test/RequestStreamTest.cpp - test/RequestStreamTest_concurrency.cpp - test/Test.cpp - test/WarmResumeManagerTest.cpp - test/WarmResumptionTest.cpp - test/framing/FrameTest.cpp - test/framing/FrameTransportTest.cpp - test/framing/FramedReaderTest.cpp - test/handlers/HelloServiceHandler.cpp - test/handlers/HelloServiceHandler.h - test/handlers/HelloStreamRequestHandler.cpp - test/handlers/HelloStreamRequestHandler.h - test/internal/AllowanceTest.cpp - test/internal/ConnectionSetTest.cpp - test/internal/KeepaliveTimerTest.cpp - test/internal/ResumeIdentificationToken.cpp - test/internal/SetupResumeAcceptorTest.cpp - test/internal/SwappableEventBaseTest.cpp - test/test_utils/ColdResumeManager.cpp - test/test_utils/ColdResumeManager.h - test/test_utils/GenericRequestResponseHandler.h - test/test_utils/MockDuplexConnection.h - test/test_utils/MockKeepaliveTimer.h - test/test_utils/MockRequestHandler.h - test/test_utils/MockStats.h - test/transport/DuplexConnectionTest.cpp - test/transport/DuplexConnectionTest.h - test/transport/TcpDuplexConnectionTest.cpp) - + rsocket/test/ColdResumptionTest.cpp + rsocket/test/ConnectionEventsTest.cpp + rsocket/test/PayloadTest.cpp + rsocket/test/RSocketClientServerTest.cpp + rsocket/test/RSocketClientTest.cpp + rsocket/test/RSocketTests.cpp + rsocket/test/RSocketTests.h + rsocket/test/RequestChannelTest.cpp + rsocket/test/RequestResponseTest.cpp + rsocket/test/RequestStreamTest.cpp + rsocket/test/RequestStreamTest_concurrency.cpp + rsocket/test/Test.cpp + rsocket/test/WarmResumeManagerTest.cpp + rsocket/test/WarmResumptionTest.cpp + rsocket/test/framing/FrameTest.cpp + rsocket/test/framing/FrameTransportTest.cpp + rsocket/test/framing/FramedReaderTest.cpp + rsocket/test/handlers/HelloServiceHandler.cpp + rsocket/test/handlers/HelloServiceHandler.h + rsocket/test/handlers/HelloStreamRequestHandler.cpp + rsocket/test/handlers/HelloStreamRequestHandler.h + rsocket/test/internal/AllowanceTest.cpp + rsocket/test/internal/ConnectionSetTest.cpp + rsocket/test/internal/KeepaliveTimerTest.cpp + rsocket/test/internal/ResumeIdentificationToken.cpp + rsocket/test/internal/SetupResumeAcceptorTest.cpp + rsocket/test/internal/SwappableEventBaseTest.cpp + rsocket/test/statemachine/RSocketStateMachineTest.cpp + rsocket/test/statemachine/StreamStateTest.cpp + rsocket/test/statemachine/StreamsWriterTest.cpp + rsocket/test/test_utils/ColdResumeManager.cpp + rsocket/test/test_utils/ColdResumeManager.h + rsocket/test/test_utils/GenericRequestResponseHandler.h + rsocket/test/test_utils/MockDuplexConnection.h + rsocket/test/test_utils/MockStreamsWriter.h + rsocket/test/test_utils/MockStats.h + rsocket/test/transport/DuplexConnectionTest.cpp + rsocket/test/transport/DuplexConnectionTest.h + rsocket/test/transport/TcpDuplexConnectionTest.cpp) + +add_dependencies(tests gmock) target_link_libraries( tests ReactiveSocket yarpl yarpl-test-utils - ${GMOCK_LIBS} - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + ${GMOCK_LIBS} # This also needs the preceding `add_dependencies` + glog::glog + gflags) target_include_directories(tests PUBLIC "${PROJECT_SOURCE_DIR}/yarpl/test/") target_compile_options( @@ -390,19 +409,18 @@ target_compile_options( add_dependencies(tests gmock yarpl-test-utils ReactiveSocket) add_test(NAME RSocketTests COMMAND tests) -add_test(NAME RSocketTests-0.1 COMMAND tests --rs_use_protocol_version=0.1) ### Fuzzer harnesses add_executable( frame_fuzzer - test/fuzzers/frame_fuzzer.cpp) + rsocket/test/fuzzers/frame_fuzzer.cpp) target_link_libraries( frame_fuzzer ReactiveSocket yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_dependencies(frame_fuzzer gmock ReactiveSocket) @@ -410,84 +428,91 @@ add_test( NAME FrameFuzzerTests COMMAND ./scripts/frame_fuzzer_test.sh WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}) +endif() ######################################## # TCK Drivers ######################################## -add_executable( - tckclient - tck-test/client.cpp - tck-test/TestFileParser.cpp - tck-test/TestFileParser.h - tck-test/FlowableSubscriber.cpp - tck-test/FlowableSubscriber.h - tck-test/SingleSubscriber.cpp - tck-test/SingleSubscriber.h - tck-test/TestSuite.cpp - tck-test/TestSuite.h - tck-test/TestInterpreter.cpp - tck-test/TestInterpreter.h - tck-test/TypedCommands.h - tck-test/BaseSubscriber.cpp - tck-test/BaseSubscriber.h) - -target_link_libraries( - tckclient - ReactiveSocket - yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) - -add_executable( - tckserver - tck-test/server.cpp - tck-test/MarbleProcessor.cpp - tck-test/MarbleProcessor.h - test/test_utils/StatsPrinter.cpp - test/test_utils/StatsPrinter.h) - -target_link_libraries( - tckserver - ReactiveSocket - yarpl - ${GFLAGS_LIBRARY} - ${GMOCK_LIBS} - ${GLOG_LIBRARY} - ${DOUBLE-CONVERSION}) +if(BUILD_TESTS) + add_executable( + tckclient + rsocket/tck-test/client.cpp + rsocket/tck-test/TestFileParser.cpp + rsocket/tck-test/TestFileParser.h + rsocket/tck-test/FlowableSubscriber.cpp + rsocket/tck-test/FlowableSubscriber.h + rsocket/tck-test/SingleSubscriber.cpp + rsocket/tck-test/SingleSubscriber.h + rsocket/tck-test/TestSuite.cpp + rsocket/tck-test/TestSuite.h + rsocket/tck-test/TestInterpreter.cpp + rsocket/tck-test/TestInterpreter.h + rsocket/tck-test/TypedCommands.h + rsocket/tck-test/BaseSubscriber.cpp + rsocket/tck-test/BaseSubscriber.h) + + target_link_libraries( + tckclient + ReactiveSocket + yarpl + glog::glog + gflags) + + add_executable( + tckserver + rsocket/tck-test/server.cpp + rsocket/tck-test/MarbleProcessor.cpp + rsocket/tck-test/MarbleProcessor.h + rsocket/test/test_utils/StatsPrinter.cpp + rsocket/test/test_utils/StatsPrinter.h) + + add_dependencies(tckserver gmock) + target_link_libraries( + tckserver + ReactiveSocket + yarpl + ${GMOCK_LIBS} # This also needs the preceding `add_dependencies` + glog::glog + gflags + ${DOUBLE-CONVERSION}) # Download the latest TCK drivers JAR. -set(TCK_DRIVERS_JAR rsocket-tck-drivers-0.9.10.jar) -join(TCK_DRIVERS_URL - "https://oss.jfrog.org/libs-release/io/rsocket/" - "rsocket-tck-drivers/0.9.10/${TCK_DRIVERS_JAR}") -message(STATUS "Downloading ${TCK_DRIVERS_URL}") -file(DOWNLOAD ${TCK_DRIVERS_URL} ${CMAKE_SOURCE_DIR}/${TCK_DRIVERS_JAR}) + set(TCK_DRIVERS_JAR rsocket-tck-drivers-0.9.10.jar) + if (NOT EXISTS ${CMAKE_SOURCE_DIR}/${TCK_DRIVERS_JAR}) + join(TCK_DRIVERS_URL + "https://oss.jfrog.org/libs-release/io/rsocket/" + "rsocket-tck-drivers/0.9.10/${TCK_DRIVERS_JAR}") + message(STATUS "Downloading ${TCK_DRIVERS_URL}") + file(DOWNLOAD ${TCK_DRIVERS_URL} ${CMAKE_SOURCE_DIR}/${TCK_DRIVERS_JAR}) + endif () +endif() ######################################## # Examples ######################################## +if (BUILD_EXAMPLES) add_library( reactivesocket_examples_util - examples/util/ExampleSubscriber.cpp - examples/util/ExampleSubscriber.h - test/test_utils/ColdResumeManager.h - test/test_utils/ColdResumeManager.cpp + rsocket/examples/util/ExampleSubscriber.cpp + rsocket/examples/util/ExampleSubscriber.h + rsocket/test/test_utils/ColdResumeManager.h + rsocket/test/test_utils/ColdResumeManager.cpp ) target_link_libraries( reactivesocket_examples_util yarpl ReactiveSocket - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # request-response-hello-world add_executable( example_request-response-hello-world-server - examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp + rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp ) target_link_libraries( @@ -495,12 +520,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_request-response-hello-world-client - examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp + rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp ) target_link_libraries( @@ -508,14 +533,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # fire-and-forget-hello-world add_executable( example_fire-and-forget-hello-world-server - examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp + rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp ) target_link_libraries( @@ -523,12 +548,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_fire-and-forget-hello-world-client - examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp + rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp ) target_link_libraries( @@ -536,15 +561,15 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # stream-hello-world add_executable( example_stream-hello-world-server - examples/stream-hello-world/StreamHelloWorld_Server.cpp + rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp ) target_link_libraries( @@ -552,12 +577,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_stream-hello-world-client - examples/stream-hello-world/StreamHelloWorld_Client.cpp + rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp ) target_link_libraries( @@ -565,14 +590,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # channel-hello-world add_executable( example_channel-hello-world-server - examples/channel-hello-world/ChannelHelloWorld_Server.cpp + rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp ) target_link_libraries( @@ -580,12 +605,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) + + add_executable( example_channel-hello-world-client - examples/channel-hello-world/ChannelHelloWorld_Client.cpp + rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp ) target_link_libraries( @@ -593,14 +620,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # stream-observable-to-flowable add_executable( example_observable-to-flowable-server - examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp + rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp ) target_link_libraries( @@ -608,12 +635,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_observable-to-flowable-client - examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp + rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp ) target_link_libraries( @@ -621,18 +648,18 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # conditional-request-handling add_executable( example_conditional-request-handling-server - examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp - examples/conditional-request-handling/TextRequestHandler.h - examples/conditional-request-handling/TextRequestHandler.cpp - examples/conditional-request-handling/JsonRequestHandler.cpp - examples/conditional-request-handling/JsonRequestHandler.h + rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp + rsocket/examples/conditional-request-handling/TextRequestHandler.h + rsocket/examples/conditional-request-handling/TextRequestHandler.cpp + rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp + rsocket/examples/conditional-request-handling/JsonRequestHandler.h ) target_link_libraries( @@ -640,12 +667,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_conditional-request-handling-client - examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp + rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp ) target_link_libraries( @@ -653,14 +680,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # warm-resumption add_executable( example_resumption-server - examples/resumption/Resumption_Server.cpp + rsocket/examples/resumption/Resumption_Server.cpp ) target_link_libraries( @@ -668,12 +695,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_warm-resumption-client - examples/resumption/WarmResumption_Client.cpp + rsocket/examples/resumption/WarmResumption_Client.cpp ) target_link_libraries( @@ -681,12 +708,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_cold-resumption-client - examples/resumption/ColdResumption_Client.cpp + rsocket/examples/resumption/ColdResumption_Client.cpp ) target_link_libraries( @@ -694,13 +721,15 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) + +endif () # BUILD_EXAMPLES ######################################## # End Examples ######################################## if (BUILD_BENCHMARKS) - add_subdirectory(benchmarks) + add_subdirectory(rsocket/benchmarks) endif () diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..d1abc700d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,77 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + diff --git a/LICENSE b/LICENSE index 4d4a15fb0..989e2c59e 100644 --- a/LICENSE +++ b/LICENSE @@ -1,30 +1,201 @@ -BSD License +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -For reactivesocket-cpp software + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -Copyright (c) 2016-present, Facebook, Inc. All rights reserved. + 1. Definitions. -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. - * Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. - * Neither the name Facebook nor the names of its contributors may be used to - endorse or promote products derived from this software without specific - prior written permission. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/PATENTS b/PATENTS deleted file mode 100644 index 3d7f19408..000000000 --- a/PATENTS +++ /dev/null @@ -1,33 +0,0 @@ -Additional Grant of Patent Rights Version 2 - -"Software" means the reactivesocket-cpp software distributed by Facebook, Inc. - -Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software -("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable -(subject to the termination provision below) license under any Necessary -Claims, to make, have made, use, sell, offer to sell, import, and otherwise -transfer the Software. For avoidance of doubt, no license is granted under -Facebook’s rights in any patent claims that are infringed by (i) modifications -to the Software made by you or any third party or (ii) the Software in -combination with any software or other technology. - -The license granted hereunder will terminate, automatically and without notice, -if you (or any of your subsidiaries, corporate affiliates or agents) initiate -directly or indirectly, or take a direct financial interest in, any Patent -Assertion: (i) against Facebook or any of its subsidiaries or corporate -affiliates, (ii) against any party if such Patent Assertion arises in whole or -in part from any software, technology, product or service of Facebook or any of -its subsidiaries or corporate affiliates, or (iii) against any party relating -to the Software. Notwithstanding the foregoing, if Facebook or any of its -subsidiaries or corporate affiliates files a lawsuit alleging patent -infringement against you in the first instance, and you respond by filing a -patent infringement counterclaim in that lawsuit against that party that is -unrelated to the Software, the license granted hereunder will not terminate -under section (i) of this paragraph due to such counterclaim. - -A "Necessary Claim" is a claim of a patent owned by Facebook that is -necessarily infringed by the Software standing alone. - -A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, -or contributory infringement or inducement to infringe any patent, including a -cross-claim or counterclaim. diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..230862230 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,28 @@ +## Motivation and Context + + + + + +## How Has This Been Tested + + + +## Types of changes + + +- [ ] Docs change / refactoring / dependency upgrade +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to change) + +## Checklist + + + +- [ ] My code follows the code style of this project. +- [ ] My change requires a change to the documentation. +- [ ] I have updated the documentation accordingly. +- [ ] I have read the **CONTRIBUTING** document. +- [ ] I have added tests to cover my changes. +- [ ] All new and existing tests passed. diff --git a/README.md b/README.md index 3b811aca9..1a5339e1c 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ C++ implementation of [RSocket](https://rsocket.io) Install `folly`: ``` -brew install folly +brew install --HEAD folly ``` # Building and running tests @@ -25,3 +25,8 @@ cmake -DCMAKE_BUILD_TYPE=DEBUG ../ make -j ./tests ``` + +# License + +By contributing to rsocket-cpp, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/benchmarks/Benchmarks.cpp b/benchmarks/Benchmarks.cpp deleted file mode 100644 index 8c357fbed..000000000 --- a/benchmarks/Benchmarks.cpp +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -int main(int argc, char** argv) { - folly::init(&argc, &argv); - - FLAGS_logtostderr = true; - - LOG(INFO) << "Running benchmarks... (takes minutes)"; - folly::runBenchmarks(); - - return 0; -} diff --git a/benchmarks/Latch.h b/benchmarks/Latch.h deleted file mode 100644 index b8dcc3520..000000000 --- a/benchmarks/Latch.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -/// Simple implementation of a latch synchronization primitive, for testing. -class Latch { - public: - explicit Latch(size_t limit) : limit_{limit} {} - - void wait() { - baton_.wait(); - } - - bool timed_wait(std::chrono::milliseconds timeout) { - return baton_.timed_wait(timeout); - } - - void post() { - auto const old = count_.fetch_add(1); - if (old == limit_ - 1) { - baton_.post(); - } - } - - private: - folly::Baton<> baton_; - std::atomic count_{0}; - const size_t limit_{0}; -}; diff --git a/build/README.md b/build/README.md new file mode 100644 index 000000000..fdcb9fdcb --- /dev/null +++ b/build/README.md @@ -0,0 +1,10 @@ +# Building using `fbcode_builder` + +Continuous integration builds are powered by `fbcode_builder`, a tiny tool +shared by several Facebook projects. Its files are in `./fbcode_builder` +(on Github) or in `fbcode/opensource/fbcode_builder` (inside Facebook's +repo). + +Start with the READMEs in the `fbcode_builder` directory. + +`./fbcode_builder_config.py` contains the project-specific configuration. diff --git a/build/deps/github_hashes/facebook/folly-rev.txt b/build/deps/github_hashes/facebook/folly-rev.txt new file mode 100644 index 000000000..cd836348c --- /dev/null +++ b/build/deps/github_hashes/facebook/folly-rev.txt @@ -0,0 +1 @@ +Subproject commit 2a20a79adf8480dffc165aebc02a93937e15ca94 diff --git a/build/fbcode_builder/.gitignore b/build/fbcode_builder/.gitignore new file mode 100644 index 000000000..b98f3edfa --- /dev/null +++ b/build/fbcode_builder/.gitignore @@ -0,0 +1,5 @@ +# Facebook-internal CI builds don't have write permission outside of the +# source tree, so we install all projects into this directory. +/facebook_ci +__pycache__/ +*.pyc diff --git a/build/fbcode_builder/CMake/FBBuildOptions.cmake b/build/fbcode_builder/CMake/FBBuildOptions.cmake new file mode 100644 index 000000000..dbaa29933 --- /dev/null +++ b/build/fbcode_builder/CMake/FBBuildOptions.cmake @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +function (fb_activate_static_library_option) + option(USE_STATIC_DEPS_ON_UNIX + "If enabled, use static dependencies on unix systems. This is generally discouraged." + OFF + ) + # Mark USE_STATIC_DEPS_ON_UNIX as an "advanced" option, since enabling it + # is generally discouraged. + mark_as_advanced(USE_STATIC_DEPS_ON_UNIX) + + if(UNIX AND USE_STATIC_DEPS_ON_UNIX) + SET(CMAKE_FIND_LIBRARY_SUFFIXES ".a" PARENT_SCOPE) + endif() +endfunction() diff --git a/build/fbcode_builder/CMake/FBCMakeParseArgs.cmake b/build/fbcode_builder/CMake/FBCMakeParseArgs.cmake new file mode 100644 index 000000000..933180189 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCMakeParseArgs.cmake @@ -0,0 +1,141 @@ +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Helper function for parsing arguments to a CMake function. +# +# This function is very similar to CMake's built-in cmake_parse_arguments() +# function, with some improvements: +# - This function correctly handles empty arguments. (cmake_parse_arguments() +# ignores empty arguments.) +# - If a multi-value argument is specified more than once, the subsequent +# arguments are appended to the original list rather than replacing it. e.g. +# if "SOURCES" is a multi-value argument, and the argument list contains +# "SOURCES a b c SOURCES x y z" then the resulting value for SOURCES will be +# "a;b;c;x;y;z" rather than "x;y;z" +# - This function errors out by default on unrecognized arguments. You can +# pass in an extra "ALLOW_UNPARSED_ARGS" argument to make it behave like +# cmake_parse_arguments(), and return the unparsed arguments in a +# _UNPARSED_ARGUMENTS variable instead. +# +# It does look like cmake_parse_arguments() handled empty arguments correctly +# from CMake 3.0 through 3.3, but it seems like this was probably broken when +# it was turned into a built-in function in CMake 3.4. Here is discussion and +# patches that fixed this behavior prior to CMake 3.0: +# https://cmake.org/pipermail/cmake-developers/2013-November/020607.html +# +# The one downside to this function over the built-in cmake_parse_arguments() +# is that I don't think we can achieve the PARSE_ARGV behavior in a non-builtin +# function, so we can't properly handle arguments that contain ";". CMake will +# treat the ";" characters as list element separators, and treat it as multiple +# separate arguments. +# +function(fb_cmake_parse_args PREFIX OPTIONS ONE_VALUE_ARGS MULTI_VALUE_ARGS ARGS) + foreach(option IN LISTS ARGN) + if ("${option}" STREQUAL "ALLOW_UNPARSED_ARGS") + set(ALLOW_UNPARSED_ARGS TRUE) + else() + message( + FATAL_ERROR + "unknown optional argument for fb_cmake_parse_args(): ${option}" + ) + endif() + endforeach() + + # Define all options as FALSE in the parent scope to start with + foreach(var_name IN LISTS OPTIONS) + set("${PREFIX}_${var_name}" "FALSE" PARENT_SCOPE) + endforeach() + + # TODO: We aren't extremely strict about error checking for one-value + # arguments here. e.g., we don't complain if a one-value argument is + # followed by another option/one-value/multi-value name rather than an + # argument. We also don't complain if a one-value argument is the last + # argument and isn't followed by a value. + + list(APPEND all_args ${ONE_VALUE_ARGS}) + list(APPEND all_args ${MULTI_VALUE_ARGS}) + set(current_variable) + set(unparsed_args) + foreach(arg IN LISTS ARGS) + list(FIND OPTIONS "${arg}" opt_index) + if("${opt_index}" EQUAL -1) + list(FIND all_args "${arg}" arg_index) + if("${arg_index}" EQUAL -1) + # This argument does not match an argument name, + # must be an argument value + if("${current_variable}" STREQUAL "") + list(APPEND unparsed_args "${arg}") + else() + # Ugh, CMake lists have a pretty fundamental flaw: they cannot + # distinguish between an empty list and a list with a single empty + # element. We track our own SEEN_VALUES_arg setting to help + # distinguish this and behave properly here. + if ("${SEEN_${current_variable}}" AND "${${current_variable}}" STREQUAL "") + set("${current_variable}" ";${arg}") + else() + list(APPEND "${current_variable}" "${arg}") + endif() + set("SEEN_${current_variable}" TRUE) + endif() + else() + # We found a single- or multi-value argument name + set(current_variable "VALUES_${arg}") + set("SEEN_${arg}" TRUE) + endif() + else() + # We found an option variable + set("${PREFIX}_${arg}" "TRUE" PARENT_SCOPE) + set(current_variable) + endif() + endforeach() + + foreach(arg_name IN LISTS ONE_VALUE_ARGS) + if(NOT "${SEEN_${arg_name}}") + unset("${PREFIX}_${arg_name}" PARENT_SCOPE) + elseif(NOT "${SEEN_VALUES_${arg_name}}") + # If the argument was seen but a value wasn't specified, error out. + # We require exactly one value to be specified. + message( + FATAL_ERROR "argument ${arg_name} was specified without a value" + ) + else() + list(LENGTH "VALUES_${arg_name}" num_args) + if("${num_args}" EQUAL 0) + # We know an argument was specified and that we called list(APPEND). + # If CMake thinks the list is empty that means there is really a single + # empty element in the list. + set("${PREFIX}_${arg_name}" "" PARENT_SCOPE) + elseif("${num_args}" EQUAL 1) + list(GET "VALUES_${arg_name}" 0 arg_value) + set("${PREFIX}_${arg_name}" "${arg_value}" PARENT_SCOPE) + else() + message( + FATAL_ERROR "too many arguments specified for ${arg_name}: " + "${VALUES_${arg_name}}" + ) + endif() + endif() + endforeach() + + foreach(arg_name IN LISTS MULTI_VALUE_ARGS) + # If this argument name was never seen, then unset the parent scope + if (NOT "${SEEN_${arg_name}}") + unset("${PREFIX}_${arg_name}" PARENT_SCOPE) + else() + # TODO: Our caller still won't be able to distinguish between an empty + # list and a list with a single empty element. We can tell which is + # which, but CMake lists don't make it easy to show this to our caller. + set("${PREFIX}_${arg_name}" "${VALUES_${arg_name}}" PARENT_SCOPE) + endif() + endforeach() + + # By default we fatal out on unparsed arguments, but return them to the + # caller if ALLOW_UNPARSED_ARGS was specified. + if (DEFINED unparsed_args) + if ("${ALLOW_UNPARSED_ARGS}") + set("${PREFIX}_UNPARSED_ARGUMENTS" "${unparsed_args}" PARENT_SCOPE) + else() + message(FATAL_ERROR "unrecognized arguments: ${unparsed_args}") + endif() + endif() +endfunction() diff --git a/build/fbcode_builder/CMake/FBCompilerSettings.cmake b/build/fbcode_builder/CMake/FBCompilerSettings.cmake new file mode 100644 index 000000000..585c95320 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCompilerSettings.cmake @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This file applies common compiler settings that are shared across +# a number of Facebook opensource projects. +# Please use caution and your best judgement before making changes +# to these shared compiler settings in order to avoid accidentally +# breaking a build in another project! + +if (WIN32) + include(FBCompilerSettingsMSVC) +else() + include(FBCompilerSettingsUnix) +endif() diff --git a/build/fbcode_builder/CMake/FBCompilerSettingsMSVC.cmake b/build/fbcode_builder/CMake/FBCompilerSettingsMSVC.cmake new file mode 100644 index 000000000..4efd7e966 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCompilerSettingsMSVC.cmake @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This file applies common compiler settings that are shared across +# a number of Facebook opensource projects. +# Please use caution and your best judgement before making changes +# to these shared compiler settings in order to avoid accidentally +# breaking a build in another project! + +add_compile_options( + /wd4250 # 'class1' : inherits 'class2::member' via dominance +) diff --git a/build/fbcode_builder/CMake/FBCompilerSettingsUnix.cmake b/build/fbcode_builder/CMake/FBCompilerSettingsUnix.cmake new file mode 100644 index 000000000..c26ce78b1 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCompilerSettingsUnix.cmake @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This file applies common compiler settings that are shared across +# a number of Facebook opensource projects. +# Please use caution and your best judgement before making changes +# to these shared compiler settings in order to avoid accidentally +# breaking a build in another project! + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wextra -Wno-deprecated -Wno-deprecated-declarations") diff --git a/build/fbcode_builder/CMake/FBPythonBinary.cmake b/build/fbcode_builder/CMake/FBPythonBinary.cmake new file mode 100644 index 000000000..99c33fb8c --- /dev/null +++ b/build/fbcode_builder/CMake/FBPythonBinary.cmake @@ -0,0 +1,697 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) + +# +# This file contains helper functions for building self-executing Python +# binaries. +# +# This is somewhat different than typical python installation with +# distutils/pip/virtualenv/etc. We primarily want to build a standalone +# executable, isolated from other Python packages on the system. We don't want +# to install files into the standard library python paths. This is more +# similar to PEX (https://github.com/pantsbuild/pex) and XAR +# (https://github.com/facebookincubator/xar). (In the future it would be nice +# to update this code to also support directly generating XAR files if XAR is +# available.) +# +# We also want to be able to easily define "libraries" of python files that can +# be shared and re-used between these standalone python executables, and can be +# shared across projects in different repositories. This means that we do need +# a way to "install" libraries so that they are visible to CMake builds in +# other repositories, without actually installing them in the standard python +# library paths. +# + +# If the caller has not already found Python, do so now. +# If we fail to find python now we won't fail immediately, but +# add_fb_python_executable() or add_fb_python_library() will fatal out if they +# are used. +if(NOT TARGET Python3::Interpreter) + # CMake 3.12+ ships with a FindPython3.cmake module. Try using it first. + # We find with QUIET here, since otherwise this generates some noisy warnings + # on versions of CMake before 3.12 + if (WIN32) + # On Windows we need both the Intepreter as well as the Development + # libraries. + find_package(Python3 COMPONENTS Interpreter Development QUIET) + else() + find_package(Python3 COMPONENTS Interpreter QUIET) + endif() + if(Python3_Interpreter_FOUND) + message(STATUS "Found Python 3: ${Python3_EXECUTABLE}") + else() + # Try with the FindPythonInterp.cmake module available in older CMake + # versions. Check to see if the caller has already searched for this + # themselves first. + if(NOT PYTHONINTERP_FOUND) + set(Python_ADDITIONAL_VERSIONS 3 3.6 3.5 3.4 3.3 3.2 3.1) + find_package(PythonInterp) + # TODO: On Windows we require the Python libraries as well. + # We currently do not search for them on this code path. + # For now we require building with CMake 3.12+ on Windows, so that the + # FindPython3 code path above is available. + endif() + if(PYTHONINTERP_FOUND) + if("${PYTHON_VERSION_MAJOR}" GREATER_EQUAL 3) + set(Python3_EXECUTABLE "${PYTHON_EXECUTABLE}") + add_custom_target(Python3::Interpreter) + else() + string( + CONCAT FBPY_FIND_PYTHON_ERR + "found Python ${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}, " + "but need Python 3" + ) + endif() + endif() + endif() +endif() + +# Find our helper program. +# We typically install this in the same directory as this .cmake file. +find_program( + FB_MAKE_PYTHON_ARCHIVE "make_fbpy_archive.py" + PATHS ${CMAKE_MODULE_PATH} +) +set(FB_PY_TEST_MAIN "${CMAKE_CURRENT_LIST_DIR}/fb_py_test_main.py") +set( + FB_PY_TEST_DISCOVER_SCRIPT + "${CMAKE_CURRENT_LIST_DIR}/FBPythonTestAddTests.cmake" +) +set( + FB_PY_WIN_MAIN_C + "${CMAKE_CURRENT_LIST_DIR}/fb_py_win_main.c" +) + +# An option to control the default installation location for +# install_fb_python_library(). This is relative to ${CMAKE_INSTALL_PREFIX} +set( + FBPY_LIB_INSTALL_DIR "lib/fb-py-libs" CACHE STRING + "The subdirectory where FB python libraries should be installed" +) + +# +# Build a self-executing python binary. +# +# This accepts the same arguments as add_fb_python_library(). +# +# In addition, a MAIN_MODULE argument is accepted. This argument specifies +# which module should be started as the __main__ module when the executable is +# run. If left unspecified, a __main__.py script must be present in the +# manifest. +# +function(add_fb_python_executable TARGET) + fb_py_check_available() + + # Parse the arguments + set(one_value_args BASE_DIR NAMESPACE MAIN_MODULE TYPE) + set(multi_value_args SOURCES DEPENDS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + fb_py_process_default_args(ARG_NAMESPACE ARG_BASE_DIR) + + # Use add_fb_python_library() to perform most of our source handling + add_fb_python_library( + "${TARGET}.main_lib" + BASE_DIR "${ARG_BASE_DIR}" + NAMESPACE "${ARG_NAMESPACE}" + SOURCES ${ARG_SOURCES} + DEPENDS ${ARG_DEPENDS} + ) + + set( + manifest_files + "$" + ) + set( + source_files + "$" + ) + + # The command to build the executable archive. + # + # If we are using CMake 3.8+ we can use COMMAND_EXPAND_LISTS. + # CMP0067 isn't really the policy we care about, but seems like the best way + # to check if we are running 3.8+. + if (POLICY CMP0067) + set(extra_cmd_params COMMAND_EXPAND_LISTS) + set(make_py_args "${manifest_files}") + else() + set(extra_cmd_params) + set(make_py_args --manifest-separator "::" "$") + endif() + + set(output_file "${TARGET}${CMAKE_EXECUTABLE_SUFFIX}") + if(WIN32) + set(zipapp_output "${TARGET}.py_zipapp") + else() + set(zipapp_output "${output_file}") + endif() + set(zipapp_output_file "${zipapp_output}") + + set(is_dir_output FALSE) + if(DEFINED ARG_TYPE) + list(APPEND make_py_args "--type" "${ARG_TYPE}") + if ("${ARG_TYPE}" STREQUAL "dir") + set(is_dir_output TRUE) + # CMake doesn't really seem to like having a directory specified as an + # output; specify the __main__.py file as the output instead. + set(zipapp_output_file "${zipapp_output}/__main__.py") + list(APPEND + extra_cmd_params + COMMAND "${CMAKE_COMMAND}" -E remove_directory "${zipapp_output}" + ) + endif() + endif() + + if(DEFINED ARG_MAIN_MODULE) + list(APPEND make_py_args "--main" "${ARG_MAIN_MODULE}") + endif() + + add_custom_command( + OUTPUT "${zipapp_output_file}" + ${extra_cmd_params} + COMMAND + "${Python3_EXECUTABLE}" "${FB_MAKE_PYTHON_ARCHIVE}" + -o "${zipapp_output}" + ${make_py_args} + DEPENDS + ${source_files} + "${TARGET}.main_lib.py_sources_built" + "${FB_MAKE_PYTHON_ARCHIVE}" + ) + + if(WIN32) + if(is_dir_output) + # TODO: generate a main executable that will invoke Python3 + # with the correct main module inside the output directory + else() + add_executable("${TARGET}.winmain" "${FB_PY_WIN_MAIN_C}") + target_link_libraries("${TARGET}.winmain" Python3::Python) + # The Python3::Python target doesn't seem to be set up completely + # correctly on Windows for some reason, and we have to explicitly add + # ${Python3_LIBRARY_DIRS} to the target link directories. + target_link_directories( + "${TARGET}.winmain" + PUBLIC ${Python3_LIBRARY_DIRS} + ) + add_custom_command( + OUTPUT "${output_file}" + DEPENDS "${TARGET}.winmain" "${zipapp_output_file}" + COMMAND + "cmd.exe" "/c" "copy" "/b" + "${TARGET}.winmain${CMAKE_EXECUTABLE_SUFFIX}+${zipapp_output}" + "${output_file}" + ) + endif() + endif() + + # Add an "ALL" target that depends on force ${TARGET}, + # so that ${TARGET} will be included in the default list of build targets. + add_custom_target("${TARGET}.GEN_PY_EXE" ALL DEPENDS "${output_file}") + + # Allow resolving the executable path for the target that we generate + # via a generator expression like: + # "WATCHMAN_WAIT_PATH=$" + set_property(TARGET "${TARGET}.GEN_PY_EXE" + PROPERTY EXECUTABLE "${CMAKE_CURRENT_BINARY_DIR}/${output_file}") +endfunction() + +# Define a python unittest executable. +# The executable is built using add_fb_python_executable and has the +# following differences: +# +# Each of the source files specified in SOURCES will be imported +# and have unittest discovery performed upon them. +# Those sources will be imported in the top level namespace. +# +# The ENV argument allows specifying a list of "KEY=VALUE" +# pairs that will be used by the test runner to set up the environment +# in the child process prior to running the test. This is useful for +# passing additional configuration to the test. +function(add_fb_python_unittest TARGET) + # Parse the arguments + set(multi_value_args SOURCES DEPENDS ENV PROPERTIES) + set( + one_value_args + WORKING_DIRECTORY BASE_DIR NAMESPACE TEST_LIST DISCOVERY_TIMEOUT + ) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + fb_py_process_default_args(ARG_NAMESPACE ARG_BASE_DIR) + if(NOT ARG_WORKING_DIRECTORY) + # Default the working directory to the current binary directory. + # This matches the default behavior of add_test() and other standard + # test functions like gtest_discover_tests() + set(ARG_WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") + endif() + if(NOT ARG_TEST_LIST) + set(ARG_TEST_LIST "${TARGET}_TESTS") + endif() + if(NOT ARG_DISCOVERY_TIMEOUT) + set(ARG_DISCOVERY_TIMEOUT 5) + endif() + + # Tell our test program the list of modules to scan for tests. + # We scan all modules directly listed in our SOURCES argument, and skip + # modules that came from dependencies in the DEPENDS list. + # + # This is written into a __test_modules__.py module that the test runner + # will look at. + set( + test_modules_path + "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_test_modules.py" + ) + file(WRITE "${test_modules_path}" "TEST_MODULES = [\n") + string(REPLACE "." "/" namespace_dir "${ARG_NAMESPACE}") + if (NOT "${namespace_dir}" STREQUAL "") + set(namespace_dir "${namespace_dir}/") + endif() + set(test_modules) + foreach(src_path IN LISTS ARG_SOURCES) + fb_py_compute_dest_path( + abs_source dest_path + "${src_path}" "${namespace_dir}" "${ARG_BASE_DIR}" + ) + string(REPLACE "/" "." module_name "${dest_path}") + string(REGEX REPLACE "\\.py$" "" module_name "${module_name}") + list(APPEND test_modules "${module_name}") + file(APPEND "${test_modules_path}" " '${module_name}',\n") + endforeach() + file(APPEND "${test_modules_path}" "]\n") + + # The __main__ is provided by our runner wrapper/bootstrap + list(APPEND ARG_SOURCES "${FB_PY_TEST_MAIN}=__main__.py") + list(APPEND ARG_SOURCES "${test_modules_path}=__test_modules__.py") + + add_fb_python_executable( + "${TARGET}" + NAMESPACE "${ARG_NAMESPACE}" + BASE_DIR "${ARG_BASE_DIR}" + SOURCES ${ARG_SOURCES} + DEPENDS ${ARG_DEPENDS} + ) + + # Run test discovery after the test executable is built. + # This logic is based on the code for gtest_discover_tests() + set(ctest_file_base "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}") + set(ctest_include_file "${ctest_file_base}_include.cmake") + set(ctest_tests_file "${ctest_file_base}_tests.cmake") + add_custom_command( + TARGET "${TARGET}.GEN_PY_EXE" POST_BUILD + BYPRODUCTS "${ctest_tests_file}" + COMMAND + "${CMAKE_COMMAND}" + -D "TEST_TARGET=${TARGET}" + -D "TEST_INTERPRETER=${Python3_EXECUTABLE}" + -D "TEST_ENV=${ARG_ENV}" + -D "TEST_EXECUTABLE=$" + -D "TEST_WORKING_DIR=${ARG_WORKING_DIRECTORY}" + -D "TEST_LIST=${ARG_TEST_LIST}" + -D "TEST_PREFIX=${TARGET}::" + -D "TEST_PROPERTIES=${ARG_PROPERTIES}" + -D "CTEST_FILE=${ctest_tests_file}" + -P "${FB_PY_TEST_DISCOVER_SCRIPT}" + VERBATIM + ) + + file( + WRITE "${ctest_include_file}" + "if(EXISTS \"${ctest_tests_file}\")\n" + " include(\"${ctest_tests_file}\")\n" + "else()\n" + " add_test(\"${TARGET}_NOT_BUILT\" \"${TARGET}_NOT_BUILT\")\n" + "endif()\n" + ) + set_property( + DIRECTORY APPEND PROPERTY TEST_INCLUDE_FILES + "${ctest_include_file}" + ) +endfunction() + +# +# Define a python library. +# +# If you want to install a python library generated from this rule note that +# you need to use install_fb_python_library() rather than CMake's built-in +# install() function. This will make it available for other downstream +# projects to use in their add_fb_python_executable() and +# add_fb_python_library() calls. (You do still need to use `install(EXPORT)` +# later to install the CMake exports.) +# +# Parameters: +# - BASE_DIR : +# The base directory path to strip off from each source path. All source +# files must be inside this directory. If not specified it defaults to +# ${CMAKE_CURRENT_SOURCE_DIR}. +# - NAMESPACE : +# The destination namespace where these files should be installed in python +# binaries. If not specified, this defaults to the current relative path of +# ${CMAKE_CURRENT_SOURCE_DIR} inside ${CMAKE_SOURCE_DIR}. e.g., a python +# library defined in the directory repo_root/foo/bar will use a default +# namespace of "foo.bar" +# - SOURCES <...>: +# The python source files. +# You may optionally specify as source using the form: PATH=ALIAS where +# PATH is a relative path in the source tree and ALIAS is the relative +# path into which PATH should be rewritten. This is useful for mapping +# an executable script to the main module in a python executable. +# e.g.: `python/bin/watchman-wait=__main__.py` +# - DEPENDS <...>: +# Other python libraries that this one depends on. +# - INSTALL_DIR : +# The directory where this library should be installed. +# install_fb_python_library() must still be called later to perform the +# installation. If a relative path is given it will be treated relative to +# ${CMAKE_INSTALL_PREFIX} +# +# CMake is unfortunately pretty crappy at being able to define custom build +# rules & behaviors. It doesn't support transitive property propagation +# between custom targets; only the built-in add_executable() and add_library() +# targets support transitive properties. +# +# We hack around this janky CMake behavior by (ab)using interface libraries to +# propagate some of the data we want between targets, without actually +# generating a C library. +# +# add_fb_python_library(SOMELIB) generates the following things: +# - An INTERFACE library rule named SOMELIB.py_lib which tracks some +# information about transitive dependencies: +# - the transitive set of source files in the INTERFACE_SOURCES property +# - the transitive set of manifest files that this library depends on in +# the INTERFACE_INCLUDE_DIRECTORIES property. +# - A custom command that generates a SOMELIB.manifest file. +# This file contains the mapping of source files to desired destination +# locations in executables that depend on this library. This manifest file +# will then be read at build-time in order to build executables. +# +function(add_fb_python_library LIB_NAME) + fb_py_check_available() + + # Parse the arguments + # We use fb_cmake_parse_args() rather than cmake_parse_arguments() since + # cmake_parse_arguments() does not handle empty arguments, and it is common + # for callers to want to specify an empty NAMESPACE parameter. + set(one_value_args BASE_DIR NAMESPACE INSTALL_DIR) + set(multi_value_args SOURCES DEPENDS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + fb_py_process_default_args(ARG_NAMESPACE ARG_BASE_DIR) + + string(REPLACE "." "/" namespace_dir "${ARG_NAMESPACE}") + if (NOT "${namespace_dir}" STREQUAL "") + set(namespace_dir "${namespace_dir}/") + endif() + + if(NOT DEFINED ARG_INSTALL_DIR) + set(install_dir "${FBPY_LIB_INSTALL_DIR}/") + elseif("${ARG_INSTALL_DIR}" STREQUAL "") + set(install_dir "") + else() + set(install_dir "${ARG_INSTALL_DIR}/") + endif() + + # message(STATUS "fb py library ${LIB_NAME}: " + # "NS=${namespace_dir} BASE=${ARG_BASE_DIR}") + + # TODO: In the future it would be nice to support pre-compiling the source + # files. We could emit a rule to compile each source file and emit a + # .pyc/.pyo file here, and then have the manifest reference the pyc/pyo + # files. + + # Define a library target to help pass around information about the library, + # and propagate dependency information. + # + # CMake make a lot of assumptions that libraries are C++ libraries. To help + # avoid confusion we name our target "${LIB_NAME}.py_lib" rather than just + # "${LIB_NAME}". This helps avoid confusion if callers try to use + # "${LIB_NAME}" on their own as a target name. (e.g., attempting to install + # it directly with install(TARGETS) won't work. Callers must use + # install_fb_python_library() instead.) + add_library("${LIB_NAME}.py_lib" INTERFACE) + + # Emit the manifest file. + # + # We write the manifest file to a temporary path first, then copy it with + # configure_file(COPYONLY). This is necessary to get CMake to understand + # that "${manifest_path}" is generated by the CMake configure phase, + # and allow using it as a dependency for add_custom_command(). + # (https://gitlab.kitware.com/cmake/cmake/issues/16367) + set(manifest_path "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}.manifest") + set(tmp_manifest "${manifest_path}.tmp") + file(WRITE "${tmp_manifest}" "FBPY_MANIFEST 1\n") + set(abs_sources) + foreach(src_path IN LISTS ARG_SOURCES) + fb_py_compute_dest_path( + abs_source dest_path + "${src_path}" "${namespace_dir}" "${ARG_BASE_DIR}" + ) + list(APPEND abs_sources "${abs_source}") + target_sources( + "${LIB_NAME}.py_lib" INTERFACE + "$" + "$" + ) + file( + APPEND "${tmp_manifest}" + "${abs_source} :: ${dest_path}\n" + ) + endforeach() + configure_file("${tmp_manifest}" "${manifest_path}" COPYONLY) + + target_include_directories( + "${LIB_NAME}.py_lib" INTERFACE + "$" + "$" + ) + + # Add a target that depends on all of the source files. + # This is needed in case some of the source files are generated. This will + # ensure that these source files are brought up-to-date before we build + # any python binaries that depend on this library. + add_custom_target("${LIB_NAME}.py_sources_built" DEPENDS ${abs_sources}) + add_dependencies("${LIB_NAME}.py_lib" "${LIB_NAME}.py_sources_built") + + # Hook up library dependencies, and also make the *.py_sources_built target + # depend on the sources for all of our dependencies also being up-to-date. + foreach(dep IN LISTS ARG_DEPENDS) + target_link_libraries("${LIB_NAME}.py_lib" INTERFACE "${dep}.py_lib") + + # Mark that our .py_sources_built target depends on each our our dependent + # libraries. This serves two functions: + # - This causes CMake to generate an error message if one of the + # dependencies is never defined. The target_link_libraries() call above + # won't complain if one of the dependencies doesn't exist (since it is + # intended to allow passing in file names for plain library files rather + # than just targets). + # - It ensures that sources for our depencencies are built before any + # executable that depends on us. Note that we depend on "${dep}.py_lib" + # rather than "${dep}.py_sources_built" for this purpose because the + # ".py_sources_built" target won't be available for imported targets. + add_dependencies("${LIB_NAME}.py_sources_built" "${dep}.py_lib") + endforeach() + + # Add a custom command to help with library installation, in case + # install_fb_python_library() is called later for this library. + # add_custom_command() only works with file dependencies defined in the same + # CMakeLists.txt file, so we want to make sure this is defined here, rather + # then where install_fb_python_library() is called. + # This command won't be run by default, but will only be run if it is needed + # by a subsequent install_fb_python_library() call. + # + # This command copies the library contents into the build directory. + # It would be nicer if we could skip this intermediate copy, and just run + # make_fbpy_archive.py at install time to copy them directly to the desired + # installation directory. Unfortunately this is difficult to do, and seems + # to interfere with some of the CMake code that wants to generate a manifest + # of installed files. + set(build_install_dir "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}.lib_install") + add_custom_command( + OUTPUT + "${build_install_dir}/${LIB_NAME}.manifest" + COMMAND "${CMAKE_COMMAND}" -E remove_directory "${build_install_dir}" + COMMAND + "${Python3_EXECUTABLE}" "${FB_MAKE_PYTHON_ARCHIVE}" --type lib-install + --install-dir "${LIB_NAME}" + -o "${build_install_dir}/${LIB_NAME}" "${manifest_path}" + DEPENDS + "${abs_sources}" + "${manifest_path}" + "${FB_MAKE_PYTHON_ARCHIVE}" + ) + add_custom_target( + "${LIB_NAME}.py_lib_install" + DEPENDS "${build_install_dir}/${LIB_NAME}.manifest" + ) + + # Set some properties to pass through the install paths to + # install_fb_python_library() + # + # Passing through ${build_install_dir} allows install_fb_python_library() + # to work even if used from a different CMakeLists.txt file than where + # add_fb_python_library() was called (i.e. such that + # ${CMAKE_CURRENT_BINARY_DIR} is different between the two calls). + set(abs_install_dir "${install_dir}") + if(NOT IS_ABSOLUTE "${abs_install_dir}") + set(abs_install_dir "${CMAKE_INSTALL_PREFIX}/${abs_install_dir}") + endif() + string(REGEX REPLACE "/$" "" abs_install_dir "${abs_install_dir}") + set_target_properties( + "${LIB_NAME}.py_lib_install" + PROPERTIES + INSTALL_DIR "${abs_install_dir}" + BUILD_INSTALL_DIR "${build_install_dir}" + ) +endfunction() + +# +# Install an FB-style packaged python binary. +# +# - DESTINATION : +# Associate the installed target files with the given export-name. +# +function(install_fb_python_executable TARGET) + # Parse the arguments + set(one_value_args DESTINATION) + set(multi_value_args) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_DESTINATION) + set(ARG_DESTINATION bin) + endif() + + install( + PROGRAMS "$" + DESTINATION "${ARG_DESTINATION}" + ) +endfunction() + +# +# Install a python library. +# +# - EXPORT : +# Associate the installed target files with the given export-name. +# +# Note that unlike the built-in CMake install() function we do not accept a +# DESTINATION parameter. Instead, use the INSTALL_DIR parameter to +# add_fb_python_library() to set the installation location. +# +function(install_fb_python_library LIB_NAME) + set(one_value_args EXPORT) + fb_cmake_parse_args(ARG "" "${one_value_args}" "" "${ARGN}") + + # Export our "${LIB_NAME}.py_lib" target so that it will be available to + # downstream projects in our installed CMake config files. + if(DEFINED ARG_EXPORT) + install(TARGETS "${LIB_NAME}.py_lib" EXPORT "${ARG_EXPORT}") + endif() + + # add_fb_python_library() emits a .py_lib_install target that will prepare + # the installation directory. However, it isn't part of the "ALL" target and + # therefore isn't built by default. + # + # Make sure the ALL target depends on it now. We have to do this by + # introducing yet another custom target. + # Add it as a dependency to the ALL target now. + add_custom_target("${LIB_NAME}.py_lib_install_all" ALL) + add_dependencies( + "${LIB_NAME}.py_lib_install_all" "${LIB_NAME}.py_lib_install" + ) + + # Copy the intermediate install directory generated at build time into + # the desired install location. + get_target_property(dest_dir "${LIB_NAME}.py_lib_install" "INSTALL_DIR") + get_target_property( + build_install_dir "${LIB_NAME}.py_lib_install" "BUILD_INSTALL_DIR" + ) + install( + DIRECTORY "${build_install_dir}/${LIB_NAME}" + DESTINATION "${dest_dir}" + ) + install( + FILES "${build_install_dir}/${LIB_NAME}.manifest" + DESTINATION "${dest_dir}" + ) +endfunction() + +# Helper macro to process the BASE_DIR and NAMESPACE arguments for +# add_fb_python_executable() and add_fb_python_executable() +macro(fb_py_process_default_args NAMESPACE_VAR BASE_DIR_VAR) + # If the namespace was not specified, default to the relative path to the + # current directory (starting from the repository root). + if(NOT DEFINED "${NAMESPACE_VAR}") + file( + RELATIVE_PATH "${NAMESPACE_VAR}" + "${CMAKE_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}" + ) + endif() + + if(NOT DEFINED "${BASE_DIR_VAR}") + # If the base directory was not specified, default to the current directory + set("${BASE_DIR_VAR}" "${CMAKE_CURRENT_SOURCE_DIR}") + else() + # If the base directory was specified, always convert it to an + # absolute path. + get_filename_component("${BASE_DIR_VAR}" "${${BASE_DIR_VAR}}" ABSOLUTE) + endif() +endmacro() + +function(fb_py_check_available) + # Make sure that Python 3 and our make_fbpy_archive.py helper script are + # available. + if(NOT Python3_EXECUTABLE) + if(FBPY_FIND_PYTHON_ERR) + message(FATAL_ERROR "Unable to find Python 3: ${FBPY_FIND_PYTHON_ERR}") + else() + message(FATAL_ERROR "Unable to find Python 3") + endif() + endif() + + if (NOT FB_MAKE_PYTHON_ARCHIVE) + message( + FATAL_ERROR "unable to find make_fbpy_archive.py helper program (it " + "should be located in the same directory as FBPythonBinary.cmake)" + ) + endif() +endfunction() + +function( + fb_py_compute_dest_path + src_path_output dest_path_output src_path namespace_dir base_dir +) + if("${src_path}" MATCHES "=") + # We want to split the string on the `=` sign, but cmake doesn't + # provide much in the way of helpers for this, so we rewrite the + # `=` sign to `;` so that we can treat it as a cmake list and + # then index into the components + string(REPLACE "=" ";" src_path_list "${src_path}") + list(GET src_path_list 0 src_path) + # Note that we ignore the `namespace_dir` in the alias case + # in order to allow aliasing a source to the top level `__main__.py` + # filename. + list(GET src_path_list 1 dest_path) + else() + unset(dest_path) + endif() + + get_filename_component(abs_source "${src_path}" ABSOLUTE) + if(NOT DEFINED dest_path) + file(RELATIVE_PATH rel_src "${ARG_BASE_DIR}" "${abs_source}") + if("${rel_src}" MATCHES "^../") + message( + FATAL_ERROR "${LIB_NAME}: source file \"${abs_source}\" is not inside " + "the base directory ${ARG_BASE_DIR}" + ) + endif() + set(dest_path "${namespace_dir}${rel_src}") + endif() + + set("${src_path_output}" "${abs_source}" PARENT_SCOPE) + set("${dest_path_output}" "${dest_path}" PARENT_SCOPE) +endfunction() diff --git a/build/fbcode_builder/CMake/FBPythonTestAddTests.cmake b/build/fbcode_builder/CMake/FBPythonTestAddTests.cmake new file mode 100644 index 000000000..d73c055d8 --- /dev/null +++ b/build/fbcode_builder/CMake/FBPythonTestAddTests.cmake @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# Add a command to be emitted to the CTest file +set(ctest_script) +function(add_command CMD) + set(escaped_args "") + foreach(arg ${ARGN}) + # Escape all arguments using "Bracket Argument" syntax + # We could skip this for argument that don't contain any special + # characters if we wanted to make the output slightly more human-friendly. + set(escaped_args "${escaped_args} [==[${arg}]==]") + endforeach() + set(ctest_script "${ctest_script}${CMD}(${escaped_args})\n" PARENT_SCOPE) +endfunction() + +if(NOT EXISTS "${TEST_EXECUTABLE}") + message(FATAL_ERROR "Test executable does not exist: ${TEST_EXECUTABLE}") +endif() +execute_process( + COMMAND ${CMAKE_COMMAND} -E env ${TEST_ENV} "${TEST_INTERPRETER}" "${TEST_EXECUTABLE}" --list-tests + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + OUTPUT_VARIABLE output + RESULT_VARIABLE result +) +if(NOT "${result}" EQUAL 0) + string(REPLACE "\n" "\n " output "${output}") + message( + FATAL_ERROR + "Error running test executable: ${TEST_EXECUTABLE}\n" + "Output:\n" + " ${output}\n" + ) +endif() + +# Parse output +string(REPLACE "\n" ";" tests_list "${output}") +foreach(test_name ${tests_list}) + add_command( + add_test + "${TEST_PREFIX}${test_name}" + ${CMAKE_COMMAND} -E env ${TEST_ENV} + "${TEST_INTERPRETER}" "${TEST_EXECUTABLE}" "${test_name}" + ) + add_command( + set_tests_properties + "${TEST_PREFIX}${test_name}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${TEST_PROPERTIES} + ) +endforeach() + +# Set a list of discovered tests in the parent scope, in case users +# want access to this list as a CMake variable +if(TEST_LIST) + add_command(set ${TEST_LIST} ${tests_list}) +endif() + +file(WRITE "${CTEST_FILE}" "${ctest_script}") diff --git a/build/fbcode_builder/CMake/FBThriftCppLibrary.cmake b/build/fbcode_builder/CMake/FBThriftCppLibrary.cmake new file mode 100644 index 000000000..670771a46 --- /dev/null +++ b/build/fbcode_builder/CMake/FBThriftCppLibrary.cmake @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) + +# Generate a C++ library from a thrift file +# +# Parameters: +# - SERVICES [ ...] +# The names of the services defined in the thrift file. +# - DEPENDS [ ...] +# A list of other thrift C++ libraries that this library depends on. +# - OPTIONS [ ...] +# A list of options to pass to the thrift compiler. +# - INCLUDE_DIR +# The sub-directory where generated headers will be installed. +# Defaults to "include" if not specified. The caller must still call +# install() to install the thrift library if desired. +# - THRIFT_INCLUDE_DIR +# The sub-directory where generated headers will be installed. +# Defaults to "${INCLUDE_DIR}/thrift-files" if not specified. +# The caller must still call install() to install the thrift library if +# desired. +function(add_fbthrift_cpp_library LIB_NAME THRIFT_FILE) + # Parse the arguments + set(one_value_args INCLUDE_DIR THRIFT_INCLUDE_DIR) + set(multi_value_args SERVICES DEPENDS OPTIONS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + if(NOT DEFINED ARG_INCLUDE_DIR) + set(ARG_INCLUDE_DIR "include") + endif() + if(NOT DEFINED ARG_THRIFT_INCLUDE_DIR) + set(ARG_THRIFT_INCLUDE_DIR "${ARG_INCLUDE_DIR}/thrift-files") + endif() + + get_filename_component(base ${THRIFT_FILE} NAME_WE) + get_filename_component( + output_dir + ${CMAKE_CURRENT_BINARY_DIR}/${THRIFT_FILE} + DIRECTORY + ) + + # Generate relative paths in #includes + file( + RELATIVE_PATH include_prefix + "${CMAKE_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + ) + get_filename_component(include_prefix ${include_prefix} DIRECTORY) + + if (NOT "${include_prefix}" STREQUAL "") + list(APPEND ARG_OPTIONS "include_prefix=${include_prefix}") + endif() + # CMake 3.12 is finally getting a list(JOIN) function, but until then + # treating the list as a string and replacing the semicolons is good enough. + string(REPLACE ";" "," GEN_ARG_STR "${ARG_OPTIONS}") + + # Compute the list of generated files + list(APPEND generated_headers + "${output_dir}/gen-cpp2/${base}_constants.h" + "${output_dir}/gen-cpp2/${base}_types.h" + "${output_dir}/gen-cpp2/${base}_types.tcc" + "${output_dir}/gen-cpp2/${base}_types_custom_protocol.h" + "${output_dir}/gen-cpp2/${base}_metadata.h" + ) + list(APPEND generated_sources + "${output_dir}/gen-cpp2/${base}_constants.cpp" + "${output_dir}/gen-cpp2/${base}_data.h" + "${output_dir}/gen-cpp2/${base}_data.cpp" + "${output_dir}/gen-cpp2/${base}_types.cpp" + "${output_dir}/gen-cpp2/${base}_metadata.cpp" + ) + foreach(service IN LISTS ARG_SERVICES) + list(APPEND generated_headers + "${output_dir}/gen-cpp2/${service}.h" + "${output_dir}/gen-cpp2/${service}.tcc" + "${output_dir}/gen-cpp2/${service}AsyncClient.h" + "${output_dir}/gen-cpp2/${service}_custom_protocol.h" + ) + list(APPEND generated_sources + "${output_dir}/gen-cpp2/${service}.cpp" + "${output_dir}/gen-cpp2/${service}AsyncClient.cpp" + "${output_dir}/gen-cpp2/${service}_processmap_binary.cpp" + "${output_dir}/gen-cpp2/${service}_processmap_compact.cpp" + ) + endforeach() + + # This generator expression gets the list of include directories required + # for all of our dependencies. + # It requires using COMMAND_EXPAND_LISTS in the add_custom_command() call + # below. COMMAND_EXPAND_LISTS is only available in CMake 3.8+ + # If we really had to support older versions of CMake we would probably need + # to use a wrapper script around the thrift compiler that could take the + # include list as a single argument and split it up before invoking the + # thrift compiler. + if (NOT POLICY CMP0067) + message(FATAL_ERROR "add_fbthrift_cpp_library() requires CMake 3.8+") + endif() + set( + thrift_include_options + "-I;$,;-I;>" + ) + + # Emit the rule to run the thrift compiler + add_custom_command( + OUTPUT + ${generated_headers} + ${generated_sources} + COMMAND_EXPAND_LISTS + COMMAND + "${CMAKE_COMMAND}" -E make_directory "${output_dir}" + COMMAND + "${FBTHRIFT_COMPILER}" + --strict + --gen "mstch_cpp2:${GEN_ARG_STR}" + "${thrift_include_options}" + -o "${output_dir}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + WORKING_DIRECTORY + "${CMAKE_BINARY_DIR}" + MAIN_DEPENDENCY + "${THRIFT_FILE}" + DEPENDS + ${ARG_DEPENDS} + "${FBTHRIFT_COMPILER}" + ) + + # Now emit the library rule to compile the sources + if (BUILD_SHARED_LIBS) + set(LIB_TYPE SHARED) + else () + set(LIB_TYPE STATIC) + endif () + + add_library( + "${LIB_NAME}" ${LIB_TYPE} + ${generated_sources} + ) + + target_include_directories( + "${LIB_NAME}" + PUBLIC + "$" + "$" + ) + target_link_libraries( + "${LIB_NAME}" + PUBLIC + ${ARG_DEPENDS} + FBThrift::thriftcpp2 + Folly::folly + ) + + # Add ${generated_headers} to the PUBLIC_HEADER property for ${LIB_NAME} + # + # This allows callers to install it using + # "install(TARGETS ${LIB_NAME} PUBLIC_HEADER)" + # However, note that CMake's PUBLIC_HEADER behavior is rather inflexible, + # and does have any way to preserve header directory structure. Callers + # must be careful to use the correct PUBLIC_HEADER DESTINATION parameter + # when doing this, to put the files the correct directory themselves. + # We define a HEADER_INSTALL_DIR property with the include directory prefix, + # so typically callers should specify the PUBLIC_HEADER DESTINATION as + # "$" + set_property( + TARGET "${LIB_NAME}" + PROPERTY PUBLIC_HEADER ${generated_headers} + ) + + # Define a dummy interface library to help propagate the thrift include + # directories between dependencies. + add_library("${LIB_NAME}.thrift_includes" INTERFACE) + target_include_directories( + "${LIB_NAME}.thrift_includes" + INTERFACE + "$" + "$" + ) + foreach(dep IN LISTS ARG_DEPENDS) + target_link_libraries( + "${LIB_NAME}.thrift_includes" + INTERFACE "${dep}.thrift_includes" + ) + endforeach() + + set_target_properties( + "${LIB_NAME}" + PROPERTIES + EXPORT_PROPERTIES "THRIFT_INSTALL_DIR" + THRIFT_INSTALL_DIR "${ARG_THRIFT_INCLUDE_DIR}/${include_prefix}" + HEADER_INSTALL_DIR "${ARG_INCLUDE_DIR}/${include_prefix}/gen-cpp2" + ) +endfunction() diff --git a/build/fbcode_builder/CMake/FBThriftLibrary.cmake b/build/fbcode_builder/CMake/FBThriftLibrary.cmake new file mode 100644 index 000000000..e4280e2a4 --- /dev/null +++ b/build/fbcode_builder/CMake/FBThriftLibrary.cmake @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) +include(FBThriftPyLibrary) +include(FBThriftCppLibrary) + +# +# add_fbthrift_library() +# +# This is a convenience function that generates thrift libraries for multiple +# languages. +# +# For example: +# add_fbthrift_library( +# foo foo.thrift +# LANGUAGES cpp py +# SERVICES Foo +# DEPENDS bar) +# +# will be expanded into two separate calls: +# +# add_fbthrift_cpp_library(foo_cpp foo.thrift SERVICES Foo DEPENDS bar_cpp) +# add_fbthrift_py_library(foo_py foo.thrift SERVICES Foo DEPENDS bar_py) +# +function(add_fbthrift_library LIB_NAME THRIFT_FILE) + # Parse the arguments + set(one_value_args PY_NAMESPACE INCLUDE_DIR THRIFT_INCLUDE_DIR) + set(multi_value_args SERVICES DEPENDS LANGUAGES CPP_OPTIONS PY_OPTIONS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_INCLUDE_DIR) + set(ARG_INCLUDE_DIR "include") + endif() + if(NOT DEFINED ARG_THRIFT_INCLUDE_DIR) + set(ARG_THRIFT_INCLUDE_DIR "${ARG_INCLUDE_DIR}/thrift-files") + endif() + + # CMake 3.12+ adds list(TRANSFORM) which would be nice to use here, but for + # now we still want to support older versions of CMake. + set(CPP_DEPENDS) + set(PY_DEPENDS) + foreach(dep IN LISTS ARG_DEPENDS) + list(APPEND CPP_DEPENDS "${dep}_cpp") + list(APPEND PY_DEPENDS "${dep}_py") + endforeach() + + foreach(lang IN LISTS ARG_LANGUAGES) + if ("${lang}" STREQUAL "cpp") + add_fbthrift_cpp_library( + "${LIB_NAME}_cpp" "${THRIFT_FILE}" + SERVICES ${ARG_SERVICES} + DEPENDS ${CPP_DEPENDS} + OPTIONS ${ARG_CPP_OPTIONS} + INCLUDE_DIR "${ARG_INCLUDE_DIR}" + THRIFT_INCLUDE_DIR "${ARG_THRIFT_INCLUDE_DIR}" + ) + elseif ("${lang}" STREQUAL "py" OR "${lang}" STREQUAL "python") + if (DEFINED ARG_PY_NAMESPACE) + set(namespace_args NAMESPACE "${ARG_PY_NAMESPACE}") + endif() + add_fbthrift_py_library( + "${LIB_NAME}_py" "${THRIFT_FILE}" + SERVICES ${ARG_SERVICES} + ${namespace_args} + DEPENDS ${PY_DEPENDS} + OPTIONS ${ARG_PY_OPTIONS} + THRIFT_INCLUDE_DIR "${ARG_THRIFT_INCLUDE_DIR}" + ) + else() + message( + FATAL_ERROR "unknown language for thrift library ${LIB_NAME}: ${lang}" + ) + endif() + endforeach() +endfunction() diff --git a/build/fbcode_builder/CMake/FBThriftPyLibrary.cmake b/build/fbcode_builder/CMake/FBThriftPyLibrary.cmake new file mode 100644 index 000000000..7bd8879ee --- /dev/null +++ b/build/fbcode_builder/CMake/FBThriftPyLibrary.cmake @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) +include(FBPythonBinary) + +# Generate a Python library from a thrift file +function(add_fbthrift_py_library LIB_NAME THRIFT_FILE) + # Parse the arguments + set(one_value_args NAMESPACE THRIFT_INCLUDE_DIR) + set(multi_value_args SERVICES DEPENDS OPTIONS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_THRIFT_INCLUDE_DIR) + set(ARG_THRIFT_INCLUDE_DIR "include/thrift-files") + endif() + + get_filename_component(base ${THRIFT_FILE} NAME_WE) + set(output_dir "${CMAKE_CURRENT_BINARY_DIR}/${THRIFT_FILE}-py") + + # Parse the namespace value + if (NOT DEFINED ARG_NAMESPACE) + set(ARG_NAMESPACE "${base}") + endif() + + string(REPLACE "." "/" namespace_dir "${ARG_NAMESPACE}") + set(py_output_dir "${output_dir}/gen-py/${namespace_dir}") + list(APPEND generated_sources + "${py_output_dir}/__init__.py" + "${py_output_dir}/ttypes.py" + "${py_output_dir}/constants.py" + ) + foreach(service IN LISTS ARG_SERVICES) + list(APPEND generated_sources + ${py_output_dir}/${service}.py + ) + endforeach() + + # Define a dummy interface library to help propagate the thrift include + # directories between dependencies. + add_library("${LIB_NAME}.thrift_includes" INTERFACE) + target_include_directories( + "${LIB_NAME}.thrift_includes" + INTERFACE + "$" + "$" + ) + foreach(dep IN LISTS ARG_DEPENDS) + target_link_libraries( + "${LIB_NAME}.thrift_includes" + INTERFACE "${dep}.thrift_includes" + ) + endforeach() + + # This generator expression gets the list of include directories required + # for all of our dependencies. + # It requires using COMMAND_EXPAND_LISTS in the add_custom_command() call + # below. COMMAND_EXPAND_LISTS is only available in CMake 3.8+ + # If we really had to support older versions of CMake we would probably need + # to use a wrapper script around the thrift compiler that could take the + # include list as a single argument and split it up before invoking the + # thrift compiler. + if (NOT POLICY CMP0067) + message(FATAL_ERROR "add_fbthrift_py_library() requires CMake 3.8+") + endif() + set( + thrift_include_options + "-I;$,;-I;>" + ) + + # Always force generation of "new-style" python classes for Python 2 + list(APPEND ARG_OPTIONS "new_style") + # CMake 3.12 is finally getting a list(JOIN) function, but until then + # treating the list as a string and replacing the semicolons is good enough. + string(REPLACE ";" "," GEN_ARG_STR "${ARG_OPTIONS}") + + # Emit the rule to run the thrift compiler + add_custom_command( + OUTPUT + ${generated_sources} + COMMAND_EXPAND_LISTS + COMMAND + "${CMAKE_COMMAND}" -E make_directory "${output_dir}" + COMMAND + "${FBTHRIFT_COMPILER}" + --strict + --gen "py:${GEN_ARG_STR}" + "${thrift_include_options}" + -o "${output_dir}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + WORKING_DIRECTORY + "${CMAKE_BINARY_DIR}" + MAIN_DEPENDENCY + "${THRIFT_FILE}" + DEPENDS + "${FBTHRIFT_COMPILER}" + ) + + # We always want to pass the namespace as "" to this call: + # thrift will already emit the files with the desired namespace prefix under + # gen-py. We don't want add_fb_python_library() to prepend the namespace a + # second time. + add_fb_python_library( + "${LIB_NAME}" + BASE_DIR "${output_dir}/gen-py" + NAMESPACE "" + SOURCES ${generated_sources} + DEPENDS ${ARG_DEPENDS} FBThrift::thrift_py + ) +endfunction() diff --git a/build/fbcode_builder/CMake/FindGMock.cmake b/build/fbcode_builder/CMake/FindGMock.cmake new file mode 100644 index 000000000..cd042dd9c --- /dev/null +++ b/build/fbcode_builder/CMake/FindGMock.cmake @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Find libgmock +# +# LIBGMOCK_DEFINES - List of defines when using libgmock. +# LIBGMOCK_INCLUDE_DIR - where to find gmock/gmock.h, etc. +# LIBGMOCK_LIBRARIES - List of libraries when using libgmock. +# LIBGMOCK_FOUND - True if libgmock found. + +IF (LIBGMOCK_INCLUDE_DIR) + # Already in cache, be silent + SET(LIBGMOCK_FIND_QUIETLY TRUE) +ENDIF () + +find_package(GTest CONFIG QUIET) +if (TARGET GTest::gmock) + get_target_property(LIBGMOCK_DEFINES GTest::gtest INTERFACE_COMPILE_DEFINITIONS) + if (NOT ${LIBGMOCK_DEFINES}) + # Explicitly set to empty string if not found to avoid it being + # set to NOTFOUND and breaking compilation + set(LIBGMOCK_DEFINES "") + endif() + get_target_property(LIBGMOCK_INCLUDE_DIR GTest::gtest INTERFACE_INCLUDE_DIRECTORIES) + set(LIBGMOCK_LIBRARIES GTest::gmock_main GTest::gmock GTest::gtest) + set(LIBGMOCK_FOUND ON) + message(STATUS "Found gmock via config, defines=${LIBGMOCK_DEFINES}, include=${LIBGMOCK_INCLUDE_DIR}, libs=${LIBGMOCK_LIBRARIES}") +else() + + FIND_PATH(LIBGMOCK_INCLUDE_DIR gmock/gmock.h) + + FIND_LIBRARY(LIBGMOCK_MAIN_LIBRARY_DEBUG NAMES gmock_maind) + FIND_LIBRARY(LIBGMOCK_MAIN_LIBRARY_RELEASE NAMES gmock_main) + FIND_LIBRARY(LIBGMOCK_LIBRARY_DEBUG NAMES gmockd) + FIND_LIBRARY(LIBGMOCK_LIBRARY_RELEASE NAMES gmock) + FIND_LIBRARY(LIBGTEST_LIBRARY_DEBUG NAMES gtestd) + FIND_LIBRARY(LIBGTEST_LIBRARY_RELEASE NAMES gtest) + + find_package(Threads REQUIRED) + INCLUDE(SelectLibraryConfigurations) + SELECT_LIBRARY_CONFIGURATIONS(LIBGMOCK_MAIN) + SELECT_LIBRARY_CONFIGURATIONS(LIBGMOCK) + SELECT_LIBRARY_CONFIGURATIONS(LIBGTEST) + + set(LIBGMOCK_LIBRARIES + ${LIBGMOCK_MAIN_LIBRARY} + ${LIBGMOCK_LIBRARY} + ${LIBGTEST_LIBRARY} + Threads::Threads + ) + + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + # The GTEST_LINKED_AS_SHARED_LIBRARY macro must be set properly on Windows. + # + # There isn't currently an easy way to determine if a library was compiled as + # a shared library on Windows, so just assume we've been built against a + # shared build of gmock for now. + SET(LIBGMOCK_DEFINES "GTEST_LINKED_AS_SHARED_LIBRARY=1" CACHE STRING "") + endif() + + # handle the QUIETLY and REQUIRED arguments and set LIBGMOCK_FOUND to TRUE if + # all listed variables are TRUE + INCLUDE(FindPackageHandleStandardArgs) + FIND_PACKAGE_HANDLE_STANDARD_ARGS( + GMock + DEFAULT_MSG + LIBGMOCK_MAIN_LIBRARY + LIBGMOCK_LIBRARY + LIBGTEST_LIBRARY + LIBGMOCK_LIBRARIES + LIBGMOCK_INCLUDE_DIR + ) + + MARK_AS_ADVANCED( + LIBGMOCK_DEFINES + LIBGMOCK_MAIN_LIBRARY + LIBGMOCK_LIBRARY + LIBGTEST_LIBRARY + LIBGMOCK_LIBRARIES + LIBGMOCK_INCLUDE_DIR + ) +endif() diff --git a/build/fbcode_builder/CMake/FindGflags.cmake b/build/fbcode_builder/CMake/FindGflags.cmake new file mode 100644 index 000000000..c00896a34 --- /dev/null +++ b/build/fbcode_builder/CMake/FindGflags.cmake @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Find libgflags. +# There's a lot of compatibility cruft going on in here, both +# to deal with changes across the FB consumers of this and also +# to deal with variances in behavior of cmake itself. +# +# Since this file is named FindGflags.cmake the cmake convention +# is for the module to export both GFLAGS_FOUND and Gflags_FOUND. +# The convention expected by consumers is that we export the +# following variables, even though these do not match the cmake +# conventions: +# +# LIBGFLAGS_INCLUDE_DIR - where to find gflags/gflags.h, etc. +# LIBGFLAGS_LIBRARY - List of libraries when using libgflags. +# LIBGFLAGS_FOUND - True if libgflags found. +# +# We need to be able to locate gflags both from an installed +# cmake config file and just from the raw headers and libs, so +# test for the former and then the latter, and then stick +# the results together and export them into the variables +# listed above. +# +# For forwards compatibility, we export the following variables: +# +# gflags_INCLUDE_DIR - where to find gflags/gflags.h, etc. +# gflags_TARGET / GFLAGS_TARGET / gflags_LIBRARIES +# - List of libraries when using libgflags. +# gflags_FOUND - True if libgflags found. +# + +IF (LIBGFLAGS_INCLUDE_DIR) + # Already in cache, be silent + SET(Gflags_FIND_QUIETLY TRUE) +ENDIF () + +find_package(gflags CONFIG QUIET) +if (gflags_FOUND) + if (NOT Gflags_FIND_QUIETLY) + message(STATUS "Found gflags from package config ${gflags_CONFIG}") + endif() + # Re-export the config-specified libs with our local names + set(LIBGFLAGS_LIBRARY ${gflags_LIBRARIES}) + set(LIBGFLAGS_INCLUDE_DIR ${gflags_INCLUDE_DIR}) + if(NOT EXISTS "${gflags_INCLUDE_DIR}") + # The gflags-devel RPM on recent RedHat-based systems is somewhat broken. + # RedHat symlinks /lib64 to /usr/lib64, and this breaks some of the + # relative path computation performed in gflags-config.cmake. The package + # config file ends up being found via /lib64, but the relative path + # computation it does only works if it was found in /usr/lib64. + # If gflags_INCLUDE_DIR does not actually exist, simply default it to + # /usr/include on these systems. + set(LIBGFLAGS_INCLUDE_DIR "/usr/include") + endif() + set(LIBGFLAGS_FOUND ${gflags_FOUND}) + # cmake module compat + set(GFLAGS_FOUND ${gflags_FOUND}) + set(Gflags_FOUND ${gflags_FOUND}) +else() + FIND_PATH(LIBGFLAGS_INCLUDE_DIR gflags/gflags.h) + + FIND_LIBRARY(LIBGFLAGS_LIBRARY_DEBUG NAMES gflagsd gflags_staticd) + FIND_LIBRARY(LIBGFLAGS_LIBRARY_RELEASE NAMES gflags gflags_static) + + INCLUDE(SelectLibraryConfigurations) + SELECT_LIBRARY_CONFIGURATIONS(LIBGFLAGS) + + # handle the QUIETLY and REQUIRED arguments and set LIBGFLAGS_FOUND to TRUE if + # all listed variables are TRUE + INCLUDE(FindPackageHandleStandardArgs) + FIND_PACKAGE_HANDLE_STANDARD_ARGS(gflags DEFAULT_MSG LIBGFLAGS_LIBRARY LIBGFLAGS_INCLUDE_DIR) + # cmake module compat + set(Gflags_FOUND ${GFLAGS_FOUND}) + # compat with some existing FindGflags consumers + set(LIBGFLAGS_FOUND ${GFLAGS_FOUND}) + + # Compat with the gflags CONFIG based detection + set(gflags_FOUND ${GFLAGS_FOUND}) + set(gflags_INCLUDE_DIR ${LIBGFLAGS_INCLUDE_DIR}) + set(gflags_LIBRARIES ${LIBGFLAGS_LIBRARY}) + set(GFLAGS_TARGET ${LIBGFLAGS_LIBRARY}) + set(gflags_TARGET ${LIBGFLAGS_LIBRARY}) + + MARK_AS_ADVANCED(LIBGFLAGS_LIBRARY LIBGFLAGS_INCLUDE_DIR) +endif() + +# Compat with the gflags CONFIG based detection +if (LIBGFLAGS_FOUND AND NOT TARGET gflags) + add_library(gflags UNKNOWN IMPORTED) + if(TARGET gflags-shared) + # If the installed gflags CMake package config defines a gflags-shared + # target but not gflags, just make the gflags target that we define + # depend on the gflags-shared target. + target_link_libraries(gflags INTERFACE gflags-shared) + # Export LIBGFLAGS_LIBRARY as the gflags-shared target in this case. + set(LIBGFLAGS_LIBRARY gflags-shared) + else() + set_target_properties( + gflags + PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${LIBGFLAGS_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${LIBGFLAGS_INCLUDE_DIR}" + ) + endif() +endif() diff --git a/build/fbcode_builder/CMake/FindGlog.cmake b/build/fbcode_builder/CMake/FindGlog.cmake new file mode 100644 index 000000000..752647cb3 --- /dev/null +++ b/build/fbcode_builder/CMake/FindGlog.cmake @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# - Try to find Glog +# Once done, this will define +# +# GLOG_FOUND - system has Glog +# GLOG_INCLUDE_DIRS - the Glog include directories +# GLOG_LIBRARIES - link these to use Glog + +include(FindPackageHandleStandardArgs) +include(SelectLibraryConfigurations) + +find_library(GLOG_LIBRARY_RELEASE glog + PATHS ${GLOG_LIBRARYDIR}) +find_library(GLOG_LIBRARY_DEBUG glogd + PATHS ${GLOG_LIBRARYDIR}) + +find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_INCLUDEDIR}) + +select_library_configurations(GLOG) + +find_package_handle_standard_args(glog DEFAULT_MSG + GLOG_LIBRARY + GLOG_INCLUDE_DIR) + +mark_as_advanced( + GLOG_LIBRARY + GLOG_INCLUDE_DIR) + +set(GLOG_LIBRARIES ${GLOG_LIBRARY}) +set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) + +if (NOT TARGET glog::glog) + add_library(glog::glog UNKNOWN IMPORTED) + set_target_properties(glog::glog PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${GLOG_INCLUDE_DIRS}") + set_target_properties(glog::glog PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${GLOG_LIBRARIES}") +endif() diff --git a/build/fbcode_builder/CMake/FindLibEvent.cmake b/build/fbcode_builder/CMake/FindLibEvent.cmake new file mode 100644 index 000000000..dd11ebd84 --- /dev/null +++ b/build/fbcode_builder/CMake/FindLibEvent.cmake @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# - Find LibEvent (a cross event library) +# This module defines +# LIBEVENT_INCLUDE_DIR, where to find LibEvent headers +# LIBEVENT_LIB, LibEvent libraries +# LibEvent_FOUND, If false, do not try to use libevent + +set(LibEvent_EXTRA_PREFIXES /usr/local /opt/local "$ENV{HOME}") +foreach(prefix ${LibEvent_EXTRA_PREFIXES}) + list(APPEND LibEvent_INCLUDE_PATHS "${prefix}/include") + list(APPEND LibEvent_LIB_PATHS "${prefix}/lib") +endforeach() + +find_package(Libevent CONFIG QUIET) +if (TARGET event) + # Re-export the config under our own names + + # Somewhat gross, but some vcpkg installed libevents have a relative + # `include` path exported into LIBEVENT_INCLUDE_DIRS, which triggers + # a cmake error because it resolves to the `include` dir within the + # folly repo, which is not something cmake allows to be in the + # INTERFACE_INCLUDE_DIRECTORIES. Thankfully on such a system the + # actual include directory is already part of the global include + # directories, so we can just skip it. + if (NOT "${LIBEVENT_INCLUDE_DIRS}" STREQUAL "include") + set(LIBEVENT_INCLUDE_DIR ${LIBEVENT_INCLUDE_DIRS}) + else() + set(LIBEVENT_INCLUDE_DIR) + endif() + + # Unfortunately, with a bare target name `event`, downstream consumers + # of the package that depends on `Libevent` located via CONFIG end + # up exporting just a bare `event` in their libraries. This is problematic + # because this in interpreted as just `-levent` with no library path. + # When libevent is not installed in the default installation prefix + # this results in linker errors. + # To resolve this, we ask cmake to lookup the full path to the library + # and use that instead. + cmake_policy(PUSH) + if(POLICY CMP0026) + # Allow reading the LOCATION property + cmake_policy(SET CMP0026 OLD) + endif() + get_target_property(LIBEVENT_LIB event LOCATION) + cmake_policy(POP) + + set(LibEvent_FOUND ${Libevent_FOUND}) + if (NOT LibEvent_FIND_QUIETLY) + message(STATUS "Found libevent from package config include=${LIBEVENT_INCLUDE_DIRS} lib=${LIBEVENT_LIB}") + endif() +else() + find_path(LIBEVENT_INCLUDE_DIR event.h PATHS ${LibEvent_INCLUDE_PATHS}) + find_library(LIBEVENT_LIB NAMES event PATHS ${LibEvent_LIB_PATHS}) + + if (LIBEVENT_LIB AND LIBEVENT_INCLUDE_DIR) + set(LibEvent_FOUND TRUE) + set(LIBEVENT_LIB ${LIBEVENT_LIB}) + else () + set(LibEvent_FOUND FALSE) + endif () + + if (LibEvent_FOUND) + if (NOT LibEvent_FIND_QUIETLY) + message(STATUS "Found libevent: ${LIBEVENT_LIB}") + endif () + else () + if (LibEvent_FIND_REQUIRED) + message(FATAL_ERROR "Could NOT find libevent.") + endif () + message(STATUS "libevent NOT found.") + endif () + + mark_as_advanced( + LIBEVENT_LIB + LIBEVENT_INCLUDE_DIR + ) +endif() diff --git a/build/fbcode_builder/CMake/FindLibUnwind.cmake b/build/fbcode_builder/CMake/FindLibUnwind.cmake new file mode 100644 index 000000000..b01a674a5 --- /dev/null +++ b/build/fbcode_builder/CMake/FindLibUnwind.cmake @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +find_path(LIBUNWIND_INCLUDE_DIR NAMES libunwind.h) +mark_as_advanced(LIBUNWIND_INCLUDE_DIR) + +find_library(LIBUNWIND_LIBRARY NAMES unwind) +mark_as_advanced(LIBUNWIND_LIBRARY) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + LIBUNWIND + REQUIRED_VARS LIBUNWIND_LIBRARY LIBUNWIND_INCLUDE_DIR) + +if(LIBUNWIND_FOUND) + set(LIBUNWIND_LIBRARIES ${LIBUNWIND_LIBRARY}) + set(LIBUNWIND_INCLUDE_DIRS ${LIBUNWIND_INCLUDE_DIR}) +endif() diff --git a/build/fbcode_builder/CMake/FindPCRE.cmake b/build/fbcode_builder/CMake/FindPCRE.cmake new file mode 100644 index 000000000..32ccb3725 --- /dev/null +++ b/build/fbcode_builder/CMake/FindPCRE.cmake @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +include(FindPackageHandleStandardArgs) +find_path(PCRE_INCLUDE_DIR NAMES pcre.h) +find_library(PCRE_LIBRARY NAMES pcre) +find_package_handle_standard_args( + PCRE + DEFAULT_MSG + PCRE_LIBRARY + PCRE_INCLUDE_DIR +) +mark_as_advanced(PCRE_INCLUDE_DIR PCRE_LIBRARY) diff --git a/build/fbcode_builder/CMake/FindRe2.cmake b/build/fbcode_builder/CMake/FindRe2.cmake new file mode 100644 index 000000000..013ae7761 --- /dev/null +++ b/build/fbcode_builder/CMake/FindRe2.cmake @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This software may be used and distributed according to the terms of the +# GNU General Public License version 2. + +find_library(RE2_LIBRARY re2) +mark_as_advanced(RE2_LIBRARY) + +find_path(RE2_INCLUDE_DIR NAMES re2/re2.h) +mark_as_advanced(RE2_INCLUDE_DIR) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + RE2 + REQUIRED_VARS RE2_LIBRARY RE2_INCLUDE_DIR) + +if(RE2_FOUND) + set(RE2_LIBRARY ${RE2_LIBRARY}) + set(RE2_INCLUDE_DIR, ${RE2_INCLUDE_DIR}) +endif() diff --git a/build/fbcode_builder/CMake/FindSodium.cmake b/build/fbcode_builder/CMake/FindSodium.cmake new file mode 100644 index 000000000..3c3f1245c --- /dev/null +++ b/build/fbcode_builder/CMake/FindSodium.cmake @@ -0,0 +1,297 @@ +# Written in 2016 by Henrik Steffen Gaßmann +# +# To the extent possible under law, the author(s) have dedicated all +# copyright and related and neighboring rights to this software to the +# public domain worldwide. This software is distributed without any warranty. +# +# You should have received a copy of the CC0 Public Domain Dedication +# along with this software. If not, see +# +# http://creativecommons.org/publicdomain/zero/1.0/ +# +######################################################################## +# Tries to find the local libsodium installation. +# +# On Windows the sodium_DIR environment variable is used as a default +# hint which can be overridden by setting the corresponding cmake variable. +# +# Once done the following variables will be defined: +# +# sodium_FOUND +# sodium_INCLUDE_DIR +# sodium_LIBRARY_DEBUG +# sodium_LIBRARY_RELEASE +# +# +# Furthermore an imported "sodium" target is created. +# + +if (CMAKE_C_COMPILER_ID STREQUAL "GNU" + OR CMAKE_C_COMPILER_ID STREQUAL "Clang") + set(_GCC_COMPATIBLE 1) +endif() + +# static library option +if (NOT DEFINED sodium_USE_STATIC_LIBS) + option(sodium_USE_STATIC_LIBS "enable to statically link against sodium" OFF) +endif() +if(NOT (sodium_USE_STATIC_LIBS EQUAL sodium_USE_STATIC_LIBS_LAST)) + unset(sodium_LIBRARY CACHE) + unset(sodium_LIBRARY_DEBUG CACHE) + unset(sodium_LIBRARY_RELEASE CACHE) + unset(sodium_DLL_DEBUG CACHE) + unset(sodium_DLL_RELEASE CACHE) + set(sodium_USE_STATIC_LIBS_LAST ${sodium_USE_STATIC_LIBS} CACHE INTERNAL "internal change tracking variable") +endif() + + +######################################################################## +# UNIX +if (UNIX) + # import pkg-config + find_package(PkgConfig QUIET) + if (PKG_CONFIG_FOUND) + pkg_check_modules(sodium_PKG QUIET libsodium) + endif() + + if(sodium_USE_STATIC_LIBS) + foreach(_libname ${sodium_PKG_STATIC_LIBRARIES}) + if (NOT _libname MATCHES "^lib.*\\.a$") # ignore strings already ending with .a + list(INSERT sodium_PKG_STATIC_LIBRARIES 0 "lib${_libname}.a") + endif() + endforeach() + list(REMOVE_DUPLICATES sodium_PKG_STATIC_LIBRARIES) + + # if pkgconfig for libsodium doesn't provide + # static lib info, then override PKG_STATIC here.. + if (NOT sodium_PKG_STATIC_FOUND) + set(sodium_PKG_STATIC_LIBRARIES libsodium.a) + endif() + + set(XPREFIX sodium_PKG_STATIC) + else() + if (NOT sodium_PKG_FOUND) + set(sodium_PKG_LIBRARIES sodium) + endif() + + set(XPREFIX sodium_PKG) + endif() + + find_path(sodium_INCLUDE_DIR sodium.h + HINTS ${${XPREFIX}_INCLUDE_DIRS} + ) + find_library(sodium_LIBRARY_DEBUG NAMES ${${XPREFIX}_LIBRARIES} + HINTS ${${XPREFIX}_LIBRARY_DIRS} + ) + find_library(sodium_LIBRARY_RELEASE NAMES ${${XPREFIX}_LIBRARIES} + HINTS ${${XPREFIX}_LIBRARY_DIRS} + ) + + +######################################################################## +# Windows +elseif (WIN32) + set(sodium_DIR "$ENV{sodium_DIR}" CACHE FILEPATH "sodium install directory") + mark_as_advanced(sodium_DIR) + + find_path(sodium_INCLUDE_DIR sodium.h + HINTS ${sodium_DIR} + PATH_SUFFIXES include + ) + + if (MSVC) + # detect target architecture + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" [=[ + #if defined _M_IX86 + #error ARCH_VALUE x86_32 + #elif defined _M_X64 + #error ARCH_VALUE x86_64 + #endif + #error ARCH_VALUE unknown + ]=]) + try_compile(_UNUSED_VAR "${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" + OUTPUT_VARIABLE _COMPILATION_LOG + ) + string(REGEX REPLACE ".*ARCH_VALUE ([a-zA-Z0-9_]+).*" "\\1" _TARGET_ARCH "${_COMPILATION_LOG}") + + # construct library path + if (_TARGET_ARCH STREQUAL "x86_32") + string(APPEND _PLATFORM_PATH "Win32") + elseif(_TARGET_ARCH STREQUAL "x86_64") + string(APPEND _PLATFORM_PATH "x64") + else() + message(FATAL_ERROR "the ${_TARGET_ARCH} architecture is not supported by Findsodium.cmake.") + endif() + string(APPEND _PLATFORM_PATH "/$$CONFIG$$") + + if (MSVC_VERSION LESS 1900) + math(EXPR _VS_VERSION "${MSVC_VERSION} / 10 - 60") + else() + math(EXPR _VS_VERSION "${MSVC_VERSION} / 10 - 50") + endif() + string(APPEND _PLATFORM_PATH "/v${_VS_VERSION}") + + if (sodium_USE_STATIC_LIBS) + string(APPEND _PLATFORM_PATH "/static") + else() + string(APPEND _PLATFORM_PATH "/dynamic") + endif() + + string(REPLACE "$$CONFIG$$" "Debug" _DEBUG_PATH_SUFFIX "${_PLATFORM_PATH}") + string(REPLACE "$$CONFIG$$" "Release" _RELEASE_PATH_SUFFIX "${_PLATFORM_PATH}") + + find_library(sodium_LIBRARY_DEBUG libsodium.lib + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX} + ) + find_library(sodium_LIBRARY_RELEASE libsodium.lib + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX} + ) + if (NOT sodium_USE_STATIC_LIBS) + set(CMAKE_FIND_LIBRARY_SUFFIXES_BCK ${CMAKE_FIND_LIBRARY_SUFFIXES}) + set(CMAKE_FIND_LIBRARY_SUFFIXES ".dll") + find_library(sodium_DLL_DEBUG libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX} + ) + find_library(sodium_DLL_RELEASE libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX} + ) + set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES_BCK}) + endif() + + elseif(_GCC_COMPATIBLE) + if (sodium_USE_STATIC_LIBS) + find_library(sodium_LIBRARY_DEBUG libsodium.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + find_library(sodium_LIBRARY_RELEASE libsodium.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + else() + find_library(sodium_LIBRARY_DEBUG libsodium.dll.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + find_library(sodium_LIBRARY_RELEASE libsodium.dll.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + + file(GLOB _DLL + LIST_DIRECTORIES false + RELATIVE "${sodium_DIR}/bin" + "${sodium_DIR}/bin/libsodium*.dll" + ) + find_library(sodium_DLL_DEBUG ${_DLL} libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES bin + ) + find_library(sodium_DLL_RELEASE ${_DLL} libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES bin + ) + endif() + else() + message(FATAL_ERROR "this platform is not supported by FindSodium.cmake") + endif() + + +######################################################################## +# unsupported +else() + message(FATAL_ERROR "this platform is not supported by FindSodium.cmake") +endif() + + +######################################################################## +# common stuff + +# extract sodium version +if (sodium_INCLUDE_DIR) + set(_VERSION_HEADER "${_INCLUDE_DIR}/sodium/version.h") + if (EXISTS _VERSION_HEADER) + file(READ "${_VERSION_HEADER}" _VERSION_HEADER_CONTENT) + string(REGEX REPLACE ".*#[ \t]*define[ \t]*SODIUM_VERSION_STRING[ \t]*\"([^\n]*)\".*" "\\1" + sodium_VERSION "${_VERSION_HEADER_CONTENT}") + set(sodium_VERSION "${sodium_VERSION}" PARENT_SCOPE) + endif() +endif() + +# communicate results +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + Sodium # The name must be either uppercase or match the filename case. + REQUIRED_VARS + sodium_LIBRARY_RELEASE + sodium_LIBRARY_DEBUG + sodium_INCLUDE_DIR + VERSION_VAR + sodium_VERSION +) + +if(Sodium_FOUND) + set(sodium_LIBRARIES + optimized ${sodium_LIBRARY_RELEASE} debug ${sodium_LIBRARY_DEBUG}) +endif() + +# mark file paths as advanced +mark_as_advanced(sodium_INCLUDE_DIR) +mark_as_advanced(sodium_LIBRARY_DEBUG) +mark_as_advanced(sodium_LIBRARY_RELEASE) +if (WIN32) + mark_as_advanced(sodium_DLL_DEBUG) + mark_as_advanced(sodium_DLL_RELEASE) +endif() + +# create imported target +if(sodium_USE_STATIC_LIBS) + set(_LIB_TYPE STATIC) +else() + set(_LIB_TYPE SHARED) +endif() + +if(NOT TARGET sodium) + add_library(sodium ${_LIB_TYPE} IMPORTED) +endif() + +set_target_properties(sodium PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${sodium_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" +) + +if (sodium_USE_STATIC_LIBS) + set_target_properties(sodium PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "SODIUM_STATIC" + IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" + IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}" + ) +else() + if (UNIX) + set_target_properties(sodium PROPERTIES + IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" + IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}" + ) + elseif (WIN32) + set_target_properties(sodium PROPERTIES + IMPORTED_IMPLIB "${sodium_LIBRARY_RELEASE}" + IMPORTED_IMPLIB_DEBUG "${sodium_LIBRARY_DEBUG}" + ) + if (NOT (sodium_DLL_DEBUG MATCHES ".*-NOTFOUND")) + set_target_properties(sodium PROPERTIES + IMPORTED_LOCATION_DEBUG "${sodium_DLL_DEBUG}" + ) + endif() + if (NOT (sodium_DLL_RELEASE MATCHES ".*-NOTFOUND")) + set_target_properties(sodium PROPERTIES + IMPORTED_LOCATION_RELWITHDEBINFO "${sodium_DLL_RELEASE}" + IMPORTED_LOCATION_MINSIZEREL "${sodium_DLL_RELEASE}" + IMPORTED_LOCATION_RELEASE "${sodium_DLL_RELEASE}" + ) + endif() + endif() +endif() diff --git a/build/fbcode_builder/CMake/FindZstd.cmake b/build/fbcode_builder/CMake/FindZstd.cmake new file mode 100644 index 000000000..89300ddfd --- /dev/null +++ b/build/fbcode_builder/CMake/FindZstd.cmake @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# - Try to find Facebook zstd library +# This will define +# ZSTD_FOUND +# ZSTD_INCLUDE_DIR +# ZSTD_LIBRARY +# + +find_path(ZSTD_INCLUDE_DIR NAMES zstd.h) + +find_library(ZSTD_LIBRARY_DEBUG NAMES zstdd zstd_staticd) +find_library(ZSTD_LIBRARY_RELEASE NAMES zstd zstd_static) + +include(SelectLibraryConfigurations) +SELECT_LIBRARY_CONFIGURATIONS(ZSTD) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + ZSTD DEFAULT_MSG + ZSTD_LIBRARY ZSTD_INCLUDE_DIR +) + +if (ZSTD_FOUND) + message(STATUS "Found Zstd: ${ZSTD_LIBRARY}") +endif() + +mark_as_advanced(ZSTD_INCLUDE_DIR ZSTD_LIBRARY) diff --git a/build/fbcode_builder/CMake/RustStaticLibrary.cmake b/build/fbcode_builder/CMake/RustStaticLibrary.cmake new file mode 100644 index 000000000..8546fe2fb --- /dev/null +++ b/build/fbcode_builder/CMake/RustStaticLibrary.cmake @@ -0,0 +1,291 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) + +set( + USE_CARGO_VENDOR AUTO CACHE STRING + "Download Rust Crates from an internally vendored location" +) +set_property(CACHE USE_CARGO_VENDOR PROPERTY STRINGS AUTO ON OFF) + +set(RUST_VENDORED_CRATES_DIR "$ENV{RUST_VENDORED_CRATES_DIR}") +if("${USE_CARGO_VENDOR}" STREQUAL "AUTO") + if(EXISTS "${RUST_VENDORED_CRATES_DIR}") + set(USE_CARGO_VENDOR ON) + else() + set(USE_CARGO_VENDOR OFF) + endif() +endif() + +if(USE_CARGO_VENDOR) + if(NOT EXISTS "${RUST_VENDORED_CRATES_DIR}") + message( + FATAL "vendored rust crates not present: " + "${RUST_VENDORED_CRATES_DIR}" + ) + endif() + + set(RUST_CARGO_HOME "${CMAKE_BINARY_DIR}/_cargo_home") + file(MAKE_DIRECTORY "${RUST_CARGO_HOME}") + + file( + TO_NATIVE_PATH "${RUST_VENDORED_CRATES_DIR}" + ESCAPED_RUST_VENDORED_CRATES_DIR + ) + string( + REPLACE "\\" "\\\\" + ESCAPED_RUST_VENDORED_CRATES_DIR + "${ESCAPED_RUST_VENDORED_CRATES_DIR}" + ) + file( + WRITE "${RUST_CARGO_HOME}/config" + "[source.crates-io]\n" + "replace-with = \"vendored-sources\"\n" + "\n" + "[source.vendored-sources]\n" + "directory = \"${ESCAPED_RUST_VENDORED_CRATES_DIR}\"\n" + ) +endif() + +# Cargo is a build system in itself, and thus will try to take advantage of all +# the cores on the system. Unfortunately, this conflicts with Ninja, since it +# also tries to utilize all the cores. This can lead to a system that is +# completely overloaded with compile jobs to the point where nothing else can +# be achieved on the system. +# +# Let's inform Ninja of this fact so it won't try to spawn other jobs while +# Rust being compiled. +set_property(GLOBAL APPEND PROPERTY JOB_POOLS rust_job_pool=1) + +# This function creates an interface library target based on the static library +# built by Cargo. It will call Cargo to build a staticlib and generate a CMake +# interface library with it. +# +# This function requires `find_package(Python COMPONENTS Interpreter)`. +# +# You need to set `lib:crate-type = ["staticlib"]` in your Cargo.toml to make +# Cargo build static library. +# +# ```cmake +# rust_static_library( [CRATE ]) +# ``` +# +# Parameters: +# - TARGET: +# Name of the target name. This function will create an interface library +# target with this name. +# - CRATE_NAME: +# Name of the crate. This parameter is optional. If unspecified, it will +# fallback to `${TARGET}`. +# +# This function creates two targets: +# - "${TARGET}": an interface library target contains the static library built +# from Cargo. +# - "${TARGET}.cargo": an internal custom target that invokes Cargo. +# +# If you are going to use this static library from C/C++, you will need to +# write header files for the library (or generate with cbindgen) and bind these +# headers with the interface library. +# +function(rust_static_library TARGET) + fb_cmake_parse_args(ARG "" "CRATE" "" "${ARGN}") + + if(DEFINED ARG_CRATE) + set(crate_name "${ARG_CRATE}") + else() + set(crate_name "${TARGET}") + endif() + + set(cargo_target "${TARGET}.cargo") + set(target_dir $,debug,release>) + set(staticlib_name "${CMAKE_STATIC_LIBRARY_PREFIX}${crate_name}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(rust_staticlib "${CMAKE_CURRENT_BINARY_DIR}/${target_dir}/${staticlib_name}") + + set(cargo_cmd cargo) + if(WIN32) + set(cargo_cmd cargo.exe) + endif() + + set(cargo_flags build $,,--release> -p ${crate_name}) + if(USE_CARGO_VENDOR) + set(extra_cargo_env "CARGO_HOME=${RUST_CARGO_HOME}") + set(cargo_flags ${cargo_flags}) + endif() + + add_custom_target( + ${cargo_target} + COMMAND + "${CMAKE_COMMAND}" -E remove -f "${CMAKE_CURRENT_SOURCE_DIR}/Cargo.lock" + COMMAND + "${CMAKE_COMMAND}" -E env + "CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR}" + ${extra_cargo_env} + ${cargo_cmd} + ${cargo_flags} + COMMENT "Building Rust crate '${crate_name}'..." + JOB_POOL rust_job_pool + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + BYPRODUCTS + "${CMAKE_CURRENT_BINARY_DIR}/debug/${staticlib_name}" + "${CMAKE_CURRENT_BINARY_DIR}/release/${staticlib_name}" + ) + + add_library(${TARGET} INTERFACE) + add_dependencies(${TARGET} ${cargo_target}) + set_target_properties( + ${TARGET} + PROPERTIES + INTERFACE_STATICLIB_OUTPUT_PATH "${rust_staticlib}" + INTERFACE_INSTALL_LIBNAME + "${CMAKE_STATIC_LIBRARY_PREFIX}${crate_name}_rs${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + target_link_libraries( + ${TARGET} + INTERFACE "$" + ) +endfunction() + +# This function instructs cmake to define a target that will use `cargo build` +# to build a bin crate referenced by the Cargo.toml file in the current source +# directory. +# It accepts a single `TARGET` parameter which will be passed as the package +# name to `cargo build -p TARGET`. If binary has different name as package, +# use optional flag BINARY_NAME to override it. +# The cmake target will be registered to build by default as part of the +# ALL target. +function(rust_executable TARGET) + fb_cmake_parse_args(ARG "" "BINARY_NAME" "" "${ARGN}") + + set(crate_name "${TARGET}") + set(cargo_target "${TARGET}.cargo") + set(target_dir $,debug,release>) + + if(DEFINED ARG_BINARY_NAME) + set(executable_name "${ARG_BINARY_NAME}${CMAKE_EXECUTABLE_SUFFIX}") + else() + set(executable_name "${crate_name}${CMAKE_EXECUTABLE_SUFFIX}") + endif() + + set(cargo_cmd cargo) + if(WIN32) + set(cargo_cmd cargo.exe) + endif() + + set(cargo_flags build $,,--release> -p ${crate_name}) + if(USE_CARGO_VENDOR) + set(extra_cargo_env "CARGO_HOME=${RUST_CARGO_HOME}") + set(cargo_flags ${cargo_flags}) + endif() + + add_custom_target( + ${cargo_target} + ALL + COMMAND + "${CMAKE_COMMAND}" -E remove -f "${CMAKE_CURRENT_SOURCE_DIR}/Cargo.lock" + COMMAND + "${CMAKE_COMMAND}" -E env + "CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR}" + ${extra_cargo_env} + ${cargo_cmd} + ${cargo_flags} + COMMENT "Building Rust executable '${crate_name}'..." + JOB_POOL rust_job_pool + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + BYPRODUCTS + "${CMAKE_CURRENT_BINARY_DIR}/debug/${executable_name}" + "${CMAKE_CURRENT_BINARY_DIR}/release/${executable_name}" + ) + + set_property(TARGET "${cargo_target}" + PROPERTY EXECUTABLE "${CMAKE_CURRENT_BINARY_DIR}/${target_dir}/${executable_name}") +endfunction() + +# This function can be used to install the executable generated by a prior +# call to the `rust_executable` function. +# It requires a `TARGET` parameter to identify the target to be installed, +# and an optional `DESTINATION` parameter to specify the installation +# directory. If DESTINATION is not specified then the `bin` directory +# will be assumed. +function(install_rust_executable TARGET) + # Parse the arguments + set(one_value_args DESTINATION) + set(multi_value_args) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_DESTINATION) + set(ARG_DESTINATION bin) + endif() + + get_target_property(foo "${TARGET}.cargo" EXECUTABLE) + + install( + PROGRAMS "${foo}" + DESTINATION "${ARG_DESTINATION}" + ) +endfunction() + +# This function installs the interface target generated from the function +# `rust_static_library`. Use this function if you want to export your Rust +# target to external CMake targets. +# +# ```cmake +# install_rust_static_library( +# +# INSTALL_DIR +# [EXPORT ] +# ) +# ``` +# +# Parameters: +# - TARGET: Name of the Rust static library target. +# - EXPORT_NAME: Name of the exported target. +# - INSTALL_DIR: Path to the directory where this library will be installed. +# +function(install_rust_static_library TARGET) + fb_cmake_parse_args(ARG "" "EXPORT;INSTALL_DIR" "" "${ARGN}") + + get_property( + staticlib_output_path + TARGET "${TARGET}" + PROPERTY INTERFACE_STATICLIB_OUTPUT_PATH + ) + get_property( + staticlib_output_name + TARGET "${TARGET}" + PROPERTY INTERFACE_INSTALL_LIBNAME + ) + + if(NOT DEFINED staticlib_output_path) + message(FATAL_ERROR "Not a rust_static_library target.") + endif() + + if(NOT DEFINED ARG_INSTALL_DIR) + message(FATAL_ERROR "Missing required argument.") + endif() + + if(DEFINED ARG_EXPORT) + set(install_export_args EXPORT "${ARG_EXPORT}") + endif() + + set(install_interface_dir "${ARG_INSTALL_DIR}") + if(NOT IS_ABSOLUTE "${install_interface_dir}") + set(install_interface_dir "\${_IMPORT_PREFIX}/${install_interface_dir}") + endif() + + target_link_libraries( + ${TARGET} INTERFACE + "$" + ) + install( + TARGETS ${TARGET} + ${install_export_args} + LIBRARY DESTINATION ${ARG_INSTALL_DIR} + ) + install( + FILES ${staticlib_output_path} + RENAME ${staticlib_output_name} + DESTINATION ${ARG_INSTALL_DIR} + ) +endfunction() diff --git a/build/fbcode_builder/CMake/fb_py_test_main.py b/build/fbcode_builder/CMake/fb_py_test_main.py new file mode 100644 index 000000000..1f3563aff --- /dev/null +++ b/build/fbcode_builder/CMake/fb_py_test_main.py @@ -0,0 +1,820 @@ +#!/usr/bin/env python +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +""" +This file contains the main module code for Python test programs. +""" + +from __future__ import print_function + +import contextlib +import ctypes +import fnmatch +import json +import logging +import optparse +import os +import platform +import re +import sys +import tempfile +import time +import traceback +import unittest +import warnings + +# Hide warning about importing "imp"; remove once python2 is gone. +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + import imp + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO +try: + import coverage +except ImportError: + coverage = None # type: ignore +try: + from importlib.machinery import SourceFileLoader +except ImportError: + SourceFileLoader = None # type: ignore + + +class get_cpu_instr_counter(object): + def read(self): + # TODO + return 0 + + +EXIT_CODE_SUCCESS = 0 +EXIT_CODE_TEST_FAILURE = 70 + + +class TestStatus(object): + + ABORTED = "FAILURE" + PASSED = "SUCCESS" + FAILED = "FAILURE" + EXPECTED_FAILURE = "SUCCESS" + UNEXPECTED_SUCCESS = "FAILURE" + SKIPPED = "ASSUMPTION_VIOLATION" + + +class PathMatcher(object): + def __init__(self, include_patterns, omit_patterns): + self.include_patterns = include_patterns + self.omit_patterns = omit_patterns + + def omit(self, path): + """ + Omit iff matches any of the omit_patterns or the include patterns are + not empty and none is matched + """ + path = os.path.realpath(path) + return any(fnmatch.fnmatch(path, p) for p in self.omit_patterns) or ( + self.include_patterns + and not any(fnmatch.fnmatch(path, p) for p in self.include_patterns) + ) + + def include(self, path): + return not self.omit(path) + + +class DebugWipeFinder(object): + """ + PEP 302 finder that uses a DebugWipeLoader for all files which do not need + coverage + """ + + def __init__(self, matcher): + self.matcher = matcher + + def find_module(self, fullname, path=None): + _, _, basename = fullname.rpartition(".") + try: + fd, pypath, (_, _, kind) = imp.find_module(basename, path) + except Exception: + # Finding without hooks using the imp module failed. One reason + # could be that there is a zip file on sys.path. The imp module + # does not support loading from there. Leave finding this module to + # the others finders in sys.meta_path. + return None + + if hasattr(fd, "close"): + fd.close() + if kind != imp.PY_SOURCE: + return None + if self.matcher.include(pypath): + return None + + """ + This is defined to match CPython's PyVarObject struct + """ + + class PyVarObject(ctypes.Structure): + _fields_ = [ + ("ob_refcnt", ctypes.c_long), + ("ob_type", ctypes.c_void_p), + ("ob_size", ctypes.c_ulong), + ] + + class DebugWipeLoader(SourceFileLoader): + """ + PEP302 loader that zeros out debug information before execution + """ + + def get_code(self, fullname): + code = super(DebugWipeLoader, self).get_code(fullname) + if code: + # Ideally we'd do + # code.co_lnotab = b'' + # But code objects are READONLY. Not to worry though; we'll + # directly modify CPython's object + code_impl = PyVarObject.from_address(id(code.co_lnotab)) + code_impl.ob_size = 0 + return code + + return DebugWipeLoader(fullname, pypath) + + +def optimize_for_coverage(cov, include_patterns, omit_patterns): + """ + We get better performance if we zero out debug information for files which + we're not interested in. Only available in CPython 3.3+ + """ + matcher = PathMatcher(include_patterns, omit_patterns) + if SourceFileLoader and platform.python_implementation() == "CPython": + sys.meta_path.insert(0, DebugWipeFinder(matcher)) + + +class TeeStream(object): + def __init__(self, *streams): + self._streams = streams + + def write(self, data): + for stream in self._streams: + stream.write(data) + + def flush(self): + for stream in self._streams: + stream.flush() + + def isatty(self): + return False + + +class CallbackStream(object): + def __init__(self, callback, bytes_callback=None, orig=None): + self._callback = callback + self._fileno = orig.fileno() if orig else None + + # Python 3 APIs: + # - `encoding` is a string holding the encoding name + # - `errors` is a string holding the error-handling mode for encoding + # - `buffer` should look like an io.BufferedIOBase object + + self.errors = orig.errors if orig else None + if bytes_callback: + # those members are only on the io.TextIOWrapper + self.encoding = orig.encoding if orig else "UTF-8" + self.buffer = CallbackStream(bytes_callback, orig=orig) + + def write(self, data): + self._callback(data) + + def flush(self): + pass + + def isatty(self): + return False + + def fileno(self): + return self._fileno + + +class BuckTestResult(unittest._TextTestResult): + """ + Our own TestResult class that outputs data in a format that can be easily + parsed by buck's test runner. + """ + + _instr_counter = get_cpu_instr_counter() + + def __init__( + self, stream, descriptions, verbosity, show_output, main_program, suite + ): + super(BuckTestResult, self).__init__(stream, descriptions, verbosity) + self._main_program = main_program + self._suite = suite + self._results = [] + self._current_test = None + self._saved_stdout = sys.stdout + self._saved_stderr = sys.stderr + self._show_output = show_output + + def getResults(self): + return self._results + + def startTest(self, test): + super(BuckTestResult, self).startTest(test) + + # Pass in the real stdout and stderr filenos. We can't really do much + # here to intercept callers who directly operate on these fileno + # objects. + sys.stdout = CallbackStream( + self.addStdout, self.addStdoutBytes, orig=sys.stdout + ) + sys.stderr = CallbackStream( + self.addStderr, self.addStderrBytes, orig=sys.stderr + ) + self._current_test = test + self._test_start_time = time.time() + self._current_status = TestStatus.ABORTED + self._messages = [] + self._stacktrace = None + self._stdout = "" + self._stderr = "" + self._start_instr_count = self._instr_counter.read() + + def _find_next_test(self, suite): + """ + Find the next test that has not been run. + """ + + for test in suite: + + # We identify test suites by test that are iterable (as is done in + # the builtin python test harness). If we see one, recurse on it. + if hasattr(test, "__iter__"): + test = self._find_next_test(test) + + # The builtin python test harness sets test references to `None` + # after they have run, so we know we've found the next test up + # if it's not `None`. + if test is not None: + return test + + def stopTest(self, test): + sys.stdout = self._saved_stdout + sys.stderr = self._saved_stderr + + super(BuckTestResult, self).stopTest(test) + + # If a failure occured during module/class setup, then this "test" may + # actually be a `_ErrorHolder`, which doesn't contain explicit info + # about the upcoming test. Since we really only care about the test + # name field (i.e. `_testMethodName`), we use that to detect an actual + # test cases, and fall back to looking the test up from the suite + # otherwise. + if not hasattr(test, "_testMethodName"): + test = self._find_next_test(self._suite) + + result = { + "testCaseName": "{0}.{1}".format( + test.__class__.__module__, test.__class__.__name__ + ), + "testCase": test._testMethodName, + "type": self._current_status, + "time": int((time.time() - self._test_start_time) * 1000), + "message": os.linesep.join(self._messages), + "stacktrace": self._stacktrace, + "stdOut": self._stdout, + "stdErr": self._stderr, + } + + # TestPilot supports an instruction count field. + if "TEST_PILOT" in os.environ: + result["instrCount"] = ( + int(self._instr_counter.read() - self._start_instr_count), + ) + + self._results.append(result) + self._current_test = None + + def stopTestRun(self): + cov = self._main_program.get_coverage() + if cov is not None: + self._results.append({"coverage": cov}) + + @contextlib.contextmanager + def _withTest(self, test): + self.startTest(test) + yield + self.stopTest(test) + + def _setStatus(self, test, status, message=None, stacktrace=None): + assert test == self._current_test + self._current_status = status + self._stacktrace = stacktrace + if message is not None: + if message.endswith(os.linesep): + message = message[:-1] + self._messages.append(message) + + def setStatus(self, test, status, message=None, stacktrace=None): + # addError() may be called outside of a test if one of the shared + # fixtures (setUpClass/tearDownClass/setUpModule/tearDownModule) + # throws an error. + # + # In this case, create a fake test result to record the error. + if self._current_test is None: + with self._withTest(test): + self._setStatus(test, status, message, stacktrace) + else: + self._setStatus(test, status, message, stacktrace) + + def setException(self, test, status, excinfo): + exctype, value, tb = excinfo + self.setStatus( + test, + status, + "{0}: {1}".format(exctype.__name__, value), + "".join(traceback.format_tb(tb)), + ) + + def addSuccess(self, test): + super(BuckTestResult, self).addSuccess(test) + self.setStatus(test, TestStatus.PASSED) + + def addError(self, test, err): + super(BuckTestResult, self).addError(test, err) + self.setException(test, TestStatus.ABORTED, err) + + def addFailure(self, test, err): + super(BuckTestResult, self).addFailure(test, err) + self.setException(test, TestStatus.FAILED, err) + + def addSkip(self, test, reason): + super(BuckTestResult, self).addSkip(test, reason) + self.setStatus(test, TestStatus.SKIPPED, "Skipped: %s" % (reason,)) + + def addExpectedFailure(self, test, err): + super(BuckTestResult, self).addExpectedFailure(test, err) + self.setException(test, TestStatus.EXPECTED_FAILURE, err) + + def addUnexpectedSuccess(self, test): + super(BuckTestResult, self).addUnexpectedSuccess(test) + self.setStatus(test, TestStatus.UNEXPECTED_SUCCESS, "Unexpected success") + + def addStdout(self, val): + self._stdout += val + if self._show_output: + self._saved_stdout.write(val) + self._saved_stdout.flush() + + def addStdoutBytes(self, val): + string = val.decode("utf-8", errors="backslashreplace") + self.addStdout(string) + + def addStderr(self, val): + self._stderr += val + if self._show_output: + self._saved_stderr.write(val) + self._saved_stderr.flush() + + def addStderrBytes(self, val): + string = val.decode("utf-8", errors="backslashreplace") + self.addStderr(string) + + +class BuckTestRunner(unittest.TextTestRunner): + def __init__(self, main_program, suite, show_output=True, **kwargs): + super(BuckTestRunner, self).__init__(**kwargs) + self.show_output = show_output + self._main_program = main_program + self._suite = suite + + def _makeResult(self): + return BuckTestResult( + self.stream, + self.descriptions, + self.verbosity, + self.show_output, + self._main_program, + self._suite, + ) + + +def _format_test_name(test_class, attrname): + return "{0}.{1}.{2}".format(test_class.__module__, test_class.__name__, attrname) + + +class StderrLogHandler(logging.StreamHandler): + """ + This class is very similar to logging.StreamHandler, except that it + always uses the current sys.stderr object. + + StreamHandler caches the current sys.stderr object when it is constructed. + This makes it behave poorly in unit tests, which may replace sys.stderr + with a StringIO buffer during tests. The StreamHandler will continue using + the old sys.stderr object instead of the desired StringIO buffer. + """ + + def __init__(self): + logging.Handler.__init__(self) + + @property + def stream(self): + return sys.stderr + + +class RegexTestLoader(unittest.TestLoader): + def __init__(self, regex=None): + self.regex = regex + super(RegexTestLoader, self).__init__() + + def getTestCaseNames(self, testCaseClass): + """ + Return a sorted sequence of method names found within testCaseClass + """ + + testFnNames = super(RegexTestLoader, self).getTestCaseNames(testCaseClass) + if self.regex is None: + return testFnNames + robj = re.compile(self.regex) + matched = [] + for attrname in testFnNames: + fullname = _format_test_name(testCaseClass, attrname) + if robj.search(fullname): + matched.append(attrname) + return matched + + +class Loader(object): + + suiteClass = unittest.TestSuite + + def __init__(self, modules, regex=None): + self.modules = modules + self.regex = regex + + def load_all(self): + loader = RegexTestLoader(self.regex) + test_suite = self.suiteClass() + for module_name in self.modules: + __import__(module_name, level=0) + module = sys.modules[module_name] + module_suite = loader.loadTestsFromModule(module) + test_suite.addTest(module_suite) + return test_suite + + def load_args(self, args): + loader = RegexTestLoader(self.regex) + + suites = [] + for arg in args: + suite = loader.loadTestsFromName(arg) + # loadTestsFromName() can only process names that refer to + # individual test functions or modules. It can't process package + # names. If there were no module/function matches, check to see if + # this looks like a package name. + if suite.countTestCases() != 0: + suites.append(suite) + continue + + # Load all modules whose name is . + prefix = arg + "." + for module in self.modules: + if module.startswith(prefix): + suite = loader.loadTestsFromName(module) + suites.append(suite) + + return loader.suiteClass(suites) + + +_COVERAGE_INI = """\ +[report] +exclude_lines = + pragma: no cover + pragma: nocover + pragma:.*no${PLATFORM} + pragma:.*no${PY_IMPL}${PY_MAJOR}${PY_MINOR} + pragma:.*no${PY_IMPL}${PY_MAJOR} + pragma:.*nopy${PY_MAJOR} + pragma:.*nopy${PY_MAJOR}${PY_MINOR} +""" + + +class MainProgram(object): + """ + This class implements the main program. It can be subclassed by + users who wish to customize some parts of the main program. + (Adding additional command line options, customizing test loading, etc.) + """ + + DEFAULT_VERBOSITY = 2 + + def __init__(self, argv): + self.init_option_parser() + self.parse_options(argv) + self.setup_logging() + + def init_option_parser(self): + usage = "%prog [options] [TEST] ..." + op = optparse.OptionParser(usage=usage, add_help_option=False) + self.option_parser = op + + op.add_option( + "--hide-output", + dest="show_output", + action="store_false", + default=True, + help="Suppress data that tests print to stdout/stderr, and only " + "show it if the test fails.", + ) + op.add_option( + "-o", + "--output", + help="Write results to a file in a JSON format to be read by Buck", + ) + op.add_option( + "-f", + "--failfast", + action="store_true", + default=False, + help="Stop after the first failure", + ) + op.add_option( + "-l", + "--list-tests", + action="store_true", + dest="list", + default=False, + help="List tests and exit", + ) + op.add_option( + "-r", + "--regex", + default=None, + help="Regex to apply to tests, to only run those tests", + ) + op.add_option( + "--collect-coverage", + action="store_true", + default=False, + help="Collect test coverage information", + ) + op.add_option( + "--coverage-include", + default="*", + help='File globs to include in converage (split by ",")', + ) + op.add_option( + "--coverage-omit", + default="", + help='File globs to omit from converage (split by ",")', + ) + op.add_option( + "--logger", + action="append", + metavar="=", + default=[], + help="Configure log levels for specific logger categories", + ) + op.add_option( + "-q", + "--quiet", + action="count", + default=0, + help="Decrease the verbosity (may be specified multiple times)", + ) + op.add_option( + "-v", + "--verbosity", + action="count", + default=self.DEFAULT_VERBOSITY, + help="Increase the verbosity (may be specified multiple times)", + ) + op.add_option( + "-?", "--help", action="help", help="Show this help message and exit" + ) + + def parse_options(self, argv): + self.options, self.test_args = self.option_parser.parse_args(argv[1:]) + self.options.verbosity -= self.options.quiet + + if self.options.collect_coverage and coverage is None: + self.option_parser.error("coverage module is not available") + self.options.coverage_include = self.options.coverage_include.split(",") + if self.options.coverage_omit == "": + self.options.coverage_omit = [] + else: + self.options.coverage_omit = self.options.coverage_omit.split(",") + + def setup_logging(self): + # Configure the root logger to log at INFO level. + # This is similar to logging.basicConfig(), but uses our + # StderrLogHandler instead of a StreamHandler. + fmt = logging.Formatter("%(pathname)s:%(lineno)s: %(message)s") + log_handler = StderrLogHandler() + log_handler.setFormatter(fmt) + root_logger = logging.getLogger() + root_logger.addHandler(log_handler) + root_logger.setLevel(logging.INFO) + + level_names = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warn": logging.WARNING, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, + "fatal": logging.FATAL, + } + + for value in self.options.logger: + parts = value.rsplit("=", 1) + if len(parts) != 2: + self.option_parser.error( + "--logger argument must be of the " + "form =: %s" % value + ) + name = parts[0] + level_name = parts[1].lower() + level = level_names.get(level_name) + if level is None: + self.option_parser.error( + "invalid log level %r for log " "category %s" % (parts[1], name) + ) + logging.getLogger(name).setLevel(level) + + def create_loader(self): + import __test_modules__ + + return Loader(__test_modules__.TEST_MODULES, self.options.regex) + + def load_tests(self): + loader = self.create_loader() + if self.options.collect_coverage: + self.start_coverage() + include = self.options.coverage_include + omit = self.options.coverage_omit + if include and "*" not in include: + optimize_for_coverage(self.cov, include, omit) + + if self.test_args: + suite = loader.load_args(self.test_args) + else: + suite = loader.load_all() + if self.options.collect_coverage: + self.cov.start() + return suite + + def get_tests(self, test_suite): + tests = [] + + for test in test_suite: + if isinstance(test, unittest.TestSuite): + tests.extend(self.get_tests(test)) + else: + tests.append(test) + + return tests + + def run(self): + test_suite = self.load_tests() + + if self.options.list: + for test in self.get_tests(test_suite): + method_name = getattr(test, "_testMethodName", "") + name = _format_test_name(test.__class__, method_name) + print(name) + return EXIT_CODE_SUCCESS + else: + result = self.run_tests(test_suite) + if self.options.output is not None: + with open(self.options.output, "w") as f: + json.dump(result.getResults(), f, indent=4, sort_keys=True) + if not result.wasSuccessful(): + return EXIT_CODE_TEST_FAILURE + return EXIT_CODE_SUCCESS + + def run_tests(self, test_suite): + # Install a signal handler to catch Ctrl-C and display the results + # (but only if running >2.6). + if sys.version_info[0] > 2 or sys.version_info[1] > 6: + unittest.installHandler() + + # Run the tests + runner = BuckTestRunner( + self, + test_suite, + verbosity=self.options.verbosity, + show_output=self.options.show_output, + ) + result = runner.run(test_suite) + + if self.options.collect_coverage and self.options.show_output: + self.cov.stop() + try: + self.cov.report(file=sys.stdout) + except coverage.misc.CoverageException: + print("No lines were covered, potentially restricted by file filters") + + return result + + def get_abbr_impl(self): + """Return abbreviated implementation name.""" + impl = platform.python_implementation() + if impl == "PyPy": + return "pp" + elif impl == "Jython": + return "jy" + elif impl == "IronPython": + return "ip" + elif impl == "CPython": + return "cp" + else: + raise RuntimeError("unknown python runtime") + + def start_coverage(self): + if not self.options.collect_coverage: + return + + with tempfile.NamedTemporaryFile("w", delete=False) as coverage_ini: + coverage_ini.write(_COVERAGE_INI) + self._coverage_ini_path = coverage_ini.name + + # Keep the original working dir in case tests use os.chdir + self._original_working_dir = os.getcwd() + + # for coverage config ignores by platform/python version + os.environ["PLATFORM"] = sys.platform + os.environ["PY_IMPL"] = self.get_abbr_impl() + os.environ["PY_MAJOR"] = str(sys.version_info.major) + os.environ["PY_MINOR"] = str(sys.version_info.minor) + + self.cov = coverage.Coverage( + include=self.options.coverage_include, + omit=self.options.coverage_omit, + config_file=coverage_ini.name, + ) + self.cov.erase() + self.cov.start() + + def get_coverage(self): + if not self.options.collect_coverage: + return None + + try: + os.remove(self._coverage_ini_path) + except OSError: + pass # Better to litter than to fail the test + + # Switch back to the original working directory. + os.chdir(self._original_working_dir) + + result = {} + + self.cov.stop() + + try: + f = StringIO() + self.cov.report(file=f) + lines = f.getvalue().split("\n") + except coverage.misc.CoverageException: + # Nothing was covered. That's fine by us + return result + + # N.B.: the format of the coverage library's output differs + # depending on whether one or more files are in the results + for line in lines[2:]: + if line.strip("-") == "": + break + r = line.split()[0] + analysis = self.cov.analysis2(r) + covString = self.convert_to_diff_cov_str(analysis) + if covString: + result[r] = covString + + return result + + def convert_to_diff_cov_str(self, analysis): + # Info on the format of analysis: + # http://nedbatchelder.com/code/coverage/api.html + if not analysis: + return None + numLines = max( + analysis[1][-1] if len(analysis[1]) else 0, + analysis[2][-1] if len(analysis[2]) else 0, + analysis[3][-1] if len(analysis[3]) else 0, + ) + lines = ["N"] * numLines + for l in analysis[1]: + lines[l - 1] = "C" + for l in analysis[2]: + lines[l - 1] = "X" + for l in analysis[3]: + lines[l - 1] = "U" + return "".join(lines) + + +def main(argv): + return MainProgram(sys.argv).run() + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/build/fbcode_builder/CMake/fb_py_win_main.c b/build/fbcode_builder/CMake/fb_py_win_main.c new file mode 100644 index 000000000..8905c3602 --- /dev/null +++ b/build/fbcode_builder/CMake/fb_py_win_main.c @@ -0,0 +1,126 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#define WIN32_LEAN_AND_MEAN + +#include +#include +#include + +#define PATH_SIZE 32768 + +typedef int (*Py_Main)(int, wchar_t**); + +// Add the given path to Windows's DLL search path. +// For Windows DLL search path resolution, see: +// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order +void add_search_path(const wchar_t* path) { + wchar_t buffer[PATH_SIZE]; + wchar_t** lppPart = NULL; + + if (!GetFullPathNameW(path, PATH_SIZE, buffer, lppPart)) { + fwprintf(stderr, L"warning: %d unable to expand path %s\n", GetLastError(), path); + return; + } + + if (!AddDllDirectory(buffer)) { + DWORD error = GetLastError(); + if (error != ERROR_FILE_NOT_FOUND) { + fwprintf(stderr, L"warning: %d unable to set DLL search path for %s\n", GetLastError(), path); + } + } +} + +int locate_py_main(int argc, wchar_t **argv) { + /* + * We have to dynamically locate Python3.dll because we may be loading a + * Python native module while running. If that module is built with a + * different Python version, we will end up a DLL import error. To resolve + * this, we can either ship an embedded version of Python with us or + * dynamically look up existing Python distribution installed on user's + * machine. This way, we should be able to get a consistent version of + * Python3.dll and .pyd modules. + */ + HINSTANCE python_dll; + Py_Main pymain; + + // last added directory has highest priority + add_search_path(L"C:\\Python36\\"); + add_search_path(L"C:\\Python37\\"); + add_search_path(L"C:\\Python38\\"); + + python_dll = LoadLibraryExW(L"python3.dll", NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS); + + int returncode = 0; + if (python_dll != NULL) { + pymain = (Py_Main) GetProcAddress(python_dll, "Py_Main"); + + if (pymain != NULL) { + returncode = (pymain)(argc, argv); + } else { + fprintf(stderr, "error: %d unable to load Py_Main\n", GetLastError()); + } + + FreeLibrary(python_dll); + } else { + fprintf(stderr, "error: %d unable to locate python3.dll\n", GetLastError()); + return 1; + } + return returncode; +} + +int wmain() { + /* + * This executable will be prepended to the start of a Python ZIP archive. + * Python will be able to directly execute the ZIP archive, so we simply + * need to tell Py_Main() to run our own file. Duplicate the argument list + * and add our file name to the beginning to tell Python what file to invoke. + */ + wchar_t** pyargv = malloc(sizeof(wchar_t*) * (__argc + 1)); + if (!pyargv) { + fprintf(stderr, "error: failed to allocate argument vector\n"); + return 1; + } + + /* Py_Main wants the wide character version of the argv so we pull those + * values from the global __wargv array that has been prepared by MSVCRT. + * + * In order for the zipapp to run we need to insert an extra argument in + * the front of the argument vector that points to ourselves. + * + * An additional complication is that, depending on who prepared the argument + * string used to start our process, the computed __wargv[0] can be a simple + * shell word like `watchman-wait` which is normally resolved together with + * the PATH by the shell. + * That unresolved path isn't sufficient to start the zipapp on windows; + * we need the fully qualified path. + * + * Given: + * __wargv == {"watchman-wait", "-h"} + * + * we want to pass the following to Py_Main: + * + * { + * "z:\build\watchman\python\watchman-wait.exe", + * "z:\build\watchman\python\watchman-wait.exe", + * "-h" + * } + */ + wchar_t full_path_to_argv0[PATH_SIZE]; + DWORD len = GetModuleFileNameW(NULL, full_path_to_argv0, PATH_SIZE); + if (len == 0 || + len == PATH_SIZE && GetLastError() == ERROR_INSUFFICIENT_BUFFER) { + fprintf( + stderr, + "error: %d while retrieving full path to this executable\n", + GetLastError()); + return 1; + } + + for (int n = 1; n < __argc; ++n) { + pyargv[n + 1] = __wargv[n]; + } + pyargv[0] = full_path_to_argv0; + pyargv[1] = full_path_to_argv0; + + return locate_py_main(__argc + 1, pyargv); +} diff --git a/build/fbcode_builder/CMake/make_fbpy_archive.py b/build/fbcode_builder/CMake/make_fbpy_archive.py new file mode 100755 index 000000000..3724feb21 --- /dev/null +++ b/build/fbcode_builder/CMake/make_fbpy_archive.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +import argparse +import collections +import errno +import os +import shutil +import sys +import tempfile +import zipapp + +MANIFEST_SEPARATOR = " :: " +MANIFEST_HEADER_V1 = "FBPY_MANIFEST 1\n" + + +class UsageError(Exception): + def __init__(self, message): + self.message = message + + def __str__(self): + return self.message + + +class BadManifestError(UsageError): + def __init__(self, path, line_num, message): + full_msg = "%s:%s: %s" % (path, line_num, message) + super().__init__(full_msg) + self.path = path + self.line_num = line_num + self.raw_message = message + + +PathInfo = collections.namedtuple( + "PathInfo", ("src", "dest", "manifest_path", "manifest_line") +) + + +def parse_manifest(manifest, path_map): + bad_prefix = ".." + os.path.sep + manifest_dir = os.path.dirname(manifest) + with open(manifest, "r") as f: + line_num = 1 + line = f.readline() + if line != MANIFEST_HEADER_V1: + raise BadManifestError( + manifest, line_num, "Unexpected manifest file header" + ) + + for line in f: + line_num += 1 + if line.startswith("#"): + continue + line = line.rstrip("\n") + parts = line.split(MANIFEST_SEPARATOR) + if len(parts) != 2: + msg = "line must be of the form SRC %s DEST" % MANIFEST_SEPARATOR + raise BadManifestError(manifest, line_num, msg) + src, dest = parts + dest = os.path.normpath(dest) + if dest.startswith(bad_prefix): + msg = "destination path starts with %s: %s" % (bad_prefix, dest) + raise BadManifestError(manifest, line_num, msg) + + if not os.path.isabs(src): + src = os.path.normpath(os.path.join(manifest_dir, src)) + + if dest in path_map: + prev_info = path_map[dest] + msg = ( + "multiple source paths specified for destination " + "path %s. Previous source was %s from %s:%s" + % ( + dest, + prev_info.src, + prev_info.manifest_path, + prev_info.manifest_line, + ) + ) + raise BadManifestError(manifest, line_num, msg) + + info = PathInfo( + src=src, + dest=dest, + manifest_path=manifest, + manifest_line=line_num, + ) + path_map[dest] = info + + +def populate_install_tree(inst_dir, path_map): + os.mkdir(inst_dir) + dest_dirs = {"": False} + + def make_dest_dir(path): + if path in dest_dirs: + return + parent = os.path.dirname(path) + make_dest_dir(parent) + abs_path = os.path.join(inst_dir, path) + os.mkdir(abs_path) + dest_dirs[path] = False + + def install_file(info): + dir_name, base_name = os.path.split(info.dest) + make_dest_dir(dir_name) + if base_name == "__init__.py": + dest_dirs[dir_name] = True + abs_dest = os.path.join(inst_dir, info.dest) + shutil.copy2(info.src, abs_dest) + + # Copy all of the destination files + for info in path_map.values(): + install_file(info) + + # Create __init__ files in any directories that don't have them. + for dir_path, has_init in dest_dirs.items(): + if has_init: + continue + init_path = os.path.join(inst_dir, dir_path, "__init__.py") + with open(init_path, "w"): + pass + + +def build_zipapp(args, path_map): + """Create a self executing python binary using Python 3's built-in + zipapp module. + + This type of Python binary is relatively simple, as zipapp is part of the + standard library, but it does not support native language extensions + (.so/.dll files). + """ + dest_dir = os.path.dirname(args.output) + with tempfile.TemporaryDirectory(prefix="make_fbpy.", dir=dest_dir) as tmpdir: + inst_dir = os.path.join(tmpdir, "tree") + populate_install_tree(inst_dir, path_map) + + tmp_output = os.path.join(tmpdir, "output.exe") + zipapp.create_archive( + inst_dir, target=tmp_output, interpreter=args.python, main=args.main + ) + os.replace(tmp_output, args.output) + + +def create_main_module(args, inst_dir, path_map): + if not args.main: + assert "__main__.py" in path_map + return + + dest_path = os.path.join(inst_dir, "__main__.py") + main_module, main_fn = args.main.split(":") + main_contents = """\ +#!{python} + +if __name__ == "__main__": + import {main_module} + {main_module}.{main_fn}() +""".format( + python=args.python, main_module=main_module, main_fn=main_fn + ) + with open(dest_path, "w") as f: + f.write(main_contents) + os.chmod(dest_path, 0o755) + + +def build_install_dir(args, path_map): + """Create a directory that contains all of the sources, with a __main__ + module to run the program. + """ + # Populate a temporary directory first, then rename to the destination + # location. This ensures that we don't ever leave a halfway-built + # directory behind at the output path if something goes wrong. + dest_dir = os.path.dirname(args.output) + with tempfile.TemporaryDirectory(prefix="make_fbpy.", dir=dest_dir) as tmpdir: + inst_dir = os.path.join(tmpdir, "tree") + populate_install_tree(inst_dir, path_map) + create_main_module(args, inst_dir, path_map) + os.rename(inst_dir, args.output) + + +def ensure_directory(path): + try: + os.makedirs(path) + except OSError as ex: + if ex.errno != errno.EEXIST: + raise + + +def install_library(args, path_map): + """Create an installation directory a python library.""" + out_dir = args.output + out_manifest = args.output + ".manifest" + + install_dir = args.install_dir + if not install_dir: + install_dir = out_dir + + os.makedirs(out_dir) + with open(out_manifest, "w") as manifest: + manifest.write(MANIFEST_HEADER_V1) + for info in path_map.values(): + abs_dest = os.path.join(out_dir, info.dest) + ensure_directory(os.path.dirname(abs_dest)) + print("copy %r --> %r" % (info.src, abs_dest)) + shutil.copy2(info.src, abs_dest) + installed_dest = os.path.join(install_dir, info.dest) + manifest.write("%s%s%s\n" % (installed_dest, MANIFEST_SEPARATOR, info.dest)) + + +def parse_manifests(args): + # Process args.manifest_separator to help support older versions of CMake + if args.manifest_separator: + manifests = [] + for manifest_arg in args.manifests: + split_arg = manifest_arg.split(args.manifest_separator) + manifests.extend(split_arg) + args.manifests = manifests + + path_map = {} + for manifest in args.manifests: + parse_manifest(manifest, path_map) + + return path_map + + +def check_main_module(args, path_map): + # Translate an empty string in the --main argument to None, + # just to allow the CMake logic to be slightly simpler and pass in an + # empty string when it really wants the default __main__.py module to be + # used. + if args.main == "": + args.main = None + + if args.type == "lib-install": + if args.main is not None: + raise UsageError("cannot specify a --main argument with --type=lib-install") + return + + main_info = path_map.get("__main__.py") + if args.main: + if main_info is not None: + msg = ( + "specified an explicit main module with --main, " + "but the file listing already includes __main__.py" + ) + raise BadManifestError( + main_info.manifest_path, main_info.manifest_line, msg + ) + parts = args.main.split(":") + if len(parts) != 2: + raise UsageError( + "argument to --main must be of the form MODULE:CALLABLE " + "(received %s)" % (args.main,) + ) + else: + if main_info is None: + raise UsageError( + "no main module specified with --main, " + "and no __main__.py module present" + ) + + +BUILD_TYPES = { + "zipapp": build_zipapp, + "dir": build_install_dir, + "lib-install": install_library, +} + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("-o", "--output", required=True, help="The output file path") + ap.add_argument( + "--install-dir", + help="When used with --type=lib-install, this parameter specifies the " + "final location where the library where be installed. This can be " + "used to generate the library in one directory first, when you plan " + "to move or copy it to another final location later.", + ) + ap.add_argument( + "--manifest-separator", + help="Split manifest arguments around this separator. This is used " + "to support older versions of CMake that cannot supply the manifests " + "as separate arguments.", + ) + ap.add_argument( + "--main", + help="The main module to run, specified as :. " + "This must be specified if and only if the archive does not contain " + "a __main__.py file.", + ) + ap.add_argument( + "--python", + help="Explicitly specify the python interpreter to use for the " "executable.", + ) + ap.add_argument( + "--type", choices=BUILD_TYPES.keys(), help="The type of output to build." + ) + ap.add_argument( + "manifests", + nargs="+", + help="The manifest files specifying how to construct the archive", + ) + args = ap.parse_args() + + if args.python is None: + args.python = sys.executable + + if args.type is None: + # In the future we might want different default output types + # for different platforms. + args.type = "zipapp" + build_fn = BUILD_TYPES[args.type] + + try: + path_map = parse_manifests(args) + check_main_module(args, path_map) + except UsageError as ex: + print("error: %s" % (ex,), file=sys.stderr) + sys.exit(1) + + build_fn(args, path_map) + + +if __name__ == "__main__": + main() diff --git a/build/fbcode_builder/LICENSE b/build/fbcode_builder/LICENSE new file mode 100644 index 000000000..b96dcb048 --- /dev/null +++ b/build/fbcode_builder/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/build/fbcode_builder/README.docker b/build/fbcode_builder/README.docker new file mode 100644 index 000000000..4e9fa8a29 --- /dev/null +++ b/build/fbcode_builder/README.docker @@ -0,0 +1,44 @@ +## Debugging Docker builds + +To debug a a build failure, start up a shell inside the just-failed image as +follows: + +``` +docker ps -a | head # Grab the container ID +docker commit CONTAINER_ID # Grab the SHA string +docker run -it SHA_STRING /bin/bash +# Debug as usual, e.g. `./run-cmake.sh Debug`, `make`, `apt-get install gdb` +``` + +## A note on Docker security + +While the Dockerfile generated above is quite simple, you must be aware that +using Docker to run arbitrary code can present significant security risks: + + - Code signature validation is off by default (as of 2016), exposing you to + man-in-the-middle malicious code injection. + + - You implicitly trust the world -- a Dockerfile cannot annotate that + you trust the image `debian:8.6` because you trust a particular + certificate -- rather, you trust the name, and that it will never be + hijacked. + + - Sandboxing in the Linux kernel is not perfect, and the builds run code as + root. Any compromised code can likely escalate to the host system. + +Specifically, you must be very careful only to add trusted OS images to the +build flow. + +Consider setting this variable before running any Docker container -- this +will validate a signature on the base image before running code from it: + +``` +export DOCKER_CONTENT_TRUST=1 +``` + +Note that unless you go through the extra steps of notarizing the resulting +images, you will have to disable trust to enter intermediate images, e.g. + +``` +DOCKER_CONTENT_TRUST= docker run -it YOUR_IMAGE_ID /bin/bash +``` diff --git a/build/fbcode_builder/README.md b/build/fbcode_builder/README.md new file mode 100644 index 000000000..d47dd41c0 --- /dev/null +++ b/build/fbcode_builder/README.md @@ -0,0 +1,43 @@ +# Easy builds for Facebook projects + +This directory contains tools designed to simplify continuous-integration +(and other builds) of Facebook open source projects. In particular, this helps +manage builds for cross-project dependencies. + +The main entry point is the `getdeps.py` script. This script has several +subcommands, but the most notable is the `build` command. This will download +and build all dependencies for a project, and then build the project itself. + +## Deployment + +This directory is copied literally into a number of different Facebook open +source repositories. Any change made to code in this directory will be +automatically be replicated by our open source tooling into all GitHub hosted +repositories that use `fbcode_builder`. Typically this directory is copied +into the open source repositories as `build/fbcode_builder/`. + + +# Project Configuration Files + +The `manifests` subdirectory contains configuration files for many different +projects, describing how to build each project. These files also list +dependencies between projects, enabling `getdeps.py` to build all dependencies +for a project before building the project itself. + + +# Shared CMake utilities + +Since this directory is copied into many Facebook open source repositories, +it is also used to help share some CMake utility files across projects. The +`CMake/` subdirectory contains a number of `.cmake` files that are shared by +the CMake-based build systems across several different projects. + + +# Older Build Scripts + +This directory also still contains a handful of older build scripts that +pre-date the current `getdeps.py` build system. Most of the other `.py` files +in this top directory, apart from `getdeps.py` itself, are from this older +build system. This older system is only used by a few remaining projects, and +new projects should generally use the newer `getdeps.py` script, by adding a +new configuration file in the `manifests/` subdirectory. diff --git a/build/fbcode_builder/docker_build_with_ccache.sh b/build/fbcode_builder/docker_build_with_ccache.sh new file mode 100755 index 000000000..e922810d5 --- /dev/null +++ b/build/fbcode_builder/docker_build_with_ccache.sh @@ -0,0 +1,219 @@ +#!/bin/bash -uex +# Copyright (c) Facebook, Inc. and its affiliates. +set -o pipefail # Be sure to `|| :` commands that are allowed to fail. + +# +# Future: port this to Python if you are making significant changes. +# + +# Parse command-line arguments +build_timeout="" # Default to no time-out +print_usage() { + echo "Usage: $0 [--build-timeout TIMEOUT_VAL] SAVE-CCACHE-TO-DIR" + echo "SAVE-CCACHE-TO-DIR is required. An empty string discards the ccache." +} +while [[ $# -gt 0 ]]; do + case "$1" in + --build-timeout) + shift + build_timeout="$1" + if [[ "$build_timeout" != "" ]] ; then + timeout "$build_timeout" true # fail early on invalid timeouts + fi + ;; + -h|--help) + print_usage + exit + ;; + *) + break + ;; + esac + shift +done +# There is one required argument, but an empty string is allowed. +if [[ "$#" != 1 ]] ; then + print_usage + exit 1 +fi +save_ccache_to_dir="$1" +if [[ "$save_ccache_to_dir" != "" ]] ; then + mkdir -p "$save_ccache_to_dir" # fail early if there's nowhere to save +else + echo "WARNING: Will not save /ccache from inside the Docker container" +fi + +rand_guid() { + echo "$(date +%s)_${RANDOM}_${RANDOM}_${RANDOM}_${RANDOM}" +} + +id=fbcode_builder_image_id=$(rand_guid) +logfile=$(mktemp) + +echo " + + +Running build with timeout '$build_timeout', label $id, and log in $logfile + + +" + +if [[ "$build_timeout" != "" ]] ; then + # Kill the container after $build_timeout. Using `/bin/timeout` would cause + # Docker to destroy the most recent container and lose its cache. + ( + sleep "$build_timeout" + echo "Build timed out after $build_timeout" 1>&2 + while true; do + maybe_container=$( + grep -E '^( ---> Running in [0-9a-f]+|FBCODE_BUILDER_EXIT)$' "$logfile" | + tail -n 1 | awk '{print $NF}' + ) + if [[ "$maybe_container" == "FBCODE_BUILDER_EXIT" ]] ; then + echo "Time-out successfully terminated build" 1>&2 + break + fi + echo "Time-out: trying to kill $maybe_container" 1>&2 + # This kill fail if we get unlucky, try again soon. + docker kill "$maybe_container" || sleep 5 + done + ) & +fi + +build_exit_code=0 +# `docker build` is allowed to fail, and `pipefail` means we must check the +# failure explicitly. +if ! docker build --label="$id" . 2>&1 | tee "$logfile" ; then + build_exit_code="${PIPESTATUS[0]}" + # NB: We are going to deliberately forge ahead even if `tee` failed. + # If it did, we have a problem with tempfile creation, and all is sad. + echo "Build failed with code $build_exit_code, trying to save ccache" 1>&2 +fi +# Stop trying to kill the container. +echo $'\nFBCODE_BUILDER_EXIT' >> "$logfile" + +if [[ "$save_ccache_to_dir" == "" ]] ; then + echo "Not inspecting Docker build, since saving the ccache wasn't requested." + exit "$build_exit_code" +fi + +img=$(docker images --filter "label=$id" -a -q) +if [[ "$img" == "" ]] ; then + docker images -a + echo "In the above list, failed to find most recent image with $id" 1>&2 + # Usually, the above `docker kill` will leave us with an up-to-the-second + # container, from which we can extract the cache. However, if that fails + # for any reason, this loop will instead grab the latest available image. + # + # It's possible for this log search to get confused due to the output of + # the build command itself, but since our builds aren't **trying** to + # break cache, we probably won't randomly hit an ID from another build. + img=$( + grep -E '^ ---> (Running in [0-9a-f]+|[0-9a-f]+)$' "$logfile" | tac | + sed 's/Running in /container_/;s/ ---> //;' | ( + while read -r x ; do + # Both docker commands below print an image ID to stdout on + # success, so we just need to know when to stop. + if [[ "$x" =~ container_.* ]] ; then + if docker commit "${x#container_}" ; then + break + fi + elif docker inspect --type image -f '{{.Id}}' "$x" ; then + break + fi + done + ) + ) + if [[ "$img" == "" ]] ; then + echo "Failed to find valid container or image ID in log $logfile" 1>&2 + exit 1 + fi +elif [[ "$(echo "$img" | wc -l)" != 1 ]] ; then + # Shouldn't really happen, but be explicit if it does. + echo "Multiple images with label $id, taking the latest of:" + echo "$img" + img=$(echo "$img" | head -n 1) +fi + +container_name="fbcode_builder_container_$(rand_guid)" +echo "Starting $container_name from latest image of the build with $id --" +echo "$img" + +# ccache collection must be done outside of the Docker build steps because +# we need to be able to kill it on timeout. +# +# This step grows the max cache size to slightly exceed than the working set +# of a successful build. This simple design persists the max size in the +# cache directory itself (the env var CCACHE_MAXSIZE does not even work with +# older ccaches like the one on 14.04). +# +# Future: copy this script into the Docker image via Dockerfile. +( + # By default, fbcode_builder creates an unsigned image, so the `docker + # run` below would fail if DOCKER_CONTENT_TRUST were set. So we unset it + # just for this one run. + export DOCKER_CONTENT_TRUST= + # CAUTION: The inner bash runs without -uex, so code accordingly. + docker run --user root --name "$container_name" "$img" /bin/bash -c ' + build_exit_code='"$build_exit_code"' + + # Might be useful if debugging whether max cache size is too small? + grep " Cleaning up cache directory " /tmp/ccache.log + + export CCACHE_DIR=/ccache + ccache -s + + echo "Total bytes in /ccache:"; + total_bytes=$(du -sb /ccache | awk "{print \$1}") + echo "$total_bytes" + + echo "Used bytes in /ccache:"; + used_bytes=$( + du -sb $(find /ccache -type f -newermt @$( + cat /FBCODE_BUILDER_CCACHE_START_TIME + )) | awk "{t += \$1} END {print t}" + ) + echo "$used_bytes" + + # Goal: set the max cache to 750MB over 125% of the usage of a + # successful build. If this is too small, it takes too long to get a + # cache fully warmed up. Plus, ccache cleans 100-200MB before reaching + # the max cache size, so a large margin is essential to prevent misses. + desired_mb=$(( 750 + used_bytes / 800000 )) # 125% in decimal MB: 1e6/1.25 + if [[ "$build_exit_code" != "0" ]] ; then + # For a bad build, disallow shrinking the max cache size. Instead of + # the max cache size, we use on-disk size, which ccache keeps at least + # 150MB under the actual max size, hence the 400MB safety margin. + cur_max_mb=$(( 400 + total_bytes / 1000000 )) # ccache uses decimal MB + if [[ "$desired_mb" -le "$cur_max_mb" ]] ; then + desired_mb="" + fi + fi + + if [[ "$desired_mb" != "" ]] ; then + echo "Updating cache size to $desired_mb MB" + ccache -M "${desired_mb}M" + ccache -s + fi + + # Subshell because `time` the binary may not be installed. + if (time tar czf /ccache.tgz /ccache) ; then + ls -l /ccache.tgz + else + # This `else` ensures we never overwrite the current cache with + # partial data in case of error, even if somebody adds code below. + rm /ccache.tgz + exit 1 + fi + ' +) + +echo "Updating $save_ccache_to_dir/ccache.tgz" +# This will not delete the existing cache if `docker run` didn't make one +docker cp "$container_name:/ccache.tgz" "$save_ccache_to_dir/" + +# Future: it'd be nice if Travis allowed us to retry if the build timed out, +# since we'll make more progress thanks to the cache. As-is, we have to +# wait for the next commit to land. +echo "Build exited with code $build_exit_code" +exit "$build_exit_code" diff --git a/build/fbcode_builder/docker_builder.py b/build/fbcode_builder/docker_builder.py new file mode 100644 index 000000000..83df7137c --- /dev/null +++ b/build/fbcode_builder/docker_builder.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" + +Extends FBCodeBuilder to produce Docker context directories. + +In order to get the largest iteration-time savings from Docker's build +caching, you will want to: + - Use fine-grained steps as appropriate (e.g. separate make & make install), + - Start your action sequence with the lowest-risk steps, and with the steps + that change the least often, and + - Put the steps that you are debugging towards the very end. + +""" +import logging +import os +import shutil +import tempfile + +from fbcode_builder import FBCodeBuilder +from shell_quoting import raw_shell, shell_comment, shell_join, ShellQuoted, path_join +from utils import recursively_flatten_list, run_command + + +class DockerFBCodeBuilder(FBCodeBuilder): + def _user(self): + return self.option("user", "root") + + def _change_user(self): + return ShellQuoted("USER {u}").format(u=self._user()) + + def setup(self): + # Please add RPM-based OSes here as appropriate. + # + # To allow exercising non-root installs -- we change users after the + # system packages are installed. TODO: For users not defined in the + # image, we should probably `useradd`. + return self.step( + "Setup", + [ + # Docker's FROM does not understand shell quoting. + ShellQuoted("FROM {}".format(self.option("os_image"))), + # /bin/sh syntax is a pain + ShellQuoted('SHELL ["/bin/bash", "-c"]'), + ] + + self.install_debian_deps() + + [self._change_user()] + + [self.workdir(self.option("prefix"))] + + self.create_python_venv() + + self.python_venv() + + self.rust_toolchain(), + ) + + def python_venv(self): + # To both avoid calling venv activate on each RUN command AND to ensure + # it is present when the resulting container is run add to PATH + actions = [] + if self.option("PYTHON_VENV", "OFF") == "ON": + actions = ShellQuoted("ENV PATH={p}:$PATH").format( + p=path_join(self.option("prefix"), "venv", "bin") + ) + return actions + + def step(self, name, actions): + assert "\n" not in name, "Name {0} would span > 1 line".format(name) + b = ShellQuoted("") + return [ShellQuoted("### {0} ###".format(name)), b] + actions + [b] + + def run(self, shell_cmd): + return ShellQuoted("RUN {cmd}").format(cmd=shell_cmd) + + def set_env(self, key, value): + return ShellQuoted("ENV {key}={val}").format(key=key, val=value) + + def workdir(self, dir): + return [ + # As late as Docker 1.12.5, this results in `build` being owned + # by root:root -- the explicit `mkdir` works around the bug: + # USER nobody + # WORKDIR build + ShellQuoted("USER root"), + ShellQuoted("RUN mkdir -p {d} && chown {u} {d}").format( + d=dir, u=self._user() + ), + self._change_user(), + ShellQuoted("WORKDIR {dir}").format(dir=dir), + ] + + def comment(self, comment): + # This should not be a command since we don't want comment changes + # to invalidate the Docker build cache. + return shell_comment(comment) + + def copy_local_repo(self, repo_dir, dest_name): + fd, archive_path = tempfile.mkstemp( + prefix="local_repo_{0}_".format(dest_name), + suffix=".tgz", + dir=os.path.abspath(self.option("docker_context_dir")), + ) + os.close(fd) + run_command("tar", "czf", archive_path, ".", cwd=repo_dir) + return [ + ShellQuoted("ADD {archive} {dest_name}").format( + archive=os.path.basename(archive_path), dest_name=dest_name + ), + # Docker permissions make very little sense... see also workdir() + ShellQuoted("USER root"), + ShellQuoted("RUN chown -R {u} {d}").format(d=dest_name, u=self._user()), + self._change_user(), + ] + + def _render_impl(self, steps): + return raw_shell(shell_join("\n", recursively_flatten_list(steps))) + + def debian_ccache_setup_steps(self): + source_ccache_tgz = self.option("ccache_tgz", "") + if not source_ccache_tgz: + logging.info("Docker ccache not enabled") + return [] + + dest_ccache_tgz = os.path.join(self.option("docker_context_dir"), "ccache.tgz") + + try: + try: + os.link(source_ccache_tgz, dest_ccache_tgz) + except OSError: + logging.exception( + "Hard-linking {s} to {d} failed, falling back to copy".format( + s=source_ccache_tgz, d=dest_ccache_tgz + ) + ) + shutil.copyfile(source_ccache_tgz, dest_ccache_tgz) + except Exception: + logging.exception( + "Failed to copy or link {s} to {d}, aborting".format( + s=source_ccache_tgz, d=dest_ccache_tgz + ) + ) + raise + + return [ + # Separate layer so that in development we avoid re-downloads. + self.run(ShellQuoted("apt-get install -yq ccache")), + ShellQuoted("ADD ccache.tgz /"), + ShellQuoted( + # Set CCACHE_DIR before the `ccache` invocations below. + "ENV CCACHE_DIR=/ccache " + # No clang support for now, so it's easiest to hardcode gcc. + 'CC="ccache gcc" CXX="ccache g++" ' + # Always log for ease of debugging. For real FB projects, + # this log is several megabytes, so dumping it to stdout + # would likely exceed the Travis log limit of 4MB. + # + # On a local machine, `docker cp` will get you the data. To + # get the data out from Travis, I would compress and dump + # uuencoded bytes to the log -- for Bistro this was about + # 600kb or 8000 lines: + # + # apt-get install sharutils + # bzip2 -9 < /tmp/ccache.log | uuencode -m ccache.log.bz2 + "CCACHE_LOGFILE=/tmp/ccache.log" + ), + self.run( + ShellQuoted( + # Future: Skipping this part made this Docker step instant, + # saving ~1min of build time. It's unclear if it is the + # chown or the du, but probably the chown -- since a large + # part of the cost is incurred at image save time. + # + # ccache.tgz may be empty, or may have the wrong + # permissions. + "mkdir -p /ccache && time chown -R nobody /ccache && " + "time du -sh /ccache && " + # Reset stats so `docker_build_with_ccache.sh` can print + # useful values at the end of the run. + "echo === Prev run stats === && ccache -s && ccache -z && " + # Record the current time to let travis_build.sh figure out + # the number of bytes in the cache that are actually used -- + # this is crucial for tuning the maximum cache size. + "date +%s > /FBCODE_BUILDER_CCACHE_START_TIME && " + # The build running as `nobody` should be able to write here + "chown nobody /tmp/ccache.log" + ) + ), + ] diff --git a/build/fbcode_builder/docker_enable_ipv6.sh b/build/fbcode_builder/docker_enable_ipv6.sh new file mode 100755 index 000000000..3752f6f5e --- /dev/null +++ b/build/fbcode_builder/docker_enable_ipv6.sh @@ -0,0 +1,13 @@ +#!/bin/sh +# Copyright (c) Facebook, Inc. and its affiliates. + + +# `daemon.json` is normally missing, but let's log it in case that changes. +touch /etc/docker/daemon.json +service docker stop +echo '{"ipv6": true, "fixed-cidr-v6": "2001:db8:1::/64"}' > /etc/docker/daemon.json +service docker start +# Fail early if docker failed on start -- add `- sudo dockerd` to debug. +docker info +# Paranoia log: what if our config got overwritten? +cat /etc/docker/daemon.json diff --git a/build/fbcode_builder/fbcode_builder.py b/build/fbcode_builder/fbcode_builder.py new file mode 100644 index 000000000..742099321 --- /dev/null +++ b/build/fbcode_builder/fbcode_builder.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" + +This is a small DSL to describe builds of Facebook's open-source projects +that are published to Github from a single internal repo, including projects +that depend on folly, wangle, proxygen, fbthrift, etc. + +This file defines the interface of the DSL, and common utilieis, but you +will have to instantiate a specific builder, with specific options, in +order to get work done -- see e.g. make_docker_context.py. + +== Design notes == + +Goals: + + - A simple declarative language for what needs to be checked out & built, + how, in what order. + + - The same specification should work for external continuous integration + builds (e.g. Travis + Docker) and for internal VM-based continuous + integration builds. + + - One should be able to build without root, and to install to a prefix. + +Non-goals: + + - General usefulness. The only point of this is to make it easier to build + and test Facebook's open-source services. + +Ideas for the future -- these may not be very good :) + + - Especially on Ubuntu 14.04 the current initial setup is inefficient: + we add PPAs after having installed a bunch of packages -- this prompts + reinstalls of large amounts of code. We also `apt-get update` a few + times. + + - A "shell script" builder. Like DockerFBCodeBuilder, but outputs a + shell script that runs outside of a container. Or maybe even + synchronously executes the shell commands, `make`-style. + + - A "Makefile" generator. That might make iterating on builds even quicker + than what you can currently get with Docker build caching. + + - Generate a rebuild script that can be run e.g. inside the built Docker + container by tagging certain steps with list-inheriting Python objects: + * do change directories + * do NOT `git clone` -- if we want to update code this should be a + separate script that e.g. runs rebase on top of specific targets + across all the repos. + * do NOT install software (most / all setup can be skipped) + * do NOT `autoreconf` or `configure` + * do `make` and `cmake` + + - If we get non-Debian OSes, part of ccache setup should be factored out. +""" + +import os +import re + +from shell_quoting import path_join, shell_join, ShellQuoted + + +def _read_project_github_hashes(): + base_dir = "deps/github_hashes/" # trailing slash used in regex below + for dirname, _, files in os.walk(base_dir): + for filename in files: + path = os.path.join(dirname, filename) + with open(path) as f: + m_proj = re.match("^" + base_dir + "(.*)-rev\.txt$", path) + if m_proj is None: + raise RuntimeError("Not a hash file? {0}".format(path)) + m_hash = re.match("^Subproject commit ([0-9a-f]+)\n$", f.read()) + if m_hash is None: + raise RuntimeError("No hash in {0}".format(path)) + yield m_proj.group(1), m_hash.group(1) + + +class FBCodeBuilder(object): + def __init__(self, **kwargs): + self._options_do_not_access = kwargs # Use .option() instead. + # This raises upon detecting options that are specified but unused, + # because otherwise it is very easy to make a typo in option names. + self.options_used = set() + # Mark 'projects_dir' used even if the build installs no github + # projects. This is needed because driver programs like + # `shell_builder.py` unconditionally set this for all builds. + self._github_dir = self.option("projects_dir") + self._github_hashes = dict(_read_project_github_hashes()) + + def __repr__(self): + return "{0}({1})".format( + self.__class__.__name__, + ", ".join( + "{0}={1}".format(k, repr(v)) + for k, v in self._options_do_not_access.items() + ), + ) + + def option(self, name, default=None): + value = self._options_do_not_access.get(name, default) + if value is None: + raise RuntimeError("Option {0} is required".format(name)) + self.options_used.add(name) + return value + + def has_option(self, name): + return name in self._options_do_not_access + + def add_option(self, name, value): + if name in self._options_do_not_access: + raise RuntimeError("Option {0} already set".format(name)) + self._options_do_not_access[name] = value + + # + # Abstract parts common to every installation flow + # + + def render(self, steps): + """ + + Converts nested actions to your builder's expected output format. + Typically takes the output of build(). + + """ + res = self._render_impl(steps) # Implementation-dependent + # Now that the output is rendered, we expect all options to have + # been used. + unused_options = set(self._options_do_not_access) + unused_options -= self.options_used + if unused_options: + raise RuntimeError( + "Unused options: {0} -- please check if you made a typo " + "in any of them. Those that are truly not useful should " + "be not be set so that this typo detection can be useful.".format( + unused_options + ) + ) + return res + + def build(self, steps): + if not steps: + raise RuntimeError( + "Please ensure that the config you are passing " "contains steps" + ) + return [self.setup(), self.diagnostics()] + steps + + def setup(self): + "Your builder may want to install packages here." + raise NotImplementedError + + def diagnostics(self): + "Log some system diagnostics before/after setup for ease of debugging" + # The builder's repr is not used in a command to avoid pointlessly + # invalidating Docker's build cache. + return self.step( + "Diagnostics", + [ + self.comment("Builder {0}".format(repr(self))), + self.run(ShellQuoted("hostname")), + self.run(ShellQuoted("cat /etc/issue || echo no /etc/issue")), + self.run(ShellQuoted("g++ --version || echo g++ not installed")), + self.run(ShellQuoted("cmake --version || echo cmake not installed")), + ], + ) + + def step(self, name, actions): + "A labeled collection of actions or other steps" + raise NotImplementedError + + def run(self, shell_cmd): + "Run this bash command" + raise NotImplementedError + + def set_env(self, key, value): + 'Set the environment "key" to value "value"' + raise NotImplementedError + + def workdir(self, dir): + "Create this directory if it does not exist, and change into it" + raise NotImplementedError + + def copy_local_repo(self, dir, dest_name): + """ + Copy the local repo at `dir` into this step's `workdir()`, analog of: + cp -r /path/to/folly folly + """ + raise NotImplementedError + + def python_deps(self): + return [ + "wheel", + "cython==0.28.6", + ] + + def debian_deps(self): + return [ + "autoconf-archive", + "bison", + "build-essential", + "cmake", + "curl", + "flex", + "git", + "gperf", + "joe", + "libboost-all-dev", + "libcap-dev", + "libdouble-conversion-dev", + "libevent-dev", + "libgflags-dev", + "libgoogle-glog-dev", + "libkrb5-dev", + "libpcre3-dev", + "libpthread-stubs0-dev", + "libnuma-dev", + "libsasl2-dev", + "libsnappy-dev", + "libsqlite3-dev", + "libssl-dev", + "libtool", + "netcat-openbsd", + "pkg-config", + "sudo", + "unzip", + "wget", + "python3-venv", + ] + + # + # Specific build helpers + # + + def install_debian_deps(self): + actions = [ + self.run( + ShellQuoted("apt-get update && apt-get install -yq {deps}").format( + deps=shell_join( + " ", (ShellQuoted(dep) for dep in self.debian_deps()) + ) + ) + ), + ] + gcc_version = self.option("gcc_version") + + # Make the selected GCC the default before building anything + actions.extend( + [ + self.run( + ShellQuoted("apt-get install -yq {c} {cpp}").format( + c=ShellQuoted("gcc-{v}").format(v=gcc_version), + cpp=ShellQuoted("g++-{v}").format(v=gcc_version), + ) + ), + self.run( + ShellQuoted( + "update-alternatives --install /usr/bin/gcc gcc {c} 40 " + "--slave /usr/bin/g++ g++ {cpp}" + ).format( + c=ShellQuoted("/usr/bin/gcc-{v}").format(v=gcc_version), + cpp=ShellQuoted("/usr/bin/g++-{v}").format(v=gcc_version), + ) + ), + self.run(ShellQuoted("update-alternatives --config gcc")), + ] + ) + + actions.extend(self.debian_ccache_setup_steps()) + + return self.step("Install packages for Debian-based OS", actions) + + def create_python_venv(self): + actions = [] + if self.option("PYTHON_VENV", "OFF") == "ON": + actions.append( + self.run( + ShellQuoted("python3 -m venv {p}").format( + p=path_join(self.option("prefix"), "venv") + ) + ) + ) + return actions + + def python_venv(self): + actions = [] + if self.option("PYTHON_VENV", "OFF") == "ON": + actions.append( + ShellQuoted("source {p}").format( + p=path_join(self.option("prefix"), "venv", "bin", "activate") + ) + ) + + actions.append( + self.run( + ShellQuoted("python3 -m pip install {deps}").format( + deps=shell_join( + " ", (ShellQuoted(dep) for dep in self.python_deps()) + ) + ) + ) + ) + return actions + + def enable_rust_toolchain(self, toolchain="stable", is_bootstrap=True): + choices = set(["stable", "beta", "nightly"]) + + assert toolchain in choices, ( + "while enabling rust toolchain: {} is not in {}" + ).format(toolchain, choices) + + rust_toolchain_opt = (toolchain, is_bootstrap) + prev_opt = self.option("rust_toolchain", rust_toolchain_opt) + assert prev_opt == rust_toolchain_opt, ( + "while enabling rust toolchain: previous toolchain already set to" + " {}, but trying to set it to {} now" + ).format(prev_opt, rust_toolchain_opt) + + self.add_option("rust_toolchain", rust_toolchain_opt) + + def rust_toolchain(self): + actions = [] + if self.option("rust_toolchain", False): + (toolchain, is_bootstrap) = self.option("rust_toolchain") + rust_dir = path_join(self.option("prefix"), "rust") + actions = [ + self.set_env("CARGO_HOME", rust_dir), + self.set_env("RUSTUP_HOME", rust_dir), + self.set_env("RUSTC_BOOTSTRAP", "1" if is_bootstrap else "0"), + self.run( + ShellQuoted( + "curl -sSf https://build.travis-ci.com/files/rustup-init.sh" + " | sh -s --" + " --default-toolchain={r} " + " --profile=minimal" + " --no-modify-path" + " -y" + ).format(p=rust_dir, r=toolchain) + ), + self.set_env( + "PATH", + ShellQuoted("{p}:$PATH").format(p=path_join(rust_dir, "bin")), + ), + self.run(ShellQuoted("rustup update")), + self.run(ShellQuoted("rustc --version")), + self.run(ShellQuoted("rustup --version")), + self.run(ShellQuoted("cargo --version")), + ] + return actions + + def debian_ccache_setup_steps(self): + return [] # It's ok to ship a renderer without ccache support. + + def github_project_workdir(self, project, path): + # Only check out a non-default branch if requested. This especially + # makes sense when building from a local repo. + git_hash = self.option( + "{0}:git_hash".format(project), + # Any repo that has a hash in deps/github_hashes defaults to + # that, with the goal of making builds maximally consistent. + self._github_hashes.get(project, ""), + ) + maybe_change_branch = ( + [ + self.run(ShellQuoted("git checkout {hash}").format(hash=git_hash)), + ] + if git_hash + else [] + ) + + local_repo_dir = self.option("{0}:local_repo_dir".format(project), "") + return self.step( + "Check out {0}, workdir {1}".format(project, path), + [ + self.workdir(self._github_dir), + self.run( + ShellQuoted("git clone {opts} https://github.com/{p}").format( + p=project, + opts=ShellQuoted( + self.option("{}:git_clone_opts".format(project), "") + ), + ) + ) + if not local_repo_dir + else self.copy_local_repo(local_repo_dir, os.path.basename(project)), + self.workdir( + path_join(self._github_dir, os.path.basename(project), path), + ), + ] + + maybe_change_branch, + ) + + def fb_github_project_workdir(self, project_and_path, github_org="facebook"): + "This helper lets Facebook-internal CI special-cases FB projects" + project, path = project_and_path.split("/", 1) + return self.github_project_workdir(github_org + "/" + project, path) + + def _make_vars(self, make_vars): + return shell_join( + " ", + ( + ShellQuoted("{k}={v}").format(k=k, v=v) + for k, v in ({} if make_vars is None else make_vars).items() + ), + ) + + def parallel_make(self, make_vars=None): + return self.run( + ShellQuoted("make -j {n} VERBOSE=1 {vars}").format( + n=self.option("make_parallelism"), + vars=self._make_vars(make_vars), + ) + ) + + def make_and_install(self, make_vars=None): + return [ + self.parallel_make(make_vars), + self.run( + ShellQuoted("make install VERBOSE=1 {vars}").format( + vars=self._make_vars(make_vars), + ) + ), + ] + + def configure(self, name=None): + autoconf_options = {} + if name is not None: + autoconf_options.update( + self.option("{0}:autoconf_options".format(name), {}) + ) + return [ + self.run( + ShellQuoted( + 'LDFLAGS="$LDFLAGS -L"{p}"/lib -Wl,-rpath="{p}"/lib" ' + 'CFLAGS="$CFLAGS -I"{p}"/include" ' + 'CPPFLAGS="$CPPFLAGS -I"{p}"/include" ' + "PY_PREFIX={p} " + "./configure --prefix={p} {args}" + ).format( + p=self.option("prefix"), + args=shell_join( + " ", + ( + ShellQuoted("{k}={v}").format(k=k, v=v) + for k, v in autoconf_options.items() + ), + ), + ) + ), + ] + + def autoconf_install(self, name): + return self.step( + "Build and install {0}".format(name), + [ + self.run(ShellQuoted("autoreconf -ivf")), + ] + + self.configure() + + self.make_and_install(), + ) + + def cmake_configure(self, name, cmake_path=".."): + cmake_defines = { + "BUILD_SHARED_LIBS": "ON", + "CMAKE_INSTALL_PREFIX": self.option("prefix"), + } + + # Hacks to add thriftpy3 support + if "BUILD_THRIFT_PY3" in os.environ and "folly" in name: + cmake_defines["PYTHON_EXTENSIONS"] = "True" + + if "BUILD_THRIFT_PY3" in os.environ and "fbthrift" in name: + cmake_defines["thriftpy3"] = "ON" + + cmake_defines.update(self.option("{0}:cmake_defines".format(name), {})) + return [ + self.run( + ShellQuoted( + 'CXXFLAGS="$CXXFLAGS -fPIC -isystem "{p}"/include" ' + 'CFLAGS="$CFLAGS -fPIC -isystem "{p}"/include" ' + "cmake {args} {cmake_path}" + ).format( + p=self.option("prefix"), + args=shell_join( + " ", + ( + ShellQuoted("-D{k}={v}").format(k=k, v=v) + for k, v in cmake_defines.items() + ), + ), + cmake_path=cmake_path, + ) + ), + ] + + def cmake_install(self, name, cmake_path=".."): + return self.step( + "Build and install {0}".format(name), + self.cmake_configure(name, cmake_path) + self.make_and_install(), + ) + + def cargo_build(self, name): + return self.step( + "Build {0}".format(name), + [ + self.run( + ShellQuoted("cargo build -j {n}").format( + n=self.option("make_parallelism") + ) + ) + ], + ) + + def fb_github_autoconf_install(self, project_and_path, github_org="facebook"): + return [ + self.fb_github_project_workdir(project_and_path, github_org), + self.autoconf_install(project_and_path), + ] + + def fb_github_cmake_install( + self, project_and_path, cmake_path="..", github_org="facebook" + ): + return [ + self.fb_github_project_workdir(project_and_path, github_org), + self.cmake_install(project_and_path, cmake_path), + ] + + def fb_github_cargo_build(self, project_and_path, github_org="facebook"): + return [ + self.fb_github_project_workdir(project_and_path, github_org), + self.cargo_build(project_and_path), + ] diff --git a/build/fbcode_builder/fbcode_builder_config.py b/build/fbcode_builder/fbcode_builder_config.py new file mode 100644 index 000000000..5ba6e607a --- /dev/null +++ b/build/fbcode_builder/fbcode_builder_config.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +"Demo config, so that `make_docker_context.py --help` works in this directory." + +config = { + "fbcode_builder_spec": lambda _builder: { + "depends_on": [], + "steps": [], + }, + "github_project": "demo/project", +} diff --git a/build/fbcode_builder/getdeps.py b/build/fbcode_builder/getdeps.py new file mode 100755 index 000000000..1b539735f --- /dev/null +++ b/build/fbcode_builder/getdeps.py @@ -0,0 +1,1071 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import json +import os +import shutil +import subprocess +import sys +import tarfile +import tempfile + +# We don't import cache.create_cache directly as the facebook +# specific import below may monkey patch it, and we want to +# observe the patched version of this function! +import getdeps.cache as cache_module +from getdeps.buildopts import setup_build_options +from getdeps.dyndeps import create_dyn_dep_munger +from getdeps.errors import TransientFailure +from getdeps.fetcher import ( + SystemPackageFetcher, + file_name_is_cmake_file, + list_files_under_dir_newer_than_timestamp, +) +from getdeps.load import ManifestLoader +from getdeps.manifest import ManifestParser +from getdeps.platform import HostType +from getdeps.runcmd import run_cmd +from getdeps.subcmd import SubCmd, add_subcommands, cmd + + +try: + import getdeps.facebook # noqa: F401 +except ImportError: + # we don't ship the facebook specific subdir, + # so allow that to fail silently + pass + + +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "getdeps")) + + +class UsageError(Exception): + pass + + +@cmd("validate-manifest", "parse a manifest and validate that it is correct") +class ValidateManifest(SubCmd): + def run(self, args): + try: + ManifestParser(file_name=args.file_name) + print("OK", file=sys.stderr) + return 0 + except Exception as exc: + print("ERROR: %s" % str(exc), file=sys.stderr) + return 1 + + def setup_parser(self, parser): + parser.add_argument("file_name", help="path to the manifest file") + + +@cmd("show-host-type", "outputs the host type tuple for the host machine") +class ShowHostType(SubCmd): + def run(self, args): + host = HostType() + print("%s" % host.as_tuple_string()) + return 0 + + +class ProjectCmdBase(SubCmd): + def run(self, args): + opts = setup_build_options(args) + + if args.current_project is not None: + opts.repo_project = args.current_project + if args.project is None: + if opts.repo_project is None: + raise UsageError( + "no project name specified, and no .projectid file found" + ) + if opts.repo_project == "fbsource": + # The fbsource repository is a little special. There is no project + # manifest file for it. A specific project must always be explicitly + # specified when building from fbsource. + raise UsageError( + "no project name specified (required when building in fbsource)" + ) + args.project = opts.repo_project + + ctx_gen = opts.get_context_generator(facebook_internal=args.facebook_internal) + if args.test_dependencies: + ctx_gen.set_value_for_all_projects("test", "on") + if args.enable_tests: + ctx_gen.set_value_for_project(args.project, "test", "on") + else: + ctx_gen.set_value_for_project(args.project, "test", "off") + + loader = ManifestLoader(opts, ctx_gen) + self.process_project_dir_arguments(args, loader) + + manifest = loader.load_manifest(args.project) + + self.run_project_cmd(args, loader, manifest) + + def process_project_dir_arguments(self, args, loader): + def parse_project_arg(arg, arg_type): + parts = arg.split(":") + if len(parts) == 2: + project, path = parts + elif len(parts) == 1: + project = args.project + path = parts[0] + # On Windows path contains colon, e.g. C:\open + elif os.name == "nt" and len(parts) == 3: + project = parts[0] + path = parts[1] + ":" + parts[2] + else: + raise UsageError( + "invalid %s argument; too many ':' characters: %s" % (arg_type, arg) + ) + + return project, os.path.abspath(path) + + # If we are currently running from a project repository, + # use the current repository for the project sources. + build_opts = loader.build_opts + if build_opts.repo_project is not None and build_opts.repo_root is not None: + loader.set_project_src_dir(build_opts.repo_project, build_opts.repo_root) + + for arg in args.src_dir: + project, path = parse_project_arg(arg, "--src-dir") + loader.set_project_src_dir(project, path) + + for arg in args.build_dir: + project, path = parse_project_arg(arg, "--build-dir") + loader.set_project_build_dir(project, path) + + for arg in args.install_dir: + project, path = parse_project_arg(arg, "--install-dir") + loader.set_project_install_dir(project, path) + + for arg in args.project_install_prefix: + project, path = parse_project_arg(arg, "--install-prefix") + loader.set_project_install_prefix(project, path) + + def setup_parser(self, parser): + parser.add_argument( + "project", + nargs="?", + help=( + "name of the project or path to a manifest " + "file describing the project" + ), + ) + parser.add_argument( + "--no-tests", + action="store_false", + dest="enable_tests", + default=True, + help="Disable building tests for this project.", + ) + parser.add_argument( + "--test-dependencies", + action="store_true", + help="Enable building tests for dependencies as well.", + ) + parser.add_argument( + "--current-project", + help="Specify the name of the fbcode_builder manifest file for the " + "current repository. If not specified, the code will attempt to find " + "this in a .projectid file in the repository root.", + ) + parser.add_argument( + "--src-dir", + default=[], + action="append", + help="Specify a local directory to use for the project source, " + "rather than fetching it.", + ) + parser.add_argument( + "--build-dir", + default=[], + action="append", + help="Explicitly specify the build directory to use for the " + "project, instead of the default location in the scratch path. " + "This only affects the project specified, and not its dependencies.", + ) + parser.add_argument( + "--install-dir", + default=[], + action="append", + help="Explicitly specify the install directory to use for the " + "project, instead of the default location in the scratch path. " + "This only affects the project specified, and not its dependencies.", + ) + parser.add_argument( + "--project-install-prefix", + default=[], + action="append", + help="Specify the final deployment installation path for a project", + ) + + self.setup_project_cmd_parser(parser) + + def setup_project_cmd_parser(self, parser): + pass + + +class CachedProject(object): + """A helper that allows calling the cache logic for a project + from both the build and the fetch code""" + + def __init__(self, cache, loader, m): + self.m = m + self.inst_dir = loader.get_project_install_dir(m) + self.project_hash = loader.get_project_hash(m) + self.ctx = loader.ctx_gen.get_context(m.name) + self.loader = loader + self.cache = cache + + self.cache_file_name = "-".join( + ( + m.name, + self.ctx.get("os"), + self.ctx.get("distro") or "none", + self.ctx.get("distro_vers") or "none", + self.project_hash, + "buildcache.tgz", + ) + ) + + def is_cacheable(self): + """We only cache third party projects""" + return self.cache and self.m.shipit_project is None + + def was_cached(self): + cached_marker = os.path.join(self.inst_dir, ".getdeps-cached-build") + return os.path.exists(cached_marker) + + def download(self): + if self.is_cacheable() and not os.path.exists(self.inst_dir): + print("check cache for %s" % self.cache_file_name) + dl_dir = os.path.join(self.loader.build_opts.scratch_dir, "downloads") + if not os.path.exists(dl_dir): + os.makedirs(dl_dir) + try: + target_file_name = os.path.join(dl_dir, self.cache_file_name) + if self.cache.download_to_file(self.cache_file_name, target_file_name): + tf = tarfile.open(target_file_name, "r") + print( + "Extracting %s -> %s..." % (self.cache_file_name, self.inst_dir) + ) + tf.extractall(self.inst_dir) + + cached_marker = os.path.join(self.inst_dir, ".getdeps-cached-build") + with open(cached_marker, "w") as f: + f.write("\n") + + return True + except Exception as exc: + print("%s" % str(exc)) + + return False + + def upload(self): + if self.is_cacheable(): + # We can prepare an archive and stick it in LFS + tempdir = tempfile.mkdtemp() + tarfilename = os.path.join(tempdir, self.cache_file_name) + print("Archiving for cache: %s..." % tarfilename) + tf = tarfile.open(tarfilename, "w:gz") + tf.add(self.inst_dir, arcname=".") + tf.close() + try: + self.cache.upload_from_file(self.cache_file_name, tarfilename) + except Exception as exc: + print( + "Failed to upload to cache (%s), continue anyway" % str(exc), + file=sys.stderr, + ) + shutil.rmtree(tempdir) + + +@cmd("fetch", "fetch the code for a given project") +class FetchCmd(ProjectCmdBase): + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="fetch the transitive deps also", + action="store_true", + default=False, + ) + parser.add_argument( + "--host-type", + help=( + "When recursively fetching, fetch deps for " + "this host type rather than the current system" + ), + ) + + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + projects = loader.manifests_in_dependency_order() + else: + projects = [manifest] + + cache = cache_module.create_cache() + for m in projects: + cached_project = CachedProject(cache, loader, m) + if cached_project.download(): + continue + + inst_dir = loader.get_project_install_dir(m) + built_marker = os.path.join(inst_dir, ".built-by-getdeps") + if os.path.exists(built_marker): + with open(built_marker, "r") as f: + built_hash = f.read().strip() + + project_hash = loader.get_project_hash(m) + if built_hash == project_hash: + continue + + # We need to fetch the sources + fetcher = loader.create_fetcher(m) + fetcher.update() + + +@cmd("install-system-deps", "Install system packages to satisfy the deps for a project") +class InstallSysDepsCmd(ProjectCmdBase): + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="install the transitive deps also", + action="store_true", + default=False, + ) + + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + projects = loader.manifests_in_dependency_order() + else: + projects = [manifest] + + cache = cache_module.create_cache() + all_packages = {} + for m in projects: + ctx = loader.ctx_gen.get_context(m.name) + packages = m.get_required_system_packages(ctx) + for k, v in packages.items(): + merged = all_packages.get(k, []) + merged += v + all_packages[k] = merged + + manager = loader.build_opts.host_type.get_package_manager() + if manager == "rpm": + packages = sorted(list(set(all_packages["rpm"]))) + if packages: + run_cmd(["dnf", "install", "-y"] + packages) + elif manager == "deb": + packages = sorted(list(set(all_packages["deb"]))) + if packages: + run_cmd(["apt", "install", "-y"] + packages) + else: + print("I don't know how to install any packages on this system") + + +@cmd("list-deps", "lists the transitive deps for a given project") +class ListDepsCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + for m in loader.manifests_in_dependency_order(): + print(m.name) + return 0 + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--host-type", + help=( + "Produce the list for the specified host type, " + "rather than that of the current system" + ), + ) + + +def clean_dirs(opts): + for d in ["build", "installed", "extracted", "shipit"]: + d = os.path.join(opts.scratch_dir, d) + print("Cleaning %s..." % d) + if os.path.exists(d): + shutil.rmtree(d) + + +@cmd("clean", "clean up the scratch dir") +class CleanCmd(SubCmd): + def run(self, args): + opts = setup_build_options(args) + clean_dirs(opts) + + +@cmd("show-build-dir", "print the build dir for a given project") +class ShowBuildDirCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + manifests = loader.manifests_in_dependency_order() + else: + manifests = [manifest] + + for m in manifests: + inst_dir = loader.get_project_build_dir(m) + print(inst_dir) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="print the transitive deps also", + action="store_true", + default=False, + ) + + +@cmd("show-inst-dir", "print the installation dir for a given project") +class ShowInstDirCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + manifests = loader.manifests_in_dependency_order() + else: + manifests = [manifest] + + for m in manifests: + inst_dir = loader.get_project_install_dir_respecting_install_prefix(m) + print(inst_dir) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="print the transitive deps also", + action="store_true", + default=False, + ) + + +@cmd("show-source-dir", "print the source dir for a given project") +class ShowSourceDirCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + manifests = loader.manifests_in_dependency_order() + else: + manifests = [manifest] + + for m in manifests: + fetcher = loader.create_fetcher(m) + print(fetcher.get_src_dir()) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="print the transitive deps also", + action="store_true", + default=False, + ) + + +@cmd("build", "build a given project") +class BuildCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.clean: + clean_dirs(loader.build_opts) + + print("Building on %s" % loader.ctx_gen.get_context(args.project)) + projects = loader.manifests_in_dependency_order() + + cache = cache_module.create_cache() if args.use_build_cache else None + + # Accumulate the install directories so that the build steps + # can find their dep installation + install_dirs = [] + + for m in projects: + fetcher = loader.create_fetcher(m) + + if isinstance(fetcher, SystemPackageFetcher): + # We are guaranteed that if the fetcher is set to + # SystemPackageFetcher then this item is completely + # satisfied by the appropriate system packages + continue + + if args.clean: + fetcher.clean() + + build_dir = loader.get_project_build_dir(m) + inst_dir = loader.get_project_install_dir(m) + + if ( + m == manifest + and not args.only_deps + or m != manifest + and not args.no_deps + ): + print("Assessing %s..." % m.name) + project_hash = loader.get_project_hash(m) + ctx = loader.ctx_gen.get_context(m.name) + built_marker = os.path.join(inst_dir, ".built-by-getdeps") + + cached_project = CachedProject(cache, loader, m) + + reconfigure, sources_changed = self.compute_source_change_status( + cached_project, fetcher, m, built_marker, project_hash + ) + + if os.path.exists(built_marker) and not cached_project.was_cached(): + # We've previously built this. We may need to reconfigure if + # our deps have changed, so let's check them. + dep_reconfigure, dep_build = self.compute_dep_change_status( + m, built_marker, loader + ) + if dep_reconfigure: + reconfigure = True + if dep_build: + sources_changed = True + + extra_cmake_defines = ( + json.loads(args.extra_cmake_defines) + if args.extra_cmake_defines + else {} + ) + + if sources_changed or reconfigure or not os.path.exists(built_marker): + if os.path.exists(built_marker): + os.unlink(built_marker) + src_dir = fetcher.get_src_dir() + builder = m.create_builder( + loader.build_opts, + src_dir, + build_dir, + inst_dir, + ctx, + loader, + final_install_prefix=loader.get_project_install_prefix(m), + extra_cmake_defines=extra_cmake_defines, + ) + builder.build(install_dirs, reconfigure=reconfigure) + + with open(built_marker, "w") as f: + f.write(project_hash) + + # Only populate the cache from continuous build runs + if args.schedule_type == "continuous": + cached_project.upload() + + install_dirs.append(inst_dir) + + def compute_dep_change_status(self, m, built_marker, loader): + reconfigure = False + sources_changed = False + st = os.lstat(built_marker) + + ctx = loader.ctx_gen.get_context(m.name) + dep_list = sorted(m.get_section_as_dict("dependencies", ctx).keys()) + for dep in dep_list: + if reconfigure and sources_changed: + break + + dep_manifest = loader.load_manifest(dep) + dep_root = loader.get_project_install_dir(dep_manifest) + for dep_file in list_files_under_dir_newer_than_timestamp( + dep_root, st.st_mtime + ): + if os.path.basename(dep_file) == ".built-by-getdeps": + continue + if file_name_is_cmake_file(dep_file): + if not reconfigure: + reconfigure = True + print( + f"Will reconfigure cmake because {dep_file} is newer than {built_marker}" + ) + else: + if not sources_changed: + sources_changed = True + print( + f"Will run build because {dep_file} is newer than {built_marker}" + ) + + if reconfigure and sources_changed: + break + + return reconfigure, sources_changed + + def compute_source_change_status( + self, cached_project, fetcher, m, built_marker, project_hash + ): + reconfigure = False + sources_changed = False + if not cached_project.download(): + check_fetcher = True + if os.path.exists(built_marker): + check_fetcher = False + with open(built_marker, "r") as f: + built_hash = f.read().strip() + if built_hash == project_hash: + if cached_project.is_cacheable(): + # We can blindly trust the build status + reconfigure = False + sources_changed = False + else: + # Otherwise, we may have changed the source, so let's + # check in with the fetcher layer + check_fetcher = True + else: + # Some kind of inconsistency with a prior build, + # let's run it again to be sure + os.unlink(built_marker) + reconfigure = True + sources_changed = True + # While we don't need to consult the fetcher for the + # status in this case, we may still need to have eg: shipit + # run in order to have a correct source tree. + fetcher.update() + + if check_fetcher: + change_status = fetcher.update() + reconfigure = change_status.build_changed() + sources_changed = change_status.sources_changed() + + return reconfigure, sources_changed + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--clean", + action="store_true", + default=False, + help=( + "Clean up the build and installation area prior to building, " + "causing the projects to be built from scratch" + ), + ) + parser.add_argument( + "--no-deps", + action="store_true", + default=False, + help=( + "Only build the named project, not its deps. " + "This is most useful after you've built all of the deps, " + "and helps to avoid waiting for relatively " + "slow up-to-date-ness checks" + ), + ) + parser.add_argument( + "--only-deps", + action="store_true", + default=False, + help=( + "Only build the named project's deps. " + "This is most useful when you want to separate out building " + "of all of the deps and your project" + ), + ) + parser.add_argument( + "--no-build-cache", + action="store_false", + default=True, + dest="use_build_cache", + help="Do not attempt to use the build cache.", + ) + parser.add_argument( + "--schedule-type", help="Indicates how the build was activated" + ) + parser.add_argument( + "--extra-cmake-defines", + help=( + "Input json map that contains extra cmake defines to be used " + "when compiling the current project and all its deps. " + 'e.g: \'{"CMAKE_CXX_FLAGS": "--bla"}\'' + ), + ) + + +@cmd("fixup-dyn-deps", "Adjusts dynamic dependencies for packaging purposes") +class FixupDeps(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + projects = loader.manifests_in_dependency_order() + + # Accumulate the install directories so that the build steps + # can find their dep installation + install_dirs = [] + + for m in projects: + inst_dir = loader.get_project_install_dir_respecting_install_prefix(m) + install_dirs.append(inst_dir) + + if m == manifest: + dep_munger = create_dyn_dep_munger( + loader.build_opts, install_dirs, args.strip + ) + dep_munger.process_deps(args.destdir, args.final_install_prefix) + + def setup_project_cmd_parser(self, parser): + parser.add_argument("destdir", help="Where to copy the fixed up executables") + parser.add_argument( + "--final-install-prefix", help="specify the final installation prefix" + ) + parser.add_argument( + "--strip", + action="store_true", + default=False, + help="Strip debug info while processing executables", + ) + + +@cmd("test", "test a given project") +class TestCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + projects = loader.manifests_in_dependency_order() + + # Accumulate the install directories so that the test steps + # can find their dep installation + install_dirs = [] + + for m in projects: + inst_dir = loader.get_project_install_dir(m) + + if m == manifest or args.test_dependencies: + built_marker = os.path.join(inst_dir, ".built-by-getdeps") + if not os.path.exists(built_marker): + print("project %s has not been built" % m.name) + # TODO: we could just go ahead and build it here, but I + # want to tackle that as part of adding build-for-test + # support. + return 1 + fetcher = loader.create_fetcher(m) + src_dir = fetcher.get_src_dir() + ctx = loader.ctx_gen.get_context(m.name) + build_dir = loader.get_project_build_dir(m) + builder = m.create_builder( + loader.build_opts, src_dir, build_dir, inst_dir, ctx, loader + ) + + builder.run_tests( + install_dirs, + schedule_type=args.schedule_type, + owner=args.test_owner, + test_filter=args.filter, + retry=args.retry, + no_testpilot=args.no_testpilot, + ) + + install_dirs.append(inst_dir) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--schedule-type", help="Indicates how the build was activated" + ) + parser.add_argument("--test-owner", help="Owner for testpilot") + parser.add_argument("--filter", help="Only run the tests matching the regex") + parser.add_argument( + "--retry", + type=int, + default=3, + help="Number of immediate retries for failed tests " + "(noop in continuous and testwarden runs)", + ) + parser.add_argument( + "--no-testpilot", + help="Do not use Test Pilot even when available", + action="store_true", + ) + + +@cmd("generate-github-actions", "generate a GitHub actions configuration") +class GenerateGitHubActionsCmd(ProjectCmdBase): + RUN_ON_ALL = """ [push, pull_request]""" + RUN_ON_DEFAULT = """ + push: + branches: + - master + pull_request: + branches: + - master""" + + def run_project_cmd(self, args, loader, manifest): + platforms = [ + HostType("linux", "ubuntu", "18"), + HostType("darwin", None, None), + HostType("windows", None, None), + ] + + for p in platforms: + self.write_job_for_platform(p, args) + + # TODO: Break up complex function + def write_job_for_platform(self, platform, args): # noqa: C901 + build_opts = setup_build_options(args, platform) + ctx_gen = build_opts.get_context_generator(facebook_internal=False) + loader = ManifestLoader(build_opts, ctx_gen) + manifest = loader.load_manifest(args.project) + manifest_ctx = loader.ctx_gen.get_context(manifest.name) + run_on = self.RUN_ON_ALL if args.run_on_all_branches else self.RUN_ON_DEFAULT + + # Some projects don't do anything "useful" as a leaf project, only + # as a dep for a leaf project. Check for those here; we don't want + # to waste the effort scheduling them on CI. + # We do this by looking at the builder type in the manifest file + # rather than creating a builder and checking its type because we + # don't know enough to create the full builder instance here. + if manifest.get("build", "builder", ctx=manifest_ctx) == "nop": + return None + + # We want to be sure that we're running things with python 3 + # but python versioning is honestly a bit of a frustrating mess. + # `python` may be version 2 or version 3 depending on the system. + # python3 may not be a thing at all! + # Assume an optimistic default + py3 = "python3" + + if build_opts.is_linux(): + job_name = "linux" + runs_on = f"ubuntu-{args.ubuntu_version}" + elif build_opts.is_windows(): + # We're targeting the windows-2016 image because it has + # Visual Studio 2017 installed, and at the time of writing, + # the version of boost in the manifests (1.69) is not + # buildable with Visual Studio 2019 + job_name = "windows" + runs_on = "windows-2016" + # The windows runners are python 3 by default; python2.exe + # is available if needed. + py3 = "python" + else: + job_name = "mac" + runs_on = "macOS-latest" + + os.makedirs(args.output_dir, exist_ok=True) + output_file = os.path.join(args.output_dir, f"getdeps_{job_name}.yml") + with open(output_file, "w") as out: + # Deliberate line break here because the @ and the generated + # symbols are meaningful to our internal tooling when they + # appear in a single token + out.write("# This file was @") + out.write("generated by getdeps.py\n") + out.write( + f""" +name: {job_name} + +on:{run_on} + +jobs: +""" + ) + + getdeps = f"{py3} build/fbcode_builder/getdeps.py" + + out.write(" build:\n") + out.write(" runs-on: %s\n" % runs_on) + out.write(" steps:\n") + out.write(" - uses: actions/checkout@v1\n") + + if build_opts.is_windows(): + # cmake relies on BOOST_ROOT but GH deliberately don't set it in order + # to avoid versioning issues: + # https://github.com/actions/virtual-environments/issues/319 + # Instead, set the version we think we need; this is effectively + # coupled with the boost manifest + # This is the unusual syntax for setting an env var for the rest of + # the steps in a workflow: + # https://github.blog/changelog/2020-10-01-github-actions-deprecating-set-env-and-add-path-commands/ + out.write(" - name: Export boost environment\n") + out.write( + ' run: "echo BOOST_ROOT=%BOOST_ROOT_1_69_0% >> %GITHUB_ENV%"\n' + ) + out.write(" shell: cmd\n") + + # The git installation may not like long filenames, so tell it + # that we want it to use them! + out.write(" - name: Fix Git config\n") + out.write(" run: git config --system core.longpaths true\n") + + projects = loader.manifests_in_dependency_order() + + for m in projects: + if m != manifest: + out.write(" - name: Fetch %s\n" % m.name) + out.write(f" run: {getdeps} fetch --no-tests {m.name}\n") + + for m in projects: + if m != manifest: + out.write(" - name: Build %s\n" % m.name) + out.write(f" run: {getdeps} build --no-tests {m.name}\n") + + out.write(" - name: Build %s\n" % manifest.name) + + project_prefix = "" + if not build_opts.is_windows(): + project_prefix = ( + " --project-install-prefix %s:/usr/local" % manifest.name + ) + + out.write( + f" run: {getdeps} build --src-dir=. {manifest.name} {project_prefix}\n" + ) + + out.write(" - name: Copy artifacts\n") + if build_opts.is_linux(): + # Strip debug info from the binaries, but only on linux. + # While the `strip` utility is also available on macOS, + # attempting to strip there results in an error. + # The `strip` utility is not available on Windows. + strip = " --strip" + else: + strip = "" + + out.write( + f" run: {getdeps} fixup-dyn-deps{strip} " + f"--src-dir=. {manifest.name} _artifacts/{job_name} {project_prefix} " + f"--final-install-prefix /usr/local\n" + ) + + out.write(" - uses: actions/upload-artifact@master\n") + out.write(" with:\n") + out.write(" name: %s\n" % manifest.name) + out.write(" path: _artifacts\n") + + out.write(" - name: Test %s\n" % manifest.name) + out.write( + f" run: {getdeps} test --src-dir=. {manifest.name} {project_prefix}\n" + ) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--disallow-system-packages", + help="Disallow satisfying third party deps from installed system packages", + action="store_true", + default=False, + ) + parser.add_argument( + "--output-dir", help="The directory that will contain the yml files" + ) + parser.add_argument( + "--run-on-all-branches", + action="store_true", + help="Allow CI to fire on all branches - Handy for testing", + ) + parser.add_argument( + "--ubuntu-version", default="18.04", help="Version of Ubuntu to use" + ) + + +def get_arg_var_name(args): + for arg in args: + if arg.startswith("--"): + return arg[2:].replace("-", "_") + + raise Exception("unable to determine argument variable name from %r" % (args,)) + + +def parse_args(): + # We want to allow common arguments to be specified either before or after + # the subcommand name. In order to do this we add them to the main parser + # and to subcommand parsers. In order for this to work, we need to tell + # argparse that the default value is SUPPRESS, so that the default values + # from the subparser arguments won't override values set by the user from + # the main parser. We maintain our own list of desired defaults in the + # common_defaults dictionary, and manually set those if the argument wasn't + # present at all. + common_args = argparse.ArgumentParser(add_help=False) + common_defaults = {} + + def add_common_arg(*args, **kwargs): + var_name = get_arg_var_name(args) + default_value = kwargs.pop("default", None) + common_defaults[var_name] = default_value + kwargs["default"] = argparse.SUPPRESS + common_args.add_argument(*args, **kwargs) + + add_common_arg("--scratch-path", help="Where to maintain checkouts and build dirs") + add_common_arg( + "--vcvars-path", default=None, help="Path to the vcvarsall.bat on Windows." + ) + add_common_arg( + "--install-prefix", + help=( + "Where the final build products will be installed " + "(default is [scratch-path]/installed)" + ), + ) + add_common_arg( + "--num-jobs", + type=int, + help=( + "Number of concurrent jobs to use while building. " + "(default=number of cpu cores)" + ), + ) + add_common_arg( + "--use-shipit", + help="use the real ShipIt instead of the simple shipit transformer", + action="store_true", + default=False, + ) + add_common_arg( + "--facebook-internal", + help="Setup the build context as an FB internal build", + action="store_true", + default=None, + ) + add_common_arg( + "--no-facebook-internal", + help="Perform a non-FB internal build, even when in an fbsource repository", + action="store_false", + dest="facebook_internal", + ) + add_common_arg( + "--allow-system-packages", + help="Allow satisfying third party deps from installed system packages", + action="store_true", + default=False, + ) + add_common_arg( + "--lfs-path", + help="Provide a parent directory for lfs when fbsource is unavailable", + default=None, + ) + + ap = argparse.ArgumentParser( + description="Get and build dependencies and projects", parents=[common_args] + ) + sub = ap.add_subparsers( + # metavar suppresses the long and ugly default list of subcommands on a + # single line. We still render the nicer list below where we would + # have shown the nasty one. + metavar="", + title="Available commands", + help="", + ) + + add_subcommands(sub, common_args) + + args = ap.parse_args() + for var_name, default_value in common_defaults.items(): + if not hasattr(args, var_name): + setattr(args, var_name, default_value) + + return ap, args + + +def main(): + ap, args = parse_args() + if getattr(args, "func", None) is None: + ap.print_help() + return 0 + try: + return args.func(args) + except UsageError as exc: + ap.error(str(exc)) + return 1 + except TransientFailure as exc: + print("TransientFailure: %s" % str(exc)) + # This return code is treated as a retryable transient infrastructure + # error by Facebook's internal CI, rather than eg: a build or code + # related error that needs to be fixed before progress can be made. + return 128 + except subprocess.CalledProcessError as exc: + print("%s" % str(exc), file=sys.stderr) + print("!! Failed", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/yarpl/include/yarpl/flowable/FlowableOperator_Merge.h b/build/fbcode_builder/getdeps/__init__.py similarity index 100% rename from yarpl/include/yarpl/flowable/FlowableOperator_Merge.h rename to build/fbcode_builder/getdeps/__init__.py diff --git a/build/fbcode_builder/getdeps/builder.py b/build/fbcode_builder/getdeps/builder.py new file mode 100644 index 000000000..4e523c2dc --- /dev/null +++ b/build/fbcode_builder/getdeps/builder.py @@ -0,0 +1,1400 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import os +import shutil +import stat +import subprocess +import sys + +from .dyndeps import create_dyn_dep_munger +from .envfuncs import Env, add_path_entry, path_search +from .fetcher import copy_if_different +from .runcmd import run_cmd + + +class BuilderBase(object): + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + env=None, + final_install_prefix=None, + ): + self.env = Env() + if env: + self.env.update(env) + + subdir = manifest.get("build", "subdir", ctx=ctx) + if subdir: + src_dir = os.path.join(src_dir, subdir) + + self.ctx = ctx + self.src_dir = src_dir + self.build_dir = build_dir or src_dir + self.inst_dir = inst_dir + self.build_opts = build_opts + self.manifest = manifest + self.final_install_prefix = final_install_prefix + + def _get_cmd_prefix(self): + if self.build_opts.is_windows(): + vcvarsall = self.build_opts.get_vcvars_path() + if vcvarsall is not None: + # Since it sets rather a large number of variables we mildly abuse + # the cmd quoting rules to assemble a command that calls the script + # to prep the environment and then triggers the actual command that + # we wanted to run. + return [vcvarsall, "amd64", "&&"] + return [] + + def _run_cmd(self, cmd, cwd=None, env=None, use_cmd_prefix=True, allow_fail=False): + if env: + e = self.env.copy() + e.update(env) + env = e + else: + env = self.env + + if use_cmd_prefix: + cmd_prefix = self._get_cmd_prefix() + if cmd_prefix: + cmd = cmd_prefix + cmd + + log_file = os.path.join(self.build_dir, "getdeps_build.log") + return run_cmd( + cmd=cmd, + env=env, + cwd=cwd or self.build_dir, + log_file=log_file, + allow_fail=allow_fail, + ) + + def build(self, install_dirs, reconfigure): + print("Building %s..." % self.manifest.name) + + if self.build_dir is not None: + if not os.path.isdir(self.build_dir): + os.makedirs(self.build_dir) + reconfigure = True + + self._build(install_dirs=install_dirs, reconfigure=reconfigure) + + # On Windows, emit a wrapper script that can be used to run build artifacts + # directly from the build directory, without installing them. On Windows $PATH + # needs to be updated to include all of the directories containing the runtime + # library dependencies in order to run the binaries. + if self.build_opts.is_windows(): + script_path = self.get_dev_run_script_path() + dep_munger = create_dyn_dep_munger(self.build_opts, install_dirs) + dep_dirs = self.get_dev_run_extra_path_dirs(install_dirs, dep_munger) + dep_munger.emit_dev_run_script(script_path, dep_dirs) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + """Execute any tests that we know how to run. If they fail, + raise an exception.""" + pass + + def _build(self, install_dirs, reconfigure): + """Perform the build. + install_dirs contains the list of installation directories for + the dependencies of this project. + reconfigure will be set to true if the fetcher determined + that the sources have changed in such a way that the build + system needs to regenerate its rules.""" + pass + + def _compute_env(self, install_dirs): + # CMAKE_PREFIX_PATH is only respected when passed through the + # environment, so we construct an appropriate path to pass down + return self.build_opts.compute_env_for_install_dirs( + install_dirs, env=self.env, manifest=self.manifest + ) + + def get_dev_run_script_path(self): + assert self.build_opts.is_windows() + return os.path.join(self.build_dir, "run.ps1") + + def get_dev_run_extra_path_dirs(self, install_dirs, dep_munger=None): + assert self.build_opts.is_windows() + if dep_munger is None: + dep_munger = create_dyn_dep_munger(self.build_opts, install_dirs) + return dep_munger.compute_dependency_paths(self.build_dir) + + +class MakeBuilder(BuilderBase): + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + build_args, + install_args, + test_args, + ): + super(MakeBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.build_args = build_args or [] + self.install_args = install_args or [] + self.test_args = test_args + + def _get_prefix(self): + return ["PREFIX=" + self.inst_dir, "prefix=" + self.inst_dir] + + def _build(self, install_dirs, reconfigure): + env = self._compute_env(install_dirs) + + # Need to ensure that PREFIX is set prior to install because + # libbpf uses it when generating its pkg-config file. + # The lowercase prefix is used by some projects. + cmd = ( + ["make", "-j%s" % self.build_opts.num_jobs] + + self.build_args + + self._get_prefix() + ) + self._run_cmd(cmd, env=env) + + install_cmd = ["make"] + self.install_args + self._get_prefix() + self._run_cmd(install_cmd, env=env) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + if not self.test_args: + return + + env = self._compute_env(install_dirs) + + cmd = ["make"] + self.test_args + self._get_prefix() + self._run_cmd(cmd, env=env) + + +class CMakeBootStrapBuilder(MakeBuilder): + def _build(self, install_dirs, reconfigure): + self._run_cmd(["./bootstrap", "--prefix=" + self.inst_dir]) + super(CMakeBootStrapBuilder, self)._build(install_dirs, reconfigure) + + +class AutoconfBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir, args): + super(AutoconfBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.args = args or [] + + def _build(self, install_dirs, reconfigure): + configure_path = os.path.join(self.src_dir, "configure") + autogen_path = os.path.join(self.src_dir, "autogen.sh") + + env = self._compute_env(install_dirs) + + if not os.path.exists(configure_path): + print("%s doesn't exist, so reconfiguring" % configure_path) + # This libtoolize call is a bit gross; the issue is that + # `autoreconf` as invoked by libsodium's `autogen.sh` doesn't + # seem to realize that it should invoke libtoolize and then + # error out when the configure script references a libtool + # related symbol. + self._run_cmd(["libtoolize"], cwd=self.src_dir, env=env) + + # We generally prefer to call the `autogen.sh` script provided + # by the project on the basis that it may know more than plain + # autoreconf does. + if os.path.exists(autogen_path): + self._run_cmd(["bash", autogen_path], cwd=self.src_dir, env=env) + else: + self._run_cmd(["autoreconf", "-ivf"], cwd=self.src_dir, env=env) + configure_cmd = [configure_path, "--prefix=" + self.inst_dir] + self.args + self._run_cmd(configure_cmd, env=env) + self._run_cmd(["make", "-j%s" % self.build_opts.num_jobs], env=env) + self._run_cmd(["make", "install"], env=env) + + +class Iproute2Builder(BuilderBase): + # ./configure --prefix does not work for iproute2. + # Thus, explicitly copy sources from src_dir to build_dir, bulid, + # and then install to inst_dir using DESTDIR + # lastly, also copy include from build_dir to inst_dir + def __init__(self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir): + super(Iproute2Builder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _patch(self): + # FBOSS build currently depends on an old version of iproute2 (commit + # 7ca63aef7d1b0c808da0040c6b366ef7a61f38c1). This is missing a commit + # (ae717baf15fb4d30749ada3948d9445892bac239) needed to build iproute2 + # successfully. Apply it viz.: include stdint.h + # Reference: https://fburl.com/ilx9g5xm + with open(self.build_dir + "/tc/tc_core.c", "r") as f: + data = f.read() + + with open(self.build_dir + "/tc/tc_core.c", "w") as f: + f.write("#include \n") + f.write(data) + + def _build(self, install_dirs, reconfigure): + configure_path = os.path.join(self.src_dir, "configure") + + env = self.env.copy() + self._run_cmd([configure_path], env=env) + shutil.rmtree(self.build_dir) + shutil.copytree(self.src_dir, self.build_dir) + self._patch() + self._run_cmd(["make", "-j%s" % self.build_opts.num_jobs], env=env) + install_cmd = ["make", "install", "DESTDIR=" + self.inst_dir] + + for d in ["include", "lib"]: + if not os.path.isdir(os.path.join(self.inst_dir, d)): + shutil.copytree( + os.path.join(self.build_dir, d), os.path.join(self.inst_dir, d) + ) + + self._run_cmd(install_cmd, env=env) + + +class BistroBuilder(BuilderBase): + def _build(self, install_dirs, reconfigure): + p = os.path.join(self.src_dir, "bistro", "bistro") + env = self._compute_env(install_dirs) + env["PATH"] = env["PATH"] + ":" + os.path.join(p, "bin") + env["TEMPLATES_PATH"] = os.path.join(p, "include", "thrift", "templates") + self._run_cmd( + [ + os.path.join(".", "cmake", "run-cmake.sh"), + "Release", + "-DCMAKE_INSTALL_PREFIX=" + self.inst_dir, + ], + cwd=p, + env=env, + ) + self._run_cmd( + [ + "make", + "install", + "-j", + str(self.build_opts.num_jobs), + ], + cwd=os.path.join(p, "cmake", "Release"), + env=env, + ) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + env = self._compute_env(install_dirs) + build_dir = os.path.join(self.src_dir, "bistro", "bistro", "cmake", "Release") + NUM_RETRIES = 5 + for i in range(NUM_RETRIES): + cmd = ["ctest", "--output-on-failure"] + if i > 0: + cmd.append("--rerun-failed") + cmd.append(build_dir) + try: + self._run_cmd( + cmd, + cwd=build_dir, + env=env, + ) + except Exception: + print(f"Tests failed... retrying ({i+1}/{NUM_RETRIES})") + else: + return + raise Exception(f"Tests failed even after {NUM_RETRIES} retries") + + +class CMakeBuilder(BuilderBase): + MANUAL_BUILD_SCRIPT = """\ +#!{sys.executable} + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import subprocess +import sys + +CMAKE = {cmake!r} +CTEST = {ctest!r} +SRC_DIR = {src_dir!r} +BUILD_DIR = {build_dir!r} +INSTALL_DIR = {install_dir!r} +CMD_PREFIX = {cmd_prefix!r} +CMAKE_ENV = {env_str} +CMAKE_DEFINE_ARGS = {define_args_str} + + +def get_jobs_argument(num_jobs_arg: int) -> str: + if num_jobs_arg > 0: + return "-j" + str(num_jobs_arg) + + import multiprocessing + num_jobs = multiprocessing.cpu_count() // 2 + return "-j" + str(num_jobs) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "cmake_args", + nargs=argparse.REMAINDER, + help='Any extra arguments after an "--" argument will be passed ' + "directly to CMake." + ) + ap.add_argument( + "--mode", + choices=["configure", "build", "install", "test"], + default="configure", + help="The mode to run: configure, build, or install. " + "Defaults to configure", + ) + ap.add_argument( + "--build", + action="store_const", + const="build", + dest="mode", + help="An alias for --mode=build", + ) + ap.add_argument( + "-j", + "--num-jobs", + action="store", + type=int, + default=0, + help="Run the build or tests with the specified number of parallel jobs", + ) + ap.add_argument( + "--install", + action="store_const", + const="install", + dest="mode", + help="An alias for --mode=install", + ) + ap.add_argument( + "--test", + action="store_const", + const="test", + dest="mode", + help="An alias for --mode=test", + ) + args = ap.parse_args() + + # Strip off a leading "--" from the additional CMake arguments + if args.cmake_args and args.cmake_args[0] == "--": + args.cmake_args = args.cmake_args[1:] + + env = CMAKE_ENV + + if args.mode == "configure": + full_cmd = CMD_PREFIX + [CMAKE, SRC_DIR] + CMAKE_DEFINE_ARGS + args.cmake_args + elif args.mode in ("build", "install"): + target = "all" if args.mode == "build" else "install" + full_cmd = CMD_PREFIX + [ + CMAKE, + "--build", + BUILD_DIR, + "--target", + target, + "--config", + "Release", + get_jobs_argument(args.num_jobs), + ] + args.cmake_args + elif args.mode == "test": + full_cmd = CMD_PREFIX + [ + {dev_run_script}CTEST, + "--output-on-failure", + get_jobs_argument(args.num_jobs), + ] + args.cmake_args + else: + ap.error("unknown invocation mode: %s" % (args.mode,)) + + cmd_str = " ".join(full_cmd) + print("Running: %r" % (cmd_str,)) + proc = subprocess.run(full_cmd, env=env, cwd=BUILD_DIR) + sys.exit(proc.returncode) + + +if __name__ == "__main__": + main() +""" + + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + defines, + final_install_prefix=None, + extra_cmake_defines=None, + ): + super(CMakeBuilder, self).__init__( + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + final_install_prefix=final_install_prefix, + ) + self.defines = defines or {} + if extra_cmake_defines: + self.defines.update(extra_cmake_defines) + + def _invalidate_cache(self): + for name in [ + "CMakeCache.txt", + "CMakeFiles/CMakeError.log", + "CMakeFiles/CMakeOutput.log", + ]: + name = os.path.join(self.build_dir, name) + if os.path.isdir(name): + shutil.rmtree(name) + elif os.path.exists(name): + os.unlink(name) + + def _needs_reconfigure(self): + for name in ["CMakeCache.txt", "build.ninja"]: + name = os.path.join(self.build_dir, name) + if not os.path.exists(name): + return True + return False + + def _write_build_script(self, **kwargs): + env_lines = [" {!r}: {!r},".format(k, v) for k, v in kwargs["env"].items()] + kwargs["env_str"] = "\n".join(["{"] + env_lines + ["}"]) + + if self.build_opts.is_windows(): + kwargs["dev_run_script"] = '"powershell.exe", {!r}, '.format( + self.get_dev_run_script_path() + ) + else: + kwargs["dev_run_script"] = "" + + define_arg_lines = ["["] + for arg in kwargs["define_args"]: + # Replace the CMAKE_INSTALL_PREFIX argument to use the INSTALL_DIR + # variable that we define in the MANUAL_BUILD_SCRIPT code. + if arg.startswith("-DCMAKE_INSTALL_PREFIX="): + value = " {!r}.format(INSTALL_DIR),".format( + "-DCMAKE_INSTALL_PREFIX={}" + ) + else: + value = " {!r},".format(arg) + define_arg_lines.append(value) + define_arg_lines.append("]") + kwargs["define_args_str"] = "\n".join(define_arg_lines) + + # In order to make it easier for developers to manually run builds for + # CMake-based projects, write out some build scripts that can be used to invoke + # CMake manually. + build_script_path = os.path.join(self.build_dir, "run_cmake.py") + script_contents = self.MANUAL_BUILD_SCRIPT.format(**kwargs) + with open(build_script_path, "wb") as f: + f.write(script_contents.encode()) + os.chmod(build_script_path, 0o755) + + def _compute_cmake_define_args(self, env): + defines = { + "CMAKE_INSTALL_PREFIX": self.final_install_prefix or self.inst_dir, + "BUILD_SHARED_LIBS": "OFF", + # Some of the deps (rsocket) default to UBSAN enabled if left + # unspecified. Some of the deps fail to compile in release mode + # due to warning->error promotion. RelWithDebInfo is the happy + # medium. + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + } + if "SANDCASTLE" not in os.environ: + # We sometimes see intermittent ccache related breakages on some + # of the FB internal CI hosts, so we prefer to disable ccache + # when running in that environment. + ccache = path_search(env, "ccache") + if ccache: + defines["CMAKE_CXX_COMPILER_LAUNCHER"] = ccache + else: + # rocksdb does its own probing for ccache. + # Ensure that it is disabled on sandcastle + env["CCACHE_DISABLE"] = "1" + # Some sandcastle hosts have broken ccache related dirs, and + # even though we've asked for it to be disabled ccache is + # still invoked by rocksdb's cmake. + # Redirect its config directory to somewhere that is guaranteed + # fresh to us, and that won't have any ccache data inside. + env["CCACHE_DIR"] = f"{self.build_opts.scratch_dir}/ccache" + + if "GITHUB_ACTIONS" in os.environ and self.build_opts.is_windows(): + # GitHub actions: the host has both gcc and msvc installed, and + # the default behavior of cmake is to prefer gcc. + # Instruct cmake that we want it to use cl.exe; this is important + # because Boost prefers cl.exe and the mismatch results in cmake + # with gcc not being able to find boost built with cl.exe. + defines["CMAKE_C_COMPILER"] = "cl.exe" + defines["CMAKE_CXX_COMPILER"] = "cl.exe" + + if self.build_opts.is_darwin(): + # Try to persuade cmake to set the rpath to match the lib + # dirs of the dependencies. This isn't automatic, and to + # make things more interesting, cmake uses `;` as the path + # separator, so translate the runtime path to something + # that cmake will parse + defines["CMAKE_INSTALL_RPATH"] = ";".join( + env.get("DYLD_LIBRARY_PATH", "").split(":") + ) + # Tell cmake that we want to set the rpath in the tree + # at build time. Without this the rpath is only set + # at the moment that the binaries are installed. That + # default is problematic for example when using the + # gtest integration in cmake which runs the built test + # executables during the build to discover the set of + # tests. + defines["CMAKE_BUILD_WITH_INSTALL_RPATH"] = "ON" + + defines.update(self.defines) + define_args = ["-D%s=%s" % (k, v) for (k, v) in defines.items()] + + # if self.build_opts.is_windows(): + # define_args += ["-G", "Visual Studio 15 2017 Win64"] + define_args += ["-G", "Ninja"] + + return define_args + + def _build(self, install_dirs, reconfigure): + reconfigure = reconfigure or self._needs_reconfigure() + + env = self._compute_env(install_dirs) + if not self.build_opts.is_windows() and self.final_install_prefix: + env["DESTDIR"] = self.inst_dir + + # Resolve the cmake that we installed + cmake = path_search(env, "cmake") + if cmake is None: + raise Exception("Failed to find CMake") + + if reconfigure: + define_args = self._compute_cmake_define_args(env) + self._write_build_script( + cmd_prefix=self._get_cmd_prefix(), + cmake=cmake, + ctest=path_search(env, "ctest"), + env=env, + define_args=define_args, + src_dir=self.src_dir, + build_dir=self.build_dir, + install_dir=self.inst_dir, + sys=sys, + ) + + self._invalidate_cache() + self._run_cmd([cmake, self.src_dir] + define_args, env=env) + + self._run_cmd( + [ + cmake, + "--build", + self.build_dir, + "--target", + "install", + "--config", + "Release", + "-j", + str(self.build_opts.num_jobs), + ], + env=env, + ) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + env = self._compute_env(install_dirs) + ctest = path_search(env, "ctest") + cmake = path_search(env, "cmake") + + # On Windows, we also need to update $PATH to include the directories that + # contain runtime library dependencies. This is not needed on other platforms + # since CMake will emit RPATH properly in the binary so they can find these + # dependencies. + if self.build_opts.is_windows(): + path_entries = self.get_dev_run_extra_path_dirs(install_dirs) + path = env.get("PATH") + if path: + path_entries.insert(0, path) + env["PATH"] = ";".join(path_entries) + + # Don't use the cmd_prefix when running tests. This is vcvarsall.bat on + # Windows. vcvarsall.bat is only needed for the build, not tests. It + # unfortunately fails if invoked with a long PATH environment variable when + # running the tests. + use_cmd_prefix = False + + def get_property(test, propname, defval=None): + """extracts a named property from a cmake test info json blob. + The properties look like: + [{"name": "WORKING_DIRECTORY"}, + {"value": "something"}] + We assume that it is invalid for the same named property to be + listed more than once. + """ + props = test.get("properties", []) + for p in props: + if p.get("name", None) == propname: + return p.get("value", defval) + return defval + + def list_tests(): + output = subprocess.check_output( + [ctest, "--show-only=json-v1"], env=env, cwd=self.build_dir + ) + try: + data = json.loads(output.decode("utf-8")) + except ValueError as exc: + raise Exception( + "Failed to decode cmake test info using %s: %s. Output was: %r" + % (ctest, str(exc), output) + ) + + tests = [] + machine_suffix = self.build_opts.host_type.as_tuple_string() + for test in data["tests"]: + working_dir = get_property(test, "WORKING_DIRECTORY") + labels = [] + machine_suffix = self.build_opts.host_type.as_tuple_string() + labels.append("tpx_test_config::buildsystem=getdeps") + labels.append("tpx_test_config::platform={}".format(machine_suffix)) + + if get_property(test, "DISABLED"): + labels.append("disabled") + command = test["command"] + if working_dir: + command = [cmake, "-E", "chdir", working_dir] + command + + import os + + tests.append( + { + "type": "custom", + "target": "%s-%s-getdeps-%s" + % (self.manifest.name, test["name"], machine_suffix), + "command": command, + "labels": labels, + "env": {}, + "required_paths": [], + "contacts": [], + "cwd": os.getcwd(), + } + ) + return tests + + if schedule_type == "continuous" or schedule_type == "testwarden": + # for continuous and testwarden runs, disabling retry can give up + # better signals for flaky tests. + retry = 0 + + from sys import platform + + testpilot = path_search(env, "testpilot") + tpx = path_search(env, "tpx") + if (tpx or testpilot) and not no_testpilot: + buck_test_info = list_tests() + import os + + buck_test_info_name = os.path.join(self.build_dir, ".buck-test-info.json") + with open(buck_test_info_name, "w") as f: + json.dump(buck_test_info, f) + + env.set("http_proxy", "") + env.set("https_proxy", "") + runs = [] + from sys import platform + + if platform == "win32": + machine_suffix = self.build_opts.host_type.as_tuple_string() + testpilot_args = [ + "parexec-testinfra.exe", + "C:/tools/testpilot/sc_testpilot.par", + # Need to force the repo type otherwise testpilot on windows + # can be confused (presumably sparse profile related) + "--force-repo", + "fbcode", + "--force-repo-root", + self.build_opts.fbsource_dir, + "--buck-test-info", + buck_test_info_name, + "--retry=%d" % retry, + "-j=%s" % str(self.build_opts.num_jobs), + "--test-config", + "platform=%s" % machine_suffix, + "buildsystem=getdeps", + "--return-nonzero-on-failures", + ] + else: + testpilot_args = [ + tpx, + "--buck-test-info", + buck_test_info_name, + "--retry=%d" % retry, + "-j=%s" % str(self.build_opts.num_jobs), + "--print-long-results", + ] + + if owner: + testpilot_args += ["--contacts", owner] + + if tpx and env: + testpilot_args.append("--env") + testpilot_args.extend(f"{key}={val}" for key, val in env.items()) + + if test_filter: + testpilot_args += ["--", test_filter] + + if schedule_type == "continuous": + runs.append( + [ + "--tag-new-tests", + "--collection", + "oss-continuous", + "--purpose", + "continuous", + ] + ) + elif schedule_type == "testwarden": + # One run to assess new tests + runs.append( + [ + "--tag-new-tests", + "--collection", + "oss-new-test-stress", + "--stress-runs", + "10", + "--purpose", + "stress-run-new-test", + ] + ) + # And another for existing tests + runs.append( + [ + "--tag-new-tests", + "--collection", + "oss-existing-test-stress", + "--stress-runs", + "10", + "--purpose", + "stress-run", + ] + ) + else: + runs.append(["--collection", "oss-diff", "--purpose", "diff"]) + + for run in runs: + self._run_cmd( + testpilot_args + run, + cwd=self.build_opts.fbcode_builder_dir, + env=env, + use_cmd_prefix=use_cmd_prefix, + ) + else: + args = [ctest, "--output-on-failure", "-j", str(self.build_opts.num_jobs)] + if test_filter: + args += ["-R", test_filter] + + count = 0 + while count <= retry: + retcode = self._run_cmd( + args, env=env, use_cmd_prefix=use_cmd_prefix, allow_fail=True + ) + + if retcode == 0: + break + if count == 0: + # Only add this option in the second run. + args += ["--rerun-failed"] + count += 1 + if retcode != 0: + # Allow except clause in getdeps.main to catch and exit gracefully + # This allows non-testpilot runs to fail through the same logic as failed testpilot runs, which may become handy in case if post test processing is needed in the future + raise subprocess.CalledProcessError(retcode, args) + + +class NinjaBootstrap(BuilderBase): + def __init__(self, build_opts, ctx, manifest, build_dir, src_dir, inst_dir): + super(NinjaBootstrap, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _build(self, install_dirs, reconfigure): + self._run_cmd([sys.executable, "configure.py", "--bootstrap"], cwd=self.src_dir) + src_ninja = os.path.join(self.src_dir, "ninja") + dest_ninja = os.path.join(self.inst_dir, "bin/ninja") + bin_dir = os.path.dirname(dest_ninja) + if not os.path.exists(bin_dir): + os.makedirs(bin_dir) + shutil.copyfile(src_ninja, dest_ninja) + shutil.copymode(src_ninja, dest_ninja) + + +class OpenSSLBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, build_dir, src_dir, inst_dir): + super(OpenSSLBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _build(self, install_dirs, reconfigure): + configure = os.path.join(self.src_dir, "Configure") + + # prefer to resolve the perl that we installed from + # our manifest on windows, but fall back to the system + # path on eg: darwin + env = self.env.copy() + for d in install_dirs: + bindir = os.path.join(d, "bin") + add_path_entry(env, "PATH", bindir, append=False) + + perl = path_search(env, "perl", "perl") + + if self.build_opts.is_windows(): + make = "nmake.exe" + args = ["VC-WIN64A-masm", "-utf-8"] + elif self.build_opts.is_darwin(): + make = "make" + args = ["darwin64-x86_64-cc"] + elif self.build_opts.is_linux(): + make = "make" + args = ( + ["linux-x86_64"] if not self.build_opts.is_arm() else ["linux-aarch64"] + ) + else: + raise Exception("don't know how to build openssl for %r" % self.ctx) + + self._run_cmd( + [ + perl, + configure, + "--prefix=%s" % self.inst_dir, + "--openssldir=%s" % self.inst_dir, + ] + + args + + [ + "enable-static-engine", + "enable-capieng", + "no-makedepend", + "no-unit-test", + "no-tests", + ] + ) + self._run_cmd([make, "install_sw", "install_ssldirs"]) + + +class Boost(BuilderBase): + def __init__( + self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir, b2_args + ): + children = os.listdir(src_dir) + assert len(children) == 1, "expected a single directory entry: %r" % (children,) + boost_src = children[0] + assert boost_src.startswith("boost") + src_dir = os.path.join(src_dir, children[0]) + super(Boost, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.b2_args = b2_args + + def _build(self, install_dirs, reconfigure): + env = self._compute_env(install_dirs) + linkage = ["static"] + if self.build_opts.is_windows(): + linkage.append("shared") + + args = [] + if self.build_opts.is_darwin(): + clang = subprocess.check_output(["xcrun", "--find", "clang"]) + user_config = os.path.join(self.build_dir, "project-config.jam") + with open(user_config, "w") as jamfile: + jamfile.write("using clang : : %s ;\n" % clang.decode().strip()) + args.append("--user-config=%s" % user_config) + + for link in linkage: + if self.build_opts.is_windows(): + bootstrap = os.path.join(self.src_dir, "bootstrap.bat") + self._run_cmd([bootstrap], cwd=self.src_dir, env=env) + args += ["address-model=64"] + else: + bootstrap = os.path.join(self.src_dir, "bootstrap.sh") + self._run_cmd( + [bootstrap, "--prefix=%s" % self.inst_dir], + cwd=self.src_dir, + env=env, + ) + + b2 = os.path.join(self.src_dir, "b2") + self._run_cmd( + [ + b2, + "-j%s" % self.build_opts.num_jobs, + "--prefix=%s" % self.inst_dir, + "--builddir=%s" % self.build_dir, + ] + + args + + self.b2_args + + [ + "link=%s" % link, + "runtime-link=shared", + "variant=release", + "threading=multi", + "debug-symbols=on", + "visibility=global", + "-d2", + "install", + ], + cwd=self.src_dir, + env=env, + ) + + +class NopBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, src_dir, inst_dir): + super(NopBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, None, inst_dir + ) + + def build(self, install_dirs, reconfigure): + print("Installing %s -> %s" % (self.src_dir, self.inst_dir)) + parent = os.path.dirname(self.inst_dir) + if not os.path.exists(parent): + os.makedirs(parent) + + install_files = self.manifest.get_section_as_ordered_pairs( + "install.files", self.ctx + ) + if install_files: + for src_name, dest_name in self.manifest.get_section_as_ordered_pairs( + "install.files", self.ctx + ): + full_dest = os.path.join(self.inst_dir, dest_name) + full_src = os.path.join(self.src_dir, src_name) + + dest_parent = os.path.dirname(full_dest) + if not os.path.exists(dest_parent): + os.makedirs(dest_parent) + if os.path.isdir(full_src): + if not os.path.exists(full_dest): + shutil.copytree(full_src, full_dest) + else: + shutil.copyfile(full_src, full_dest) + shutil.copymode(full_src, full_dest) + # This is a bit gross, but the mac ninja.zip doesn't + # give ninja execute permissions, so force them on + # for things that look like they live in a bin dir + if os.path.dirname(dest_name) == "bin": + st = os.lstat(full_dest) + os.chmod(full_dest, st.st_mode | stat.S_IXUSR) + else: + if not os.path.exists(self.inst_dir): + shutil.copytree(self.src_dir, self.inst_dir) + + +class OpenNSABuilder(NopBuilder): + # OpenNSA libraries are stored with git LFS. As a result, fetcher fetches + # LFS pointers and not the contents. Use git-lfs to pull the real contents + # before copying to install dir using NoopBuilder. + # In future, if more builders require git-lfs, we would consider installing + # git-lfs as part of the sandcastle infra as against repeating similar + # logic for each builder that requires git-lfs. + def __init__(self, build_opts, ctx, manifest, src_dir, inst_dir): + super(OpenNSABuilder, self).__init__( + build_opts, ctx, manifest, src_dir, inst_dir + ) + + def build(self, install_dirs, reconfigure): + env = self._compute_env(install_dirs) + self._run_cmd(["git", "lfs", "install", "--local"], cwd=self.src_dir, env=env) + self._run_cmd(["git", "lfs", "pull"], cwd=self.src_dir, env=env) + + super(OpenNSABuilder, self).build(install_dirs, reconfigure) + + +class SqliteBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir): + super(SqliteBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _build(self, install_dirs, reconfigure): + for f in ["sqlite3.c", "sqlite3.h", "sqlite3ext.h"]: + src = os.path.join(self.src_dir, f) + dest = os.path.join(self.build_dir, f) + copy_if_different(src, dest) + + cmake_lists = """ +cmake_minimum_required(VERSION 3.1.3 FATAL_ERROR) +project(sqlite3 C) +add_library(sqlite3 STATIC sqlite3.c) +# These options are taken from the defaults in Makefile.msc in +# the sqlite distribution +target_compile_definitions(sqlite3 PRIVATE + -DSQLITE_ENABLE_COLUMN_METADATA=1 + -DSQLITE_ENABLE_FTS3=1 + -DSQLITE_ENABLE_RTREE=1 + -DSQLITE_ENABLE_GEOPOLY=1 + -DSQLITE_ENABLE_JSON1=1 + -DSQLITE_ENABLE_STMTVTAB=1 + -DSQLITE_ENABLE_DBPAGE_VTAB=1 + -DSQLITE_ENABLE_DBSTAT_VTAB=1 + -DSQLITE_INTROSPECTION_PRAGMAS=1 + -DSQLITE_ENABLE_DESERIALIZE=1 +) +install(TARGETS sqlite3) +install(FILES sqlite3.h sqlite3ext.h DESTINATION include) + """ + + with open(os.path.join(self.build_dir, "CMakeLists.txt"), "w") as f: + f.write(cmake_lists) + + defines = { + "CMAKE_INSTALL_PREFIX": self.inst_dir, + "BUILD_SHARED_LIBS": "OFF", + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + } + define_args = ["-D%s=%s" % (k, v) for (k, v) in defines.items()] + define_args += ["-G", "Ninja"] + + env = self._compute_env(install_dirs) + + # Resolve the cmake that we installed + cmake = path_search(env, "cmake") + + self._run_cmd([cmake, self.build_dir] + define_args, env=env) + self._run_cmd( + [ + cmake, + "--build", + self.build_dir, + "--target", + "install", + "--config", + "Release", + "-j", + str(self.build_opts.num_jobs), + ], + env=env, + ) + + +class CargoBuilder(BuilderBase): + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + build_doc, + workspace_dir, + manifests_to_build, + loader, + ): + super(CargoBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.build_doc = build_doc + self.ws_dir = workspace_dir + self.manifests_to_build = manifests_to_build and manifests_to_build.split(",") + self.loader = loader + + def run_cargo(self, install_dirs, operation, args=None): + args = args or [] + env = self._compute_env(install_dirs) + # Enable using nightly features with stable compiler + env["RUSTC_BOOTSTRAP"] = "1" + env["LIBZ_SYS_STATIC"] = "1" + cmd = [ + "cargo", + operation, + "--workspace", + "-j%s" % self.build_opts.num_jobs, + ] + args + self._run_cmd(cmd, cwd=self.workspace_dir(), env=env) + + def build_source_dir(self): + return os.path.join(self.build_dir, "source") + + def workspace_dir(self): + return os.path.join(self.build_source_dir(), self.ws_dir or "") + + def manifest_dir(self, manifest): + return os.path.join(self.build_source_dir(), manifest) + + def recreate_dir(self, src, dst): + if os.path.isdir(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + + def _build(self, install_dirs, reconfigure): + build_source_dir = self.build_source_dir() + self.recreate_dir(self.src_dir, build_source_dir) + + dot_cargo_dir = os.path.join(build_source_dir, ".cargo") + if not os.path.isdir(dot_cargo_dir): + os.mkdir(dot_cargo_dir) + + with open(os.path.join(dot_cargo_dir, "config"), "w+") as f: + f.write( + """\ +[build] +target-dir = '''{}''' + +[net] +git-fetch-with-cli = true + +[profile.dev] +debug = false +incremental = false +""".format( + self.build_dir.replace("\\", "\\\\") + ) + ) + + if self.ws_dir is not None: + self._patchup_workspace() + + try: + from getdeps.facebook.rust import vendored_crates + + vendored_crates(self.build_opts, build_source_dir) + except ImportError: + # This FB internal module isn't shippped to github, + # so just rely on cargo downloading crates on it's own + pass + + if self.manifests_to_build is None: + self.run_cargo( + install_dirs, + "build", + ["--out-dir", os.path.join(self.inst_dir, "bin"), "-Zunstable-options"], + ) + else: + for manifest in self.manifests_to_build: + self.run_cargo( + install_dirs, + "build", + [ + "--out-dir", + os.path.join(self.inst_dir, "bin"), + "-Zunstable-options", + "--manifest-path", + self.manifest_dir(manifest), + ], + ) + + self.recreate_dir(build_source_dir, os.path.join(self.inst_dir, "source")) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + if test_filter: + args = ["--", test_filter] + else: + args = [] + + if self.manifests_to_build is None: + self.run_cargo(install_dirs, "test", args) + if self.build_doc: + self.run_cargo(install_dirs, "doc", ["--no-deps"]) + else: + for manifest in self.manifests_to_build: + margs = ["--manifest-path", self.manifest_dir(manifest)] + self.run_cargo(install_dirs, "test", args + margs) + if self.build_doc: + self.run_cargo(install_dirs, "doc", ["--no-deps"] + margs) + + def _patchup_workspace(self): + """ + This method makes some assumptions about the state of the project and + its cargo dependendies: + 1. Crates from cargo dependencies can be extracted from Cargo.toml files + using _extract_crates function. It is using a heuristic so check its + code to understand how it is done. + 2. The extracted cargo dependencies crates can be found in the + dependency's install dir using _resolve_crate_to_path function + which again is using a heuristic. + + Notice that many things might go wrong here. E.g. if someone depends + on another getdeps crate by writing in their Cargo.toml file: + + my-rename-of-crate = { package = "crate", git = "..." } + + they can count themselves lucky because the code will raise an + Exception. There migh be more cases where the code will silently pass + producing bad results. + """ + workspace_dir = self.workspace_dir() + config = self._resolve_config() + if config: + with open(os.path.join(workspace_dir, "Cargo.toml"), "r+") as f: + manifest_content = f.read() + if "[package]" not in manifest_content: + # A fake manifest has to be crated to change the virtual + # manifest into a non-virtual. The virtual manifests are limited + # in many ways and the inability to define patches on them is + # one. Check https://github.com/rust-lang/cargo/issues/4934 to + # see if it is resolved. + f.write( + """ + [package] + name = "fake_manifest_of_{}" + version = "0.0.0" + [lib] + path = "/dev/null" + """.format( + self.manifest.name + ) + ) + else: + f.write("\n") + f.write(config) + + def _resolve_config(self): + """ + Returns a configuration to be put inside root Cargo.toml file which + patches the dependencies git code with local getdeps versions. + See https://doc.rust-lang.org/cargo/reference/manifest.html#the-patch-section + """ + dep_to_git = self._resolve_dep_to_git() + dep_to_crates = CargoBuilder._resolve_dep_to_crates( + self.build_source_dir(), dep_to_git + ) + + config = [] + for name in sorted(dep_to_git.keys()): + git_conf = dep_to_git[name] + crates = sorted(dep_to_crates.get(name, [])) + if not crates: + continue # nothing to patch, move along + crates_patches = [ + '{} = {{ path = "{}" }}'.format( + crate, + CargoBuilder._resolve_crate_to_path(crate, git_conf).replace( + "\\", "\\\\" + ), + ) + for crate in crates + ] + + config.append( + '[patch."{0}"]\n'.format(git_conf["repo_url"]) + + "\n".join(crates_patches) + ) + return "\n".join(config) + + def _resolve_dep_to_git(self): + """ + For each direct dependency of the currently build manifest check if it + is also cargo-builded and if yes then extract it's git configs and + install dir + """ + dependencies = self.manifest.get_section_as_dict("dependencies", ctx=self.ctx) + if not dependencies: + return [] + + dep_to_git = {} + for dep in dependencies.keys(): + dep_manifest = self.loader.load_manifest(dep) + dep_builder = dep_manifest.get("build", "builder", ctx=self.ctx) + if dep_builder not in ["cargo", "nop"] or dep == "rust": + # This is a direct dependency, but it is not build with cargo + # and it is not simply copying files with nop, so ignore it. + # The "rust" dependency is an exception since it contains the + # toolchain. + continue + + git_conf = dep_manifest.get_section_as_dict("git", ctx=self.ctx) + if "repo_url" not in git_conf: + raise Exception( + "A cargo dependency requires git.repo_url to be defined." + ) + source_dir = self.loader.get_project_install_dir(dep_manifest) + if dep_builder == "cargo": + source_dir = os.path.join(source_dir, "source") + git_conf["source_dir"] = source_dir + dep_to_git[dep] = git_conf + return dep_to_git + + @staticmethod + def _resolve_dep_to_crates(build_source_dir, dep_to_git): + """ + This function traverse the build_source_dir in search of Cargo.toml + files, extracts the crate names from them using _extract_crates + function and returns a merged result containing crate names per + dependency name from all Cargo.toml files in the project. + """ + if not dep_to_git: + return {} # no deps, so don't waste time traversing files + + dep_to_crates = {} + for root, _, files in os.walk(build_source_dir): + for f in files: + if f == "Cargo.toml": + more_dep_to_crates = CargoBuilder._extract_crates( + os.path.join(root, f), dep_to_git + ) + for name, crates in more_dep_to_crates.items(): + dep_to_crates.setdefault(name, set()).update(crates) + return dep_to_crates + + @staticmethod + def _extract_crates(cargo_toml_file, dep_to_git): + """ + This functions reads content of provided cargo toml file and extracts + crate names per each dependency. The extraction is done by a heuristic + so it might be incorrect. + """ + deps_to_crates = {} + with open(cargo_toml_file, "r") as f: + for line in f.readlines(): + if line.startswith("#") or "git = " not in line: + continue # filter out commented lines and ones without git deps + for name, conf in dep_to_git.items(): + if 'git = "{}"'.format(conf["repo_url"]) in line: + pkg_template = ' package = "' + if pkg_template in line: + crate_name, _, _ = line.partition(pkg_template)[ + 2 + ].partition('"') + else: + crate_name, _, _ = line.partition("=") + deps_to_crates.setdefault(name, set()).add(crate_name.strip()) + return deps_to_crates + + @staticmethod + def _resolve_crate_to_path(crate, git_conf): + """ + Tries to find in git_conf["inst_dir"] by searching a [package] + keyword followed by name = "". + """ + source_dir = git_conf["source_dir"] + search_pattern = '[package]\nname = "{}"'.format(crate) + + for root, _, files in os.walk(source_dir): + for fname in files: + if fname == "Cargo.toml": + with open(os.path.join(root, fname), "r") as f: + if search_pattern in f.read(): + return root + + raise Exception("Failed to found crate {} in path {}".format(crate, source_dir)) diff --git a/build/fbcode_builder/getdeps/buildopts.py b/build/fbcode_builder/getdeps/buildopts.py new file mode 100644 index 000000000..bc6d2da87 --- /dev/null +++ b/build/fbcode_builder/getdeps/buildopts.py @@ -0,0 +1,458 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import errno +import glob +import ntpath +import os +import subprocess +import sys +import tempfile + +from .copytree import containing_repo_type +from .envfuncs import Env, add_path_entry +from .fetcher import get_fbsource_repo_data +from .manifest import ContextGenerator +from .platform import HostType, is_windows + + +try: + import typing # noqa: F401 +except ImportError: + pass + + +def detect_project(path): + repo_type, repo_root = containing_repo_type(path) + if repo_type is None: + return None, None + + # Look for a .projectid file. If it exists, read the project name from it. + project_id_path = os.path.join(repo_root, ".projectid") + try: + with open(project_id_path, "r") as f: + project_name = f.read().strip() + return repo_root, project_name + except EnvironmentError as ex: + if ex.errno != errno.ENOENT: + raise + + return repo_root, None + + +class BuildOptions(object): + def __init__( + self, + fbcode_builder_dir, + scratch_dir, + host_type, + install_dir=None, + num_jobs=0, + use_shipit=False, + vcvars_path=None, + allow_system_packages=False, + lfs_path=None, + ): + """fbcode_builder_dir - the path to either the in-fbsource fbcode_builder dir, + or for shipit-transformed repos, the build dir that + has been mapped into that dir. + scratch_dir - a place where we can store repos and build bits. + This path should be stable across runs and ideally + should not be in the repo of the project being built, + but that is ultimately where we generally fall back + for builds outside of FB + install_dir - where the project will ultimately be installed + num_jobs - the level of concurrency to use while building + use_shipit - use real shipit instead of the simple shipit transformer + vcvars_path - Path to external VS toolchain's vsvarsall.bat + """ + if not num_jobs: + import multiprocessing + + num_jobs = multiprocessing.cpu_count() // 2 + + if not install_dir: + install_dir = os.path.join(scratch_dir, "installed") + + self.project_hashes = None + for p in ["../deps/github_hashes", "../project_hashes"]: + hashes = os.path.join(fbcode_builder_dir, p) + if os.path.exists(hashes): + self.project_hashes = hashes + break + + # Detect what repository and project we are being run from. + self.repo_root, self.repo_project = detect_project(os.getcwd()) + + # If we are running from an fbsource repository, set self.fbsource_dir + # to allow the ShipIt-based fetchers to use it. + if self.repo_project == "fbsource": + self.fbsource_dir = self.repo_root + else: + self.fbsource_dir = None + + self.num_jobs = num_jobs + self.scratch_dir = scratch_dir + self.install_dir = install_dir + self.fbcode_builder_dir = fbcode_builder_dir + self.host_type = host_type + self.use_shipit = use_shipit + self.allow_system_packages = allow_system_packages + self.lfs_path = lfs_path + if vcvars_path is None and is_windows(): + + # On Windows, the compiler is not available in the PATH by + # default so we need to run the vcvarsall script to populate the + # environment. We use a glob to find some version of this script + # as deployed with Visual Studio 2017. This logic can also + # locate Visual Studio 2019 but note that at the time of writing + # the version of boost in our manifest cannot be built with + # VS 2019, so we're effectively tied to VS 2017 until we upgrade + # the boost dependency. + vcvarsall = [] + for year in ["2017", "2019"]: + vcvarsall += glob.glob( + os.path.join( + os.environ["ProgramFiles(x86)"], + "Microsoft Visual Studio", + year, + "*", + "VC", + "Auxiliary", + "Build", + "vcvarsall.bat", + ) + ) + vcvars_path = vcvarsall[0] + + self.vcvars_path = vcvars_path + + @property + def manifests_dir(self): + return os.path.join(self.fbcode_builder_dir, "manifests") + + def is_darwin(self): + return self.host_type.is_darwin() + + def is_windows(self): + return self.host_type.is_windows() + + def is_arm(self): + return self.host_type.is_arm() + + def get_vcvars_path(self): + return self.vcvars_path + + def is_linux(self): + return self.host_type.is_linux() + + def get_context_generator(self, host_tuple=None, facebook_internal=None): + """Create a manifest ContextGenerator for the specified target platform.""" + if host_tuple is None: + host_type = self.host_type + elif isinstance(host_tuple, HostType): + host_type = host_tuple + else: + host_type = HostType.from_tuple_string(host_tuple) + + # facebook_internal is an Optional[bool] + # If it is None, default to assuming this is a Facebook-internal build if + # we are running in an fbsource repository. + if facebook_internal is None: + facebook_internal = self.fbsource_dir is not None + + return ContextGenerator( + { + "os": host_type.ostype, + "distro": host_type.distro, + "distro_vers": host_type.distrovers, + "fb": "on" if facebook_internal else "off", + "test": "off", + } + ) + + def compute_env_for_install_dirs(self, install_dirs, env=None, manifest=None): + if env is not None: + env = env.copy() + else: + env = Env() + + env["GETDEPS_BUILD_DIR"] = os.path.join(self.scratch_dir, "build") + env["GETDEPS_INSTALL_DIR"] = self.install_dir + + # On macOS we need to set `SDKROOT` when we use clang for system + # header files. + if self.is_darwin() and "SDKROOT" not in env: + sdkroot = subprocess.check_output(["xcrun", "--show-sdk-path"]) + env["SDKROOT"] = sdkroot.decode().strip() + + if self.fbsource_dir: + env["YARN_YARN_OFFLINE_MIRROR"] = os.path.join( + self.fbsource_dir, "xplat/third-party/yarn/offline-mirror" + ) + yarn_exe = "yarn.bat" if self.is_windows() else "yarn" + env["YARN_PATH"] = os.path.join( + self.fbsource_dir, "xplat/third-party/yarn/", yarn_exe + ) + node_exe = "node-win-x64.exe" if self.is_windows() else "node" + env["NODE_BIN"] = os.path.join( + self.fbsource_dir, "xplat/third-party/node/bin/", node_exe + ) + env["RUST_VENDORED_CRATES_DIR"] = os.path.join( + self.fbsource_dir, "third-party/rust/vendor" + ) + hash_data = get_fbsource_repo_data(self) + env["FBSOURCE_HASH"] = hash_data.hash + env["FBSOURCE_DATE"] = hash_data.date + + lib_path = None + if self.is_darwin(): + lib_path = "DYLD_LIBRARY_PATH" + elif self.is_linux(): + lib_path = "LD_LIBRARY_PATH" + elif self.is_windows(): + lib_path = "PATH" + else: + lib_path = None + + for d in install_dirs: + bindir = os.path.join(d, "bin") + + if not ( + manifest and manifest.get("build", "disable_env_override_pkgconfig") + ): + pkgconfig = os.path.join(d, "lib/pkgconfig") + if os.path.exists(pkgconfig): + add_path_entry(env, "PKG_CONFIG_PATH", pkgconfig) + + pkgconfig = os.path.join(d, "lib64/pkgconfig") + if os.path.exists(pkgconfig): + add_path_entry(env, "PKG_CONFIG_PATH", pkgconfig) + + if not (manifest and manifest.get("build", "disable_env_override_path")): + add_path_entry(env, "CMAKE_PREFIX_PATH", d) + + # Allow resolving shared objects built earlier (eg: zstd + # doesn't include the full path to the dylib in its linkage + # so we need to give it an assist) + if lib_path: + for lib in ["lib", "lib64"]: + libdir = os.path.join(d, lib) + if os.path.exists(libdir): + add_path_entry(env, lib_path, libdir) + + # Allow resolving binaries (eg: cmake, ninja) and dlls + # built by earlier steps + if os.path.exists(bindir): + add_path_entry(env, "PATH", bindir, append=False) + + # If rustc is present in the `bin` directory, set RUSTC to prevent + # cargo uses the rustc installed in the system. + if self.is_windows(): + cargo_path = os.path.join(bindir, "cargo.exe") + rustc_path = os.path.join(bindir, "rustc.exe") + rustdoc_path = os.path.join(bindir, "rustdoc.exe") + else: + cargo_path = os.path.join(bindir, "cargo") + rustc_path = os.path.join(bindir, "rustc") + rustdoc_path = os.path.join(bindir, "rustdoc") + + if os.path.isfile(rustc_path): + env["CARGO_BIN"] = cargo_path + env["RUSTC"] = rustc_path + env["RUSTDOC"] = rustdoc_path + + openssl_include = os.path.join(d, "include/openssl") + if os.path.isdir(openssl_include) and any( + os.path.isfile(os.path.join(d, "lib", libcrypto)) + for libcrypto in ("libcrypto.lib", "libcrypto.so", "libcrypto.a") + ): + # This must be the openssl library, let Rust know about it + env["OPENSSL_DIR"] = d + + return env + + +def list_win32_subst_letters(): + output = subprocess.check_output(["subst"]).decode("utf-8") + # The output is a set of lines like: `F:\: => C:\open\some\where` + lines = output.strip().split("\r\n") + mapping = {} + for line in lines: + fields = line.split(": => ") + if len(fields) != 2: + continue + letter = fields[0] + path = fields[1] + mapping[letter] = path + + return mapping + + +def find_existing_win32_subst_for_path( + path, # type: str + subst_mapping, # type: typing.Mapping[str, str] +): + # type: (...) -> typing.Optional[str] + path = ntpath.normcase(ntpath.normpath(path)) + for letter, target in subst_mapping.items(): + if ntpath.normcase(target) == path: + return letter + return None + + +def find_unused_drive_letter(): + import ctypes + + buffer_len = 256 + blen = ctypes.c_uint(buffer_len) + rv = ctypes.c_uint() + bufs = ctypes.create_string_buffer(buffer_len) + rv = ctypes.windll.kernel32.GetLogicalDriveStringsA(blen, bufs) + if rv > buffer_len: + raise Exception("GetLogicalDriveStringsA result too large for buffer") + nul = "\x00".encode("ascii") + + used = [drive.decode("ascii")[0] for drive in bufs.raw.strip(nul).split(nul)] + possible = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"] + available = sorted(list(set(possible) - set(used))) + if len(available) == 0: + return None + # Prefer to assign later letters rather than earlier letters + return available[-1] + + +def create_subst_path(path): + for _attempt in range(0, 24): + drive = find_existing_win32_subst_for_path( + path, subst_mapping=list_win32_subst_letters() + ) + if drive: + return drive + available = find_unused_drive_letter() + if available is None: + raise Exception( + ( + "unable to make shorter subst mapping for %s; " + "no available drive letters" + ) + % path + ) + + # Try to set up a subst mapping; note that we may be racing with + # other processes on the same host, so this may not succeed. + try: + subprocess.check_call(["subst", "%s:" % available, path]) + return "%s:\\" % available + except Exception: + print("Failed to map %s -> %s" % (available, path)) + + raise Exception("failed to set up a subst path for %s" % path) + + +def _check_host_type(args, host_type): + if host_type is None: + host_tuple_string = getattr(args, "host_type", None) + if host_tuple_string: + host_type = HostType.from_tuple_string(host_tuple_string) + else: + host_type = HostType() + + assert isinstance(host_type, HostType) + return host_type + + +def setup_build_options(args, host_type=None): + """Create a BuildOptions object based on the arguments""" + + fbcode_builder_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + scratch_dir = args.scratch_path + if not scratch_dir: + # TODO: `mkscratch` doesn't currently know how best to place things on + # sandcastle, so whip up something reasonable-ish + if "SANDCASTLE" in os.environ: + if "DISK_TEMP" not in os.environ: + raise Exception( + ( + "I need DISK_TEMP to be set in the sandcastle environment " + "so that I can store build products somewhere sane" + ) + ) + scratch_dir = os.path.join( + os.environ["DISK_TEMP"], "fbcode_builder_getdeps" + ) + if not scratch_dir: + try: + scratch_dir = ( + subprocess.check_output( + ["mkscratch", "path", "--subdir", "fbcode_builder_getdeps"] + ) + .strip() + .decode("utf-8") + ) + except OSError as exc: + if exc.errno != errno.ENOENT: + # A legit failure; don't fall back, surface the error + raise + # This system doesn't have mkscratch so we fall back to + # something local. + munged = fbcode_builder_dir.replace("Z", "zZ") + for s in ["/", "\\", ":"]: + munged = munged.replace(s, "Z") + + if is_windows() and os.path.isdir("c:/open"): + temp = "c:/open/scratch" + else: + temp = tempfile.gettempdir() + + scratch_dir = os.path.join(temp, "fbcode_builder_getdeps-%s" % munged) + if not is_windows() and os.geteuid() == 0: + # Running as root; in the case where someone runs + # sudo getdeps.py install-system-deps + # and then runs as build without privs, we want to avoid creating + # a scratch dir that the second stage cannot write to. + # So we generate a different path if we are root. + scratch_dir += "-root" + + if not os.path.exists(scratch_dir): + os.makedirs(scratch_dir) + + if is_windows(): + subst = create_subst_path(scratch_dir) + print( + "Mapping scratch dir %s -> %s" % (scratch_dir, subst), file=sys.stderr + ) + scratch_dir = subst + else: + if not os.path.exists(scratch_dir): + os.makedirs(scratch_dir) + + # Make sure we normalize the scratch path. This path is used as part of the hash + # computation for detecting if projects have been updated, so we need to always + # use the exact same string to refer to a given directory. + # But! realpath in some combinations of Windows/Python3 versions can expand the + # drive substitutions on Windows, so avoid that! + if not is_windows(): + scratch_dir = os.path.realpath(scratch_dir) + + # Save any extra cmake defines passed by the user in an env variable, so it + # can be used while hashing this build. + os.environ["GETDEPS_CMAKE_DEFINES"] = getattr(args, "extra_cmake_defines", "") or "" + + host_type = _check_host_type(args, host_type) + + return BuildOptions( + fbcode_builder_dir, + scratch_dir, + host_type, + install_dir=args.install_prefix, + num_jobs=args.num_jobs, + use_shipit=args.use_shipit, + vcvars_path=args.vcvars_path, + allow_system_packages=args.allow_system_packages, + lfs_path=args.lfs_path, + ) diff --git a/build/fbcode_builder/getdeps/cache.py b/build/fbcode_builder/getdeps/cache.py new file mode 100644 index 000000000..a261541c7 --- /dev/null +++ b/build/fbcode_builder/getdeps/cache.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + + +class ArtifactCache(object): + """The ArtifactCache is a small abstraction that allows caching + named things in some external storage mechanism. + The primary use case is for storing the build products on CI + systems to accelerate the build""" + + def download_to_file(self, name, dest_file_name): + """If `name` exists in the cache, download it and place it + in the specified `dest_file_name` location on the filesystem. + If a transient issue was encountered a TransientFailure shall + be raised. + If `name` doesn't exist in the cache `False` shall be returned. + If `dest_file_name` was successfully updated `True` shall be + returned. + All other conditions shall raise an appropriate exception.""" + return False + + def upload_from_file(self, name, source_file_name): + """Causes `name` to be populated in the cache by uploading + the contents of `source_file_name` to the storage system. + If a transient issue was encountered a TransientFailure shall + be raised. + If the upload failed for some other reason, an appropriate + exception shall be raised.""" + pass + + +def create_cache(): + """This function is monkey patchable to provide an actual + implementation""" + return None diff --git a/build/fbcode_builder/getdeps/copytree.py b/build/fbcode_builder/getdeps/copytree.py new file mode 100644 index 000000000..2790bc0d9 --- /dev/null +++ b/build/fbcode_builder/getdeps/copytree.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import shutil +import subprocess + +from .platform import is_windows + + +PREFETCHED_DIRS = set() + + +def containing_repo_type(path): + while True: + if os.path.exists(os.path.join(path, ".git")): + return ("git", path) + if os.path.exists(os.path.join(path, ".hg")): + return ("hg", path) + + parent = os.path.dirname(path) + if parent == path: + return None, None + path = parent + + +def find_eden_root(dirpath): + """If the specified directory is inside an EdenFS checkout, returns + the canonical absolute path to the root of that checkout. + + Returns None if the specified directory is not in an EdenFS checkout. + """ + if is_windows(): + repo_type, repo_root = containing_repo_type(dirpath) + if repo_root is not None: + if os.path.exists(os.path.join(repo_root, ".eden", "config")): + return os.path.realpath(repo_root) + return None + + try: + return os.readlink(os.path.join(dirpath, ".eden", "root")) + except OSError: + return None + + +def prefetch_dir_if_eden(dirpath): + """After an amend/rebase, Eden may need to fetch a large number + of trees from the servers. The simplistic single threaded walk + performed by copytree makes this more expensive than is desirable + so we help accelerate things by performing a prefetch on the + source directory""" + global PREFETCHED_DIRS + if dirpath in PREFETCHED_DIRS: + return + root = find_eden_root(dirpath) + if root is None: + return + glob = f"{os.path.relpath(dirpath, root).replace(os.sep, '/')}/**" + print(f"Prefetching {glob}") + subprocess.call(["edenfsctl", "prefetch", "--repo", root, "--silent", glob]) + PREFETCHED_DIRS.add(dirpath) + + +def copytree(src_dir, dest_dir, ignore=None): + """Recursively copy the src_dir to the dest_dir, filtering + out entries using the ignore lambda. The behavior of the + ignore lambda must match that described by `shutil.copytree`. + This `copytree` function knows how to prefetch data when + running in an eden repo. + TODO: I'd like to either extend this or add a variant that + uses watchman to mirror src_dir into dest_dir. + """ + prefetch_dir_if_eden(src_dir) + return shutil.copytree(src_dir, dest_dir, ignore=ignore) diff --git a/build/fbcode_builder/getdeps/dyndeps.py b/build/fbcode_builder/getdeps/dyndeps.py new file mode 100644 index 000000000..216f26c46 --- /dev/null +++ b/build/fbcode_builder/getdeps/dyndeps.py @@ -0,0 +1,430 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import errno +import glob +import os +import re +import shutil +import stat +import subprocess +import sys +from struct import unpack + +from .envfuncs import path_search + + +OBJECT_SUBDIRS = ("bin", "lib", "lib64") + + +def copyfile(src, dest): + shutil.copyfile(src, dest) + shutil.copymode(src, dest) + + +class DepBase(object): + def __init__(self, buildopts, install_dirs, strip): + self.buildopts = buildopts + self.env = buildopts.compute_env_for_install_dirs(install_dirs) + self.install_dirs = install_dirs + self.strip = strip + self.processed_deps = set() + + def list_dynamic_deps(self, objfile): + raise RuntimeError("list_dynamic_deps not implemented") + + def interesting_dep(self, d): + return True + + # final_install_prefix must be the equivalent path to `destdir` on the + # installed system. For example, if destdir is `/tmp/RANDOM/usr/local' which + # is intended to map to `/usr/local` in the install image, then + # final_install_prefix='/usr/local'. + # If left unspecified, destdir will be used. + def process_deps(self, destdir, final_install_prefix=None): + if self.buildopts.is_windows(): + lib_dir = "bin" + else: + lib_dir = "lib" + self.munged_lib_dir = os.path.join(destdir, lib_dir) + + final_lib_dir = os.path.join(final_install_prefix or destdir, lib_dir) + + if not os.path.isdir(self.munged_lib_dir): + os.makedirs(self.munged_lib_dir) + + # Look only at the things that got installed in the leaf package, + # which will be the last entry in the install dirs list + inst_dir = self.install_dirs[-1] + print("Process deps under %s" % inst_dir, file=sys.stderr) + + for dir in OBJECT_SUBDIRS: + src_dir = os.path.join(inst_dir, dir) + if not os.path.isdir(src_dir): + continue + dest_dir = os.path.join(destdir, dir) + if not os.path.exists(dest_dir): + os.makedirs(dest_dir) + + for objfile in self.list_objs_in_dir(src_dir): + print("Consider %s/%s" % (dir, objfile)) + dest_obj = os.path.join(dest_dir, objfile) + copyfile(os.path.join(src_dir, objfile), dest_obj) + self.munge_in_place(dest_obj, final_lib_dir) + + def find_all_dependencies(self, build_dir): + all_deps = set() + for objfile in self.list_objs_in_dir( + build_dir, recurse=True, output_prefix=build_dir + ): + for d in self.list_dynamic_deps(objfile): + all_deps.add(d) + + interesting_deps = {d for d in all_deps if self.interesting_dep(d)} + dep_paths = [] + for dep in interesting_deps: + dep_path = self.resolve_loader_path(dep) + if dep_path: + dep_paths.append(dep_path) + + return dep_paths + + def munge_in_place(self, objfile, final_lib_dir): + print("Munging %s" % objfile) + for d in self.list_dynamic_deps(objfile): + if not self.interesting_dep(d): + continue + + # Resolve this dep: does it exist in any of our installation + # directories? If so, then it is a candidate for processing + dep = self.resolve_loader_path(d) + print("dep: %s -> %s" % (d, dep)) + if dep: + dest_dep = os.path.join(self.munged_lib_dir, os.path.basename(dep)) + if dep not in self.processed_deps: + self.processed_deps.add(dep) + copyfile(dep, dest_dep) + self.munge_in_place(dest_dep, final_lib_dir) + + self.rewrite_dep(objfile, d, dep, dest_dep, final_lib_dir) + + if self.strip: + self.strip_debug_info(objfile) + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + raise RuntimeError("rewrite_dep not implemented") + + def resolve_loader_path(self, dep): + if os.path.isabs(dep): + return dep + d = os.path.basename(dep) + for inst_dir in self.install_dirs: + for libdir in OBJECT_SUBDIRS: + candidate = os.path.join(inst_dir, libdir, d) + if os.path.exists(candidate): + return candidate + return None + + def list_objs_in_dir(self, dir, recurse=False, output_prefix=""): + for entry in os.listdir(dir): + entry_path = os.path.join(dir, entry) + st = os.lstat(entry_path) + if stat.S_ISREG(st.st_mode): + if self.is_objfile(entry_path): + relative_result = os.path.join(output_prefix, entry) + yield os.path.normcase(relative_result) + elif recurse and stat.S_ISDIR(st.st_mode): + child_prefix = os.path.join(output_prefix, entry) + for result in self.list_objs_in_dir( + entry_path, recurse=recurse, output_prefix=child_prefix + ): + yield result + + def is_objfile(self, objfile): + return True + + def strip_debug_info(self, objfile): + """override this to define how to remove debug information + from an object file""" + pass + + +class WinDeps(DepBase): + def __init__(self, buildopts, install_dirs, strip): + super(WinDeps, self).__init__(buildopts, install_dirs, strip) + self.dumpbin = self.find_dumpbin() + + def find_dumpbin(self): + # Looking for dumpbin in the following hardcoded paths. + # The registry option to find the install dir doesn't work anymore. + globs = [ + ( + "C:/Program Files (x86)/" + "Microsoft Visual Studio/" + "*/*/VC/Tools/" + "MSVC/*/bin/Hostx64/x64/dumpbin.exe" + ), + ( + "C:/Program Files (x86)/" + "Common Files/" + "Microsoft/Visual C++ for Python/*/" + "VC/bin/dumpbin.exe" + ), + ("c:/Program Files (x86)/Microsoft Visual Studio */VC/bin/dumpbin.exe"), + ] + for pattern in globs: + for exe in glob.glob(pattern): + return exe + + raise RuntimeError("could not find dumpbin.exe") + + def list_dynamic_deps(self, exe): + deps = [] + print("Resolve deps for %s" % exe) + output = subprocess.check_output( + [self.dumpbin, "/nologo", "/dependents", exe] + ).decode("utf-8") + + lines = output.split("\n") + for line in lines: + m = re.match("\\s+(\\S+.dll)", line, re.IGNORECASE) + if m: + deps.append(m.group(1).lower()) + + return deps + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + # We can't rewrite on windows, but we will + # place the deps alongside the exe so that + # they end up in the search path + pass + + # These are the Windows system dll, which we don't want to copy while + # packaging. + SYSTEM_DLLS = set( # noqa: C405 + [ + "advapi32.dll", + "dbghelp.dll", + "kernel32.dll", + "msvcp140.dll", + "vcruntime140.dll", + "ws2_32.dll", + "ntdll.dll", + "shlwapi.dll", + ] + ) + + def interesting_dep(self, d): + if "api-ms-win-crt" in d: + return False + if d in self.SYSTEM_DLLS: + return False + return True + + def is_objfile(self, objfile): + if not os.path.isfile(objfile): + return False + if objfile.lower().endswith(".exe"): + return True + return False + + def emit_dev_run_script(self, script_path, dep_dirs): + """Emit a script that can be used to run build artifacts directly from the + build directory, without installing them. + + The dep_dirs parameter should be a list of paths that need to be added to $PATH. + This can be computed by calling compute_dependency_paths() or + compute_dependency_paths_fast(). + + This is only necessary on Windows, which does not have RPATH, and instead + requires the $PATH environment variable be updated in order to find the proper + library dependencies. + """ + contents = self._get_dev_run_script_contents(dep_dirs) + with open(script_path, "w") as f: + f.write(contents) + + def compute_dependency_paths(self, build_dir): + """Return a list of all directories that need to be added to $PATH to ensure + that library dependencies can be found correctly. This is computed by scanning + binaries to determine exactly the right list of dependencies. + + The compute_dependency_paths_fast() is a alternative function that runs faster + but may return additional extraneous paths. + """ + dep_dirs = set() + # Find paths by scanning the binaries. + for dep in self.find_all_dependencies(build_dir): + dep_dirs.add(os.path.dirname(dep)) + + dep_dirs.update(self.read_custom_dep_dirs(build_dir)) + return sorted(dep_dirs) + + def compute_dependency_paths_fast(self, build_dir): + """Similar to compute_dependency_paths(), but rather than actually scanning + binaries, just add all library paths from the specified installation + directories. This is much faster than scanning the binaries, but may result in + more paths being returned than actually necessary. + """ + dep_dirs = set() + for inst_dir in self.install_dirs: + for subdir in OBJECT_SUBDIRS: + path = os.path.join(inst_dir, subdir) + if os.path.exists(path): + dep_dirs.add(path) + + dep_dirs.update(self.read_custom_dep_dirs(build_dir)) + return sorted(dep_dirs) + + def read_custom_dep_dirs(self, build_dir): + # The build system may also have included libraries from other locations that + # we might not be able to find normally in find_all_dependencies(). + # To handle this situation we support reading additional library paths + # from a LIBRARY_DEP_DIRS.txt file that may have been generated in the build + # output directory. + dep_dirs = set() + try: + explicit_dep_dirs_path = os.path.join(build_dir, "LIBRARY_DEP_DIRS.txt") + with open(explicit_dep_dirs_path, "r") as f: + for line in f.read().splitlines(): + dep_dirs.add(line) + except OSError as ex: + if ex.errno != errno.ENOENT: + raise + + return dep_dirs + + def _get_dev_run_script_contents(self, path_dirs): + path_entries = ["$env:PATH"] + path_dirs + path_str = ";".join(path_entries) + return """\ +$orig_env = $env:PATH +$env:PATH = "{path_str}" + +try {{ + $cmd_args = $args[1..$args.length] + & $args[0] @cmd_args +}} finally {{ + $env:PATH = $orig_env +}} +""".format( + path_str=path_str + ) + + +class ElfDeps(DepBase): + def __init__(self, buildopts, install_dirs, strip): + super(ElfDeps, self).__init__(buildopts, install_dirs, strip) + + # We need patchelf to rewrite deps, so ensure that it is built... + subprocess.check_call([sys.executable, sys.argv[0], "build", "patchelf"]) + # ... and that we know where it lives + self.patchelf = os.path.join( + os.fsdecode( + subprocess.check_output( + [sys.executable, sys.argv[0], "show-inst-dir", "patchelf"] + ).strip() + ), + "bin/patchelf", + ) + + def list_dynamic_deps(self, objfile): + out = ( + subprocess.check_output( + [self.patchelf, "--print-needed", objfile], env=dict(self.env.items()) + ) + .decode("utf-8") + .strip() + ) + lines = out.split("\n") + return lines + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + final_dep = os.path.join( + final_lib_dir, os.path.relpath(new_dep, self.munged_lib_dir) + ) + subprocess.check_call( + [self.patchelf, "--replace-needed", depname, final_dep, objfile] + ) + + def is_objfile(self, objfile): + if not os.path.isfile(objfile): + return False + with open(objfile, "rb") as f: + # https://en.wikipedia.org/wiki/Executable_and_Linkable_Format#File_header + magic = f.read(4) + return magic == b"\x7fELF" + + def strip_debug_info(self, objfile): + subprocess.check_call(["strip", objfile]) + + +# MACH-O magic number +MACH_MAGIC = 0xFEEDFACF + + +class MachDeps(DepBase): + def interesting_dep(self, d): + if d.startswith("/usr/lib/") or d.startswith("/System/"): + return False + return True + + def is_objfile(self, objfile): + if not os.path.isfile(objfile): + return False + with open(objfile, "rb") as f: + # mach stores the magic number in native endianness, + # so unpack as native here and compare + header = f.read(4) + if len(header) != 4: + return False + magic = unpack("I", header)[0] + return magic == MACH_MAGIC + + def list_dynamic_deps(self, objfile): + if not self.interesting_dep(objfile): + return + out = ( + subprocess.check_output( + ["otool", "-L", objfile], env=dict(self.env.items()) + ) + .decode("utf-8") + .strip() + ) + lines = out.split("\n") + deps = [] + for line in lines: + m = re.match("\t(\\S+)\\s", line) + if m: + if os.path.basename(m.group(1)) != os.path.basename(objfile): + deps.append(os.path.normcase(m.group(1))) + return deps + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + if objfile.endswith(".dylib"): + # Erase the original location from the id of the shared + # object. It doesn't appear to hurt to retain it, but + # it does look weird, so let's rewrite it to be sure. + subprocess.check_call( + ["install_name_tool", "-id", os.path.basename(objfile), objfile] + ) + final_dep = os.path.join( + final_lib_dir, os.path.relpath(new_dep, self.munged_lib_dir) + ) + + subprocess.check_call( + ["install_name_tool", "-change", depname, final_dep, objfile] + ) + + +def create_dyn_dep_munger(buildopts, install_dirs, strip=False): + if buildopts.is_linux(): + return ElfDeps(buildopts, install_dirs, strip) + if buildopts.is_darwin(): + return MachDeps(buildopts, install_dirs, strip) + if buildopts.is_windows(): + return WinDeps(buildopts, install_dirs, strip) diff --git a/build/fbcode_builder/getdeps/envfuncs.py b/build/fbcode_builder/getdeps/envfuncs.py new file mode 100644 index 000000000..f2e13f16f --- /dev/null +++ b/build/fbcode_builder/getdeps/envfuncs.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import shlex +import sys + + +class Env(object): + def __init__(self, src=None): + self._dict = {} + if src is None: + self.update(os.environ) + else: + self.update(src) + + def update(self, src): + for k, v in src.items(): + self.set(k, v) + + def copy(self): + return Env(self._dict) + + def _key(self, key): + # The `str` cast may not appear to be needed, but without it we run + # into issues when passing the environment to subprocess. The main + # issue is that in python2 `os.environ` (which is the initial source + # of data for the environment) uses byte based strings, but this + # project uses `unicode_literals`. `subprocess` will raise an error + # if the environment that it is passed has a mixture of byte and + # unicode strings. + # It is simplest to force everthing to be `str` for the sake of + # consistency. + key = str(key) + if sys.platform.startswith("win"): + # Windows env var names are case insensitive but case preserving. + # An implementation of PAR files on windows gets confused if + # the env block contains keys with conflicting case, so make a + # pass over the contents to remove any. + # While this O(n) scan is technically expensive and gross, it + # is practically not a problem because the volume of calls is + # relatively low and the cost of manipulating the env is dwarfed + # by the cost of spawning a process on windows. In addition, + # since the processes that we run are expensive anyway, this + # overhead is not the worst thing to worry about. + for k in list(self._dict.keys()): + if str(k).lower() == key.lower(): + return k + elif key in self._dict: + return key + return None + + def get(self, key, defval=None): + key = self._key(key) + if key is None: + return defval + return self._dict[key] + + def __getitem__(self, key): + val = self.get(key) + if key is None: + raise KeyError(key) + return val + + def unset(self, key): + if key is None: + raise KeyError("attempting to unset env[None]") + + key = self._key(key) + if key: + del self._dict[key] + + def __delitem__(self, key): + self.unset(key) + + def __repr__(self): + return repr(self._dict) + + def set(self, key, value): + if key is None: + raise KeyError("attempting to assign env[None] = %r" % value) + + if value is None: + raise ValueError("attempting to assign env[%s] = None" % key) + + # The `str` conversion is important to avoid triggering errors + # with subprocess if we pass in a unicode value; see commentary + # in the `_key` method. + key = str(key) + value = str(value) + + # The `unset` call is necessary on windows where the keys are + # case insensitive. Since this dict is case sensitive, simply + # assigning the value to the new key is not sufficient to remove + # the old value. The `unset` call knows how to match keys and + # remove any potential duplicates. + self.unset(key) + self._dict[key] = value + + def __setitem__(self, key, value): + self.set(key, value) + + def __iter__(self): + return self._dict.__iter__() + + def __len__(self): + return len(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + +def add_path_entry(env, name, item, append=True, separator=os.pathsep): + """Cause `item` to be added to the path style env var named + `name` held in the `env` dict. `append` specifies whether + the item is added to the end (the default) or should be + prepended if `name` already exists.""" + val = env.get(name, "") + if len(val) > 0: + val = val.split(separator) + else: + val = [] + if append: + val.append(item) + else: + val.insert(0, item) + env.set(name, separator.join(val)) + + +def add_flag(env, name, flag, append=True): + """Cause `flag` to be added to the CXXFLAGS-style env var named + `name` held in the `env` dict. `append` specifies whether the + flag is added to the end (the default) or should be prepended if + `name` already exists.""" + val = shlex.split(env.get(name, "")) + if append: + val.append(flag) + else: + val.insert(0, flag) + env.set(name, " ".join(val)) + + +_path_search_cache = {} +_not_found = object() + + +def tpx_path(): + return "xplat/testinfra/tpx/ctp.tpx" + + +def path_search(env, exename, defval=None): + """Search for exename in the PATH specified in env. + exename is eg: `ninja` and this function knows to append a .exe + to the end on windows. + Returns the path to the exe if found, or None if either no + PATH is set in env or no executable is found.""" + + path = env.get("PATH", None) + if path is None: + return defval + + # The project hash computation code searches for C++ compilers (g++, clang, etc) + # repeatedly. Cache the result so we don't end up searching for these over and over + # again. + cache_key = (path, exename) + result = _path_search_cache.get(cache_key, _not_found) + if result is _not_found: + result = _perform_path_search(path, exename) + _path_search_cache[cache_key] = result + return result + + +def _perform_path_search(path, exename): + is_win = sys.platform.startswith("win") + if is_win: + exename = "%s.exe" % exename + + for bindir in path.split(os.pathsep): + full_name = os.path.join(bindir, exename) + if os.path.exists(full_name) and os.path.isfile(full_name): + if not is_win and not os.access(full_name, os.X_OK): + continue + return full_name + + return None diff --git a/build/fbcode_builder/getdeps/errors.py b/build/fbcode_builder/getdeps/errors.py new file mode 100644 index 000000000..3fad1a1de --- /dev/null +++ b/build/fbcode_builder/getdeps/errors.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + + +class TransientFailure(Exception): + """Raising this error causes getdeps to return with an error code + that Sandcastle will consider to be a retryable transient + infrastructure error""" + + pass + + +class ManifestNotFound(Exception): + def __init__(self, manifest_name): + super(Exception, self).__init__("Unable to find manifest '%s'" % manifest_name) diff --git a/build/fbcode_builder/getdeps/expr.py b/build/fbcode_builder/getdeps/expr.py new file mode 100644 index 000000000..6c0485d03 --- /dev/null +++ b/build/fbcode_builder/getdeps/expr.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import re +import shlex + + +def parse_expr(expr_text, valid_variables): + """parses the simple criteria expression syntax used in + dependency specifications. + Returns an ExprNode instance that can be evaluated like this: + + ``` + expr = parse_expr("os=windows") + ok = expr.eval({ + "os": "windows" + }) + ``` + + Whitespace is allowed between tokens. The following terms + are recognized: + + KEY = VALUE # Evaluates to True if ctx[KEY] == VALUE + not(EXPR) # Evaluates to True if EXPR evaluates to False + # and vice versa + all(EXPR1, EXPR2, ...) # Evaluates True if all of the supplied + # EXPR's also evaluate True + any(EXPR1, EXPR2, ...) # Evaluates True if any of the supplied + # EXPR's also evaluate True, False if + # none of them evaluated true. + """ + + p = Parser(expr_text, valid_variables) + return p.parse() + + +class ExprNode(object): + def eval(self, ctx): + return False + + +class TrueExpr(ExprNode): + def eval(self, ctx): + return True + + def __str__(self): + return "true" + + +class NotExpr(ExprNode): + def __init__(self, node): + self._node = node + + def eval(self, ctx): + return not self._node.eval(ctx) + + def __str__(self): + return "not(%s)" % self._node + + +class AllExpr(ExprNode): + def __init__(self, nodes): + self._nodes = nodes + + def eval(self, ctx): + for node in self._nodes: + if not node.eval(ctx): + return False + return True + + def __str__(self): + items = [] + for node in self._nodes: + items.append(str(node)) + return "all(%s)" % ",".join(items) + + +class AnyExpr(ExprNode): + def __init__(self, nodes): + self._nodes = nodes + + def eval(self, ctx): + for node in self._nodes: + if node.eval(ctx): + return True + return False + + def __str__(self): + items = [] + for node in self._nodes: + items.append(str(node)) + return "any(%s)" % ",".join(items) + + +class EqualExpr(ExprNode): + def __init__(self, key, value): + self._key = key + self._value = value + + def eval(self, ctx): + return ctx.get(self._key) == self._value + + def __str__(self): + return "%s=%s" % (self._key, self._value) + + +class Parser(object): + def __init__(self, text, valid_variables): + self.text = text + self.lex = shlex.shlex(text) + self.valid_variables = valid_variables + + def parse(self): + expr = self.top() + garbage = self.lex.get_token() + if garbage != "": + raise Exception( + "Unexpected token %s after EqualExpr in %s" % (garbage, self.text) + ) + return expr + + def top(self): + name = self.ident() + op = self.lex.get_token() + + if op == "(": + parsers = { + "not": self.parse_not, + "any": self.parse_any, + "all": self.parse_all, + } + func = parsers.get(name) + if not func: + raise Exception("invalid term %s in %s" % (name, self.text)) + return func() + + if op == "=": + if name not in self.valid_variables: + raise Exception("unknown variable %r in expression" % (name,)) + return EqualExpr(name, self.lex.get_token()) + + raise Exception( + "Unexpected token sequence '%s %s' in %s" % (name, op, self.text) + ) + + def ident(self): + ident = self.lex.get_token() + if not re.match("[a-zA-Z]+", ident): + raise Exception("expected identifier found %s" % ident) + return ident + + def parse_not(self): + node = self.top() + expr = NotExpr(node) + tok = self.lex.get_token() + if tok != ")": + raise Exception("expected ')' found %s" % tok) + return expr + + def parse_any(self): + nodes = [] + while True: + nodes.append(self.top()) + tok = self.lex.get_token() + if tok == ")": + break + if tok != ",": + raise Exception("expected ',' or ')' but found %s" % tok) + return AnyExpr(nodes) + + def parse_all(self): + nodes = [] + while True: + nodes.append(self.top()) + tok = self.lex.get_token() + if tok == ")": + break + if tok != ",": + raise Exception("expected ',' or ')' but found %s" % tok) + return AllExpr(nodes) diff --git a/build/fbcode_builder/getdeps/fetcher.py b/build/fbcode_builder/getdeps/fetcher.py new file mode 100644 index 000000000..041549ad7 --- /dev/null +++ b/build/fbcode_builder/getdeps/fetcher.py @@ -0,0 +1,771 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import errno +import hashlib +import os +import re +import shutil +import stat +import subprocess +import sys +import tarfile +import time +import zipfile +from datetime import datetime +from typing import Dict, NamedTuple + +from .copytree import prefetch_dir_if_eden +from .envfuncs import Env +from .errors import TransientFailure +from .platform import is_windows +from .runcmd import run_cmd + + +try: + from urllib import urlretrieve + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + from urllib.request import urlretrieve + + +def file_name_is_cmake_file(file_name): + file_name = file_name.lower() + base = os.path.basename(file_name) + return ( + base.endswith(".cmake") + or base.endswith(".cmake.in") + or base == "cmakelists.txt" + ) + + +class ChangeStatus(object): + """Indicates the nature of changes that happened while updating + the source directory. There are two broad uses: + * When extracting archives for third party software we want to + know that we did something (eg: we either extracted code or + we didn't do anything) + * For 1st party code where we use shipit to transform the code, + we want to know if we changed anything so that we can perform + a build, but we generally want to be a little more nuanced + and be able to distinguish between just changing a source file + and whether we might need to reconfigure the build system. + """ + + def __init__(self, all_changed=False): + """Construct a ChangeStatus object. The default is to create + a status that indicates no changes, but passing all_changed=True + will create one that indicates that everything changed""" + if all_changed: + self.source_files = 1 + self.make_files = 1 + else: + self.source_files = 0 + self.make_files = 0 + + def record_change(self, file_name): + """Used by the shipit fetcher to record changes as it updates + files in the destination. If the file name might be one used + in the cmake build system that we use for 1st party code, then + record that as a "make file" change. We could broaden this + to match any file used by various build systems, but it is + only really useful for our internal cmake stuff at this time. + If the file isn't a build file and is under the `fbcode_builder` + dir then we don't class that as an interesting change that we + might need to rebuild, so we ignore it. + Otherwise we record the file as a source file change.""" + + file_name = file_name.lower() + if file_name_is_cmake_file(file_name): + self.make_files += 1 + elif "/fbcode_builder/cmake" in file_name: + self.source_files += 1 + elif "/fbcode_builder/" not in file_name: + self.source_files += 1 + + def sources_changed(self): + """Returns true if any source files were changed during + an update operation. This will typically be used to decide + that the build system to be run on the source dir in an + incremental mode""" + return self.source_files > 0 + + def build_changed(self): + """Returns true if any build files were changed during + an update operation. This will typically be used to decidfe + that the build system should be reconfigured and re-run + as a full build""" + return self.make_files > 0 + + +class Fetcher(object): + """The Fetcher is responsible for fetching and extracting the + sources for project. The Fetcher instance defines where the + extracted data resides and reports this to the consumer via + its `get_src_dir` method.""" + + def update(self): + """Brings the src dir up to date, ideally minimizing + changes so that a subsequent build doesn't over-build. + Returns a ChangeStatus object that helps the caller to + understand the nature of the changes required during + the update.""" + return ChangeStatus() + + def clean(self): + """Reverts any changes that might have been made to + the src dir""" + pass + + def hash(self): + """Returns a hash that identifies the version of the code in the + working copy. For a git repo this is commit hash for the working + copy. For other Fetchers this should relate to the version of + the code in the src dir. The intent is that if a manifest + changes the version/rev of a project that the hash be different. + Importantly, this should be computable without actually fetching + the code, as we want this to factor into a hash used to download + a pre-built version of the code, without having to first download + and extract its sources (eg: boost on windows is pretty painful). + """ + pass + + def get_src_dir(self): + """Returns the source directory that the project was + extracted into""" + pass + + +class LocalDirFetcher(object): + """This class exists to override the normal fetching behavior, and + use an explicit user-specified directory for the project sources. + + This fetcher cannot update or track changes. It always reports that the + project has changed, forcing it to always be built.""" + + def __init__(self, path): + self.path = os.path.realpath(path) + + def update(self): + return ChangeStatus(all_changed=True) + + def hash(self): + return "0" * 40 + + def get_src_dir(self): + return self.path + + +class SystemPackageFetcher(object): + def __init__(self, build_options, packages): + self.manager = build_options.host_type.get_package_manager() + self.packages = packages.get(self.manager) + if self.packages: + self.installed = None + else: + self.installed = False + + def packages_are_installed(self): + if self.installed is not None: + return self.installed + + if self.manager == "rpm": + result = run_cmd(["rpm", "-q"] + self.packages, allow_fail=True) + self.installed = result == 0 + elif self.manager == "deb": + result = run_cmd(["dpkg", "-s"] + self.packages, allow_fail=True) + self.installed = result == 0 + else: + self.installed = False + + return self.installed + + def update(self): + assert self.installed + return ChangeStatus(all_changed=False) + + def hash(self): + return "0" * 40 + + def get_src_dir(self): + return None + + +class PreinstalledNopFetcher(SystemPackageFetcher): + def __init__(self): + self.installed = True + + +class GitFetcher(Fetcher): + DEFAULT_DEPTH = 1 + + def __init__(self, build_options, manifest, repo_url, rev, depth): + # Extract the host/path portions of the URL and generate a flattened + # directory name. eg: + # github.com/facebook/folly.git -> github.com-facebook-folly.git + url = urlparse(repo_url) + directory = "%s%s" % (url.netloc, url.path) + for s in ["/", "\\", ":"]: + directory = directory.replace(s, "-") + + # Place it in a repos dir in the scratch space + repos_dir = os.path.join(build_options.scratch_dir, "repos") + if not os.path.exists(repos_dir): + os.makedirs(repos_dir) + self.repo_dir = os.path.join(repos_dir, directory) + + if not rev and build_options.project_hashes: + hash_file = os.path.join( + build_options.project_hashes, + re.sub("\\.git$", "-rev.txt", url.path[1:]), + ) + if os.path.exists(hash_file): + with open(hash_file, "r") as f: + data = f.read() + m = re.match("Subproject commit ([a-fA-F0-9]{40})", data) + if not m: + raise Exception("Failed to parse rev from %s" % hash_file) + rev = m.group(1) + print("Using pinned rev %s for %s" % (rev, repo_url)) + + self.rev = rev or "master" + self.origin_repo = repo_url + self.manifest = manifest + self.depth = depth if depth else GitFetcher.DEFAULT_DEPTH + + def _update(self): + current_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=self.repo_dir) + .strip() + .decode("utf-8") + ) + target_hash = ( + subprocess.check_output(["git", "rev-parse", self.rev], cwd=self.repo_dir) + .strip() + .decode("utf-8") + ) + if target_hash == current_hash: + # It's up to date, so there are no changes. This doesn't detect eg: + # if origin/master moved and rev='master', but that's ok for our purposes; + # we should be using explicit hashes or eg: a stable branch for the cases + # that we care about, and it isn't unreasonable to require that the user + # explicitly perform a clean build if those have moved. For the most + # part we prefer that folks build using a release tarball from github + # rather than use the git protocol, as it is generally a bit quicker + # to fetch and easier to hash and verify tarball downloads. + return ChangeStatus() + + print("Updating %s -> %s" % (self.repo_dir, self.rev)) + run_cmd(["git", "fetch", "origin", self.rev], cwd=self.repo_dir) + run_cmd(["git", "checkout", self.rev], cwd=self.repo_dir) + run_cmd(["git", "submodule", "update", "--init"], cwd=self.repo_dir) + + return ChangeStatus(True) + + def update(self): + if os.path.exists(self.repo_dir): + return self._update() + self._clone() + return ChangeStatus(True) + + def _clone(self): + print("Cloning %s..." % self.origin_repo) + # The basename/dirname stuff allows us to dance around issues where + # eg: this python process is native win32, but the git.exe is cygwin + # or msys and doesn't like the absolute windows path that we'd otherwise + # pass to it. Careful use of cwd helps avoid headaches with cygpath. + run_cmd( + [ + "git", + "clone", + "--depth=" + str(self.depth), + "--", + self.origin_repo, + os.path.basename(self.repo_dir), + ], + cwd=os.path.dirname(self.repo_dir), + ) + self._update() + + def clean(self): + if os.path.exists(self.repo_dir): + run_cmd(["git", "clean", "-fxd"], cwd=self.repo_dir) + + def hash(self): + return self.rev + + def get_src_dir(self): + return self.repo_dir + + +def does_file_need_update(src_name, src_st, dest_name): + try: + target_st = os.lstat(dest_name) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + return True + + if src_st.st_size != target_st.st_size: + return True + + if stat.S_IFMT(src_st.st_mode) != stat.S_IFMT(target_st.st_mode): + return True + if stat.S_ISLNK(src_st.st_mode): + return os.readlink(src_name) != os.readlink(dest_name) + if not stat.S_ISREG(src_st.st_mode): + return True + + # They might have the same content; compare. + with open(src_name, "rb") as sf, open(dest_name, "rb") as df: + chunk_size = 8192 + while True: + src_data = sf.read(chunk_size) + dest_data = df.read(chunk_size) + if src_data != dest_data: + return True + if len(src_data) < chunk_size: + # EOF + break + return False + + +def copy_if_different(src_name, dest_name): + """Copy src_name -> dest_name, but only touch dest_name + if src_name is different from dest_name, making this a + more build system friendly way to copy.""" + src_st = os.lstat(src_name) + if not does_file_need_update(src_name, src_st, dest_name): + return False + + dest_parent = os.path.dirname(dest_name) + if not os.path.exists(dest_parent): + os.makedirs(dest_parent) + if stat.S_ISLNK(src_st.st_mode): + try: + os.unlink(dest_name) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + target = os.readlink(src_name) + print("Symlinking %s -> %s" % (dest_name, target)) + os.symlink(target, dest_name) + else: + print("Copying %s -> %s" % (src_name, dest_name)) + shutil.copy2(src_name, dest_name) + + return True + + +def list_files_under_dir_newer_than_timestamp(dir_to_scan, ts): + for root, _dirs, files in os.walk(dir_to_scan): + for src_file in files: + full_name = os.path.join(root, src_file) + st = os.lstat(full_name) + if st.st_mtime > ts: + yield full_name + + +class ShipitPathMap(object): + def __init__(self): + self.roots = [] + self.mapping = [] + self.exclusion = [] + + def add_mapping(self, fbsource_dir, target_dir): + """Add a posix path or pattern. We cannot normpath the input + here because that would change the paths from posix to windows + form and break the logic throughout this class.""" + self.roots.append(fbsource_dir) + self.mapping.append((fbsource_dir, target_dir)) + + def add_exclusion(self, pattern): + self.exclusion.append(re.compile(pattern)) + + def _minimize_roots(self): + """compute the de-duplicated set of roots within fbsource. + We take the shortest common directory prefix to make this + determination""" + self.roots.sort(key=len) + minimized = [] + + for r in self.roots: + add_this_entry = True + for existing in minimized: + if r.startswith(existing + "/"): + add_this_entry = False + break + if add_this_entry: + minimized.append(r) + + self.roots = minimized + + def _sort_mapping(self): + self.mapping.sort(reverse=True, key=lambda x: len(x[0])) + + def _map_name(self, norm_name, dest_root): + if norm_name.endswith(".pyc") or norm_name.endswith(".swp"): + # Ignore some incidental garbage while iterating + return None + + for excl in self.exclusion: + if excl.match(norm_name): + return None + + for src_name, dest_name in self.mapping: + if norm_name == src_name or norm_name.startswith(src_name + "/"): + rel_name = os.path.relpath(norm_name, src_name) + # We can have "." as a component of some paths, depending + # on the contents of the shipit transformation section. + # normpath doesn't always remove `.` as the final component + # of the path, which be problematic when we later mkdir + # the dirname of the path that we return. Take care to avoid + # returning a path with a `.` in it. + rel_name = os.path.normpath(rel_name) + if dest_name == ".": + return os.path.normpath(os.path.join(dest_root, rel_name)) + dest_name = os.path.normpath(dest_name) + return os.path.normpath(os.path.join(dest_root, dest_name, rel_name)) + + raise Exception("%s did not match any rules" % norm_name) + + def mirror(self, fbsource_root, dest_root): + self._minimize_roots() + self._sort_mapping() + + change_status = ChangeStatus() + + # Record the full set of files that should be in the tree + full_file_list = set() + + for fbsource_subdir in self.roots: + dir_to_mirror = os.path.join(fbsource_root, fbsource_subdir) + prefetch_dir_if_eden(dir_to_mirror) + if not os.path.exists(dir_to_mirror): + raise Exception( + "%s doesn't exist; check your sparse profile!" % dir_to_mirror + ) + for root, _dirs, files in os.walk(dir_to_mirror): + for src_file in files: + full_name = os.path.join(root, src_file) + rel_name = os.path.relpath(full_name, fbsource_root) + norm_name = rel_name.replace("\\", "/") + + target_name = self._map_name(norm_name, dest_root) + if target_name: + full_file_list.add(target_name) + if copy_if_different(full_name, target_name): + change_status.record_change(target_name) + + # Compare the list of previously shipped files; if a file is + # in the old list but not the new list then it has been + # removed from the source and should be removed from the + # destination. + # Why don't we simply create this list by walking dest_root? + # Some builds currently have to be in-source builds and + # may legitimately need to keep some state in the source tree :-/ + installed_name = os.path.join(dest_root, ".shipit_shipped") + if os.path.exists(installed_name): + with open(installed_name, "rb") as f: + for name in f.read().decode("utf-8").splitlines(): + name = name.strip() + if name not in full_file_list: + print("Remove %s" % name) + os.unlink(name) + change_status.record_change(name) + + with open(installed_name, "wb") as f: + for name in sorted(list(full_file_list)): + f.write(("%s\n" % name).encode("utf-8")) + + return change_status + + +class FbsourceRepoData(NamedTuple): + hash: str + date: str + + +FBSOURCE_REPO_DATA: Dict[str, FbsourceRepoData] = {} + + +def get_fbsource_repo_data(build_options): + """Returns the commit metadata for the fbsource repo. + Since we may have multiple first party projects to + hash, and because we don't mutate the repo, we cache + this hash in a global.""" + cached_data = FBSOURCE_REPO_DATA.get(build_options.fbsource_dir) + if cached_data: + return cached_data + + cmd = ["hg", "log", "-r.", "-T{node}\n{date|hgdate}"] + env = Env() + env.set("HGPLAIN", "1") + log_data = subprocess.check_output( + cmd, cwd=build_options.fbsource_dir, env=dict(env.items()) + ).decode("ascii") + + (hash, datestr) = log_data.split("\n") + + # datestr is like "seconds fractionalseconds" + # We want "20200324.113140" + (unixtime, _fractional) = datestr.split(" ") + date = datetime.fromtimestamp(int(unixtime)).strftime("%Y%m%d.%H%M%S") + cached_data = FbsourceRepoData(hash=hash, date=date) + + FBSOURCE_REPO_DATA[build_options.fbsource_dir] = cached_data + + return cached_data + + +class SimpleShipitTransformerFetcher(Fetcher): + def __init__(self, build_options, manifest): + self.build_options = build_options + self.manifest = manifest + self.repo_dir = os.path.join(build_options.scratch_dir, "shipit", manifest.name) + + def clean(self): + if os.path.exists(self.repo_dir): + shutil.rmtree(self.repo_dir) + + def update(self): + mapping = ShipitPathMap() + for src, dest in self.manifest.get_section_as_ordered_pairs("shipit.pathmap"): + mapping.add_mapping(src, dest) + if self.manifest.shipit_fbcode_builder: + mapping.add_mapping( + "fbcode/opensource/fbcode_builder", "build/fbcode_builder" + ) + for pattern in self.manifest.get_section_as_args("shipit.strip"): + mapping.add_exclusion(pattern) + + return mapping.mirror(self.build_options.fbsource_dir, self.repo_dir) + + def hash(self): + # We return a fixed non-hash string for in-fbsource builds. + # We're relying on the `update` logic to correctly invalidate + # the build in the case that files have changed. + return "fbsource" + + def get_src_dir(self): + return self.repo_dir + + +class ShipitTransformerFetcher(Fetcher): + SHIPIT = "/var/www/scripts/opensource/shipit/run_shipit.php" + + def __init__(self, build_options, project_name): + self.build_options = build_options + self.project_name = project_name + self.repo_dir = os.path.join(build_options.scratch_dir, "shipit", project_name) + + def update(self): + if os.path.exists(self.repo_dir): + return ChangeStatus() + self.run_shipit() + return ChangeStatus(True) + + def clean(self): + if os.path.exists(self.repo_dir): + shutil.rmtree(self.repo_dir) + + @classmethod + def available(cls): + return os.path.exists(cls.SHIPIT) + + def run_shipit(self): + tmp_path = self.repo_dir + ".new" + try: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + + # Run shipit + run_cmd( + [ + "php", + ShipitTransformerFetcher.SHIPIT, + "--project=" + self.project_name, + "--create-new-repo", + "--source-repo-dir=" + self.build_options.fbsource_dir, + "--source-branch=.", + "--skip-source-init", + "--skip-source-pull", + "--skip-source-clean", + "--skip-push", + "--skip-reset", + "--destination-use-anonymous-https", + "--create-new-repo-output-path=" + tmp_path, + ] + ) + + # Remove the .git directory from the repository it generated. + # There is no need to commit this. + repo_git_dir = os.path.join(tmp_path, ".git") + shutil.rmtree(repo_git_dir) + os.rename(tmp_path, self.repo_dir) + except Exception: + # Clean up after a failed extraction + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + self.clean() + raise + + def hash(self): + # We return a fixed non-hash string for in-fbsource builds. + return "fbsource" + + def get_src_dir(self): + return self.repo_dir + + +def download_url_to_file_with_progress(url, file_name): + print("Download %s -> %s ..." % (url, file_name)) + + class Progress(object): + last_report = 0 + + def progress(self, count, block, total): + if total == -1: + total = "(Unknown)" + amount = count * block + + if sys.stdout.isatty(): + sys.stdout.write("\r downloading %s of %s " % (amount, total)) + else: + # When logging to CI logs, avoid spamming the logs and print + # status every few seconds + now = time.time() + if now - self.last_report > 5: + sys.stdout.write(".. %s of %s " % (amount, total)) + self.last_report = now + sys.stdout.flush() + + progress = Progress() + start = time.time() + try: + (_filename, headers) = urlretrieve(url, file_name, reporthook=progress.progress) + except (OSError, IOError) as exc: # noqa: B014 + raise TransientFailure( + "Failed to download %s to %s: %s" % (url, file_name, str(exc)) + ) + + end = time.time() + sys.stdout.write(" [Complete in %f seconds]\n" % (end - start)) + sys.stdout.flush() + print(f"{headers}") + + +class ArchiveFetcher(Fetcher): + def __init__(self, build_options, manifest, url, sha256): + self.manifest = manifest + self.url = url + self.sha256 = sha256 + self.build_options = build_options + + url = urlparse(self.url) + basename = "%s-%s" % (manifest.name, os.path.basename(url.path)) + self.file_name = os.path.join(build_options.scratch_dir, "downloads", basename) + self.src_dir = os.path.join(build_options.scratch_dir, "extracted", basename) + self.hash_file = self.src_dir + ".hash" + + def _verify_hash(self): + h = hashlib.sha256() + with open(self.file_name, "rb") as f: + while True: + block = f.read(8192) + if not block: + break + h.update(block) + digest = h.hexdigest() + if digest != self.sha256: + os.unlink(self.file_name) + raise Exception( + "%s: expected sha256 %s but got %s" % (self.url, self.sha256, digest) + ) + + def _download_dir(self): + """returns the download dir, creating it if it doesn't already exist""" + download_dir = os.path.dirname(self.file_name) + if not os.path.exists(download_dir): + os.makedirs(download_dir) + return download_dir + + def _download(self): + self._download_dir() + download_url_to_file_with_progress(self.url, self.file_name) + self._verify_hash() + + def clean(self): + if os.path.exists(self.src_dir): + shutil.rmtree(self.src_dir) + + def update(self): + try: + with open(self.hash_file, "r") as f: + saved_hash = f.read().strip() + if saved_hash == self.sha256 and os.path.exists(self.src_dir): + # Everything is up to date + return ChangeStatus() + print( + "saved hash %s doesn't match expected hash %s, re-validating" + % (saved_hash, self.sha256) + ) + os.unlink(self.hash_file) + except EnvironmentError: + pass + + # If we got here we know the contents of src_dir are either missing + # or wrong, so blow away whatever happened to be there first. + if os.path.exists(self.src_dir): + shutil.rmtree(self.src_dir) + + # If we already have a file here, make sure it looks legit before + # proceeding: any errors and we just remove it and re-download + if os.path.exists(self.file_name): + try: + self._verify_hash() + except Exception: + if os.path.exists(self.file_name): + os.unlink(self.file_name) + + if not os.path.exists(self.file_name): + self._download() + + if tarfile.is_tarfile(self.file_name): + opener = tarfile.open + elif zipfile.is_zipfile(self.file_name): + opener = zipfile.ZipFile + else: + raise Exception("don't know how to extract %s" % self.file_name) + os.makedirs(self.src_dir) + print("Extract %s -> %s" % (self.file_name, self.src_dir)) + t = opener(self.file_name) + if is_windows(): + # Ensure that we don't fall over when dealing with long paths + # on windows + src = r"\\?\%s" % os.path.normpath(self.src_dir) + else: + src = self.src_dir + # The `str` here is necessary to ensure that we don't pass a unicode + # object down to tarfile.extractall on python2. When extracting + # the boost tarball it makes some assumptions and tries to convert + # a non-ascii path to ascii and throws. + src = str(src) + t.extractall(src) + + with open(self.hash_file, "w") as f: + f.write(self.sha256) + + return ChangeStatus(True) + + def hash(self): + return self.sha256 + + def get_src_dir(self): + return self.src_dir diff --git a/build/fbcode_builder/getdeps/load.py b/build/fbcode_builder/getdeps/load.py new file mode 100644 index 000000000..c5f40d2fa --- /dev/null +++ b/build/fbcode_builder/getdeps/load.py @@ -0,0 +1,354 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import base64 +import hashlib +import os + +from . import fetcher +from .envfuncs import path_search +from .errors import ManifestNotFound +from .manifest import ManifestParser + + +class Loader(object): + """The loader allows our tests to patch the load operation""" + + def _list_manifests(self, build_opts): + """Returns a generator that iterates all the available manifests""" + for (path, _, files) in os.walk(build_opts.manifests_dir): + for name in files: + # skip hidden files + if name.startswith("."): + continue + + yield os.path.join(path, name) + + def _load_manifest(self, path): + return ManifestParser(path) + + def load_project(self, build_opts, project_name): + if "/" in project_name or "\\" in project_name: + # Assume this is a path already + return ManifestParser(project_name) + + for manifest in self._list_manifests(build_opts): + if os.path.basename(manifest) == project_name: + return ManifestParser(manifest) + + raise ManifestNotFound(project_name) + + def load_all(self, build_opts): + manifests_by_name = {} + + for manifest in self._list_manifests(build_opts): + m = self._load_manifest(manifest) + + if m.name in manifests_by_name: + raise Exception("found duplicate manifest '%s'" % m.name) + + manifests_by_name[m.name] = m + + return manifests_by_name + + +class ResourceLoader(Loader): + def __init__(self, namespace, manifests_dir): + self.namespace = namespace + self.manifests_dir = manifests_dir + + def _list_manifests(self, _build_opts): + import pkg_resources + + dirs = [self.manifests_dir] + + while dirs: + current = dirs.pop(0) + for name in pkg_resources.resource_listdir(self.namespace, current): + path = "%s/%s" % (current, name) + + if pkg_resources.resource_isdir(self.namespace, path): + dirs.append(path) + else: + yield "%s/%s" % (current, name) + + def _find_manifest(self, project_name): + for name in self._list_manifests(): + if name.endswith("/%s" % project_name): + return name + + raise ManifestNotFound(project_name) + + def _load_manifest(self, path): + import pkg_resources + + contents = pkg_resources.resource_string(self.namespace, path).decode("utf8") + return ManifestParser(file_name=path, fp=contents) + + def load_project(self, build_opts, project_name): + project_name = self._find_manifest(project_name) + return self._load_resource_manifest(project_name) + + +LOADER = Loader() + + +def patch_loader(namespace, manifests_dir="manifests"): + global LOADER + LOADER = ResourceLoader(namespace, manifests_dir) + + +def load_project(build_opts, project_name): + """given the name of a project or a path to a manifest file, + load up the ManifestParser instance for it and return it""" + return LOADER.load_project(build_opts, project_name) + + +def load_all_manifests(build_opts): + return LOADER.load_all(build_opts) + + +class ManifestLoader(object): + """ManifestLoader stores information about project manifest relationships for a + given set of (build options + platform) configuration. + + The ManifestLoader class primarily serves as a location to cache project dependency + relationships and project hash values for this build configuration. + """ + + def __init__(self, build_opts, ctx_gen=None): + self._loader = LOADER + self.build_opts = build_opts + if ctx_gen is None: + self.ctx_gen = self.build_opts.get_context_generator() + else: + self.ctx_gen = ctx_gen + + self.manifests_by_name = {} + self._loaded_all = False + self._project_hashes = {} + self._fetcher_overrides = {} + self._build_dir_overrides = {} + self._install_dir_overrides = {} + self._install_prefix_overrides = {} + + def load_manifest(self, name): + manifest = self.manifests_by_name.get(name) + if manifest is None: + manifest = self._loader.load_project(self.build_opts, name) + self.manifests_by_name[name] = manifest + return manifest + + def load_all_manifests(self): + if not self._loaded_all: + all_manifests_by_name = self._loader.load_all(self.build_opts) + if self.manifests_by_name: + # To help ensure that we only ever have a single manifest object for a + # given project, and that it can't change once we have loaded it, + # only update our mapping for projects that weren't already loaded. + for name, manifest in all_manifests_by_name.items(): + self.manifests_by_name.setdefault(name, manifest) + else: + self.manifests_by_name = all_manifests_by_name + self._loaded_all = True + + return self.manifests_by_name + + def manifests_in_dependency_order(self, manifest=None): + """Compute all dependencies of the specified project. Returns a list of the + dependencies plus the project itself, in topologically sorted order. + + Each entry in the returned list only depends on projects that appear before it + in the list. + + If the input manifest is None, the dependencies for all currently loaded + projects will be computed. i.e., if you call load_all_manifests() followed by + manifests_in_dependency_order() this will return a global dependency ordering of + all projects.""" + # The list of deps that have been fully processed + seen = set() + # The list of deps which have yet to be evaluated. This + # can potentially contain duplicates. + if manifest is None: + deps = list(self.manifests_by_name.values()) + else: + assert manifest.name in self.manifests_by_name + deps = [manifest] + # The list of manifests in dependency order + dep_order = [] + + while len(deps) > 0: + m = deps.pop(0) + if m.name in seen: + continue + + # Consider its deps, if any. + # We sort them for increased determinism; we'll produce + # a correct order even if they aren't sorted, but we prefer + # to produce the same order regardless of how they are listed + # in the project manifest files. + ctx = self.ctx_gen.get_context(m.name) + dep_list = sorted(m.get_section_as_dict("dependencies", ctx).keys()) + builder = m.get("build", "builder", ctx=ctx) + if builder in ("cmake", "python-wheel"): + dep_list.append("cmake") + elif builder == "autoconf" and m.name not in ( + "autoconf", + "libtool", + "automake", + ): + # they need libtool and its deps (automake, autoconf) so add + # those as deps (but obviously not if we're building those + # projects themselves) + dep_list.append("libtool") + + dep_count = 0 + for dep_name in dep_list: + # If we're not sure whether it is done, queue it up + if dep_name not in seen: + dep = self.manifests_by_name.get(dep_name) + if dep is None: + dep = self._loader.load_project(self.build_opts, dep_name) + self.manifests_by_name[dep.name] = dep + + deps.append(dep) + dep_count += 1 + + if dep_count > 0: + # If we queued anything, re-queue this item, as it depends + # those new item(s) and their transitive deps. + deps.append(m) + continue + + # Its deps are done, so we can emit it + seen.add(m.name) + dep_order.append(m) + + return dep_order + + def set_project_src_dir(self, project_name, path): + self._fetcher_overrides[project_name] = fetcher.LocalDirFetcher(path) + + def set_project_build_dir(self, project_name, path): + self._build_dir_overrides[project_name] = path + + def set_project_install_dir(self, project_name, path): + self._install_dir_overrides[project_name] = path + + def set_project_install_prefix(self, project_name, path): + self._install_prefix_overrides[project_name] = path + + def create_fetcher(self, manifest): + override = self._fetcher_overrides.get(manifest.name) + if override is not None: + return override + + ctx = self.ctx_gen.get_context(manifest.name) + return manifest.create_fetcher(self.build_opts, ctx) + + def get_project_hash(self, manifest): + h = self._project_hashes.get(manifest.name) + if h is None: + h = self._compute_project_hash(manifest) + self._project_hashes[manifest.name] = h + return h + + def _compute_project_hash(self, manifest): + """This recursive function computes a hash for a given manifest. + The hash takes into account some environmental factors on the + host machine and includes the hashes of its dependencies. + No caching of the computation is performed, which is theoretically + wasteful but the computation is fast enough that it is not required + to cache across multiple invocations.""" + ctx = self.ctx_gen.get_context(manifest.name) + + hasher = hashlib.sha256() + # Some environmental and configuration things matter + env = {} + env["install_dir"] = self.build_opts.install_dir + env["scratch_dir"] = self.build_opts.scratch_dir + env["vcvars_path"] = self.build_opts.vcvars_path + env["os"] = self.build_opts.host_type.ostype + env["distro"] = self.build_opts.host_type.distro + env["distro_vers"] = self.build_opts.host_type.distrovers + for name in [ + "CXXFLAGS", + "CPPFLAGS", + "LDFLAGS", + "CXX", + "CC", + "GETDEPS_CMAKE_DEFINES", + ]: + env[name] = os.environ.get(name) + for tool in ["cc", "c++", "gcc", "g++", "clang", "clang++"]: + env["tool-%s" % tool] = path_search(os.environ, tool) + for name in manifest.get_section_as_args("depends.environment", ctx): + env[name] = os.environ.get(name) + + fetcher = self.create_fetcher(manifest) + env["fetcher.hash"] = fetcher.hash() + + for name in sorted(env.keys()): + hasher.update(name.encode("utf-8")) + value = env.get(name) + if value is not None: + try: + hasher.update(value.encode("utf-8")) + except AttributeError as exc: + raise AttributeError("name=%r, value=%r: %s" % (name, value, exc)) + + manifest.update_hash(hasher, ctx) + + dep_list = sorted(manifest.get_section_as_dict("dependencies", ctx).keys()) + for dep in dep_list: + dep_manifest = self.load_manifest(dep) + dep_hash = self.get_project_hash(dep_manifest) + hasher.update(dep_hash.encode("utf-8")) + + # Use base64 to represent the hash, rather than the simple hex digest, + # so that the string is shorter. Use the URL-safe encoding so that + # the hash can also be safely used as a filename component. + h = base64.urlsafe_b64encode(hasher.digest()).decode("ascii") + # ... and because cmd.exe is troublesome with `=` signs, nerf those. + # They tend to be padding characters at the end anyway, so we can + # safely discard them. + h = h.replace("=", "") + + return h + + def _get_project_dir_name(self, manifest): + if manifest.is_first_party_project(): + return manifest.name + else: + project_hash = self.get_project_hash(manifest) + return "%s-%s" % (manifest.name, project_hash) + + def get_project_install_dir(self, manifest): + override = self._install_dir_overrides.get(manifest.name) + if override: + return override + + project_dir_name = self._get_project_dir_name(manifest) + return os.path.join(self.build_opts.install_dir, project_dir_name) + + def get_project_build_dir(self, manifest): + override = self._build_dir_overrides.get(manifest.name) + if override: + return override + + project_dir_name = self._get_project_dir_name(manifest) + return os.path.join(self.build_opts.scratch_dir, "build", project_dir_name) + + def get_project_install_prefix(self, manifest): + return self._install_prefix_overrides.get(manifest.name) + + def get_project_install_dir_respecting_install_prefix(self, manifest): + inst_dir = self.get_project_install_dir(manifest) + prefix = self.get_project_install_prefix(manifest) + if prefix: + return inst_dir + prefix + return inst_dir diff --git a/build/fbcode_builder/getdeps/manifest.py b/build/fbcode_builder/getdeps/manifest.py new file mode 100644 index 000000000..71566d659 --- /dev/null +++ b/build/fbcode_builder/getdeps/manifest.py @@ -0,0 +1,606 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import io +import os + +from .builder import ( + AutoconfBuilder, + Boost, + CargoBuilder, + CMakeBuilder, + BistroBuilder, + Iproute2Builder, + MakeBuilder, + NinjaBootstrap, + NopBuilder, + OpenNSABuilder, + OpenSSLBuilder, + SqliteBuilder, + CMakeBootStrapBuilder, +) +from .expr import parse_expr +from .fetcher import ( + ArchiveFetcher, + GitFetcher, + PreinstalledNopFetcher, + ShipitTransformerFetcher, + SimpleShipitTransformerFetcher, + SystemPackageFetcher, +) +from .py_wheel_builder import PythonWheelBuilder + + +try: + import configparser +except ImportError: + import ConfigParser as configparser + +REQUIRED = "REQUIRED" +OPTIONAL = "OPTIONAL" + +SCHEMA = { + "manifest": { + "optional_section": False, + "fields": { + "name": REQUIRED, + "fbsource_path": OPTIONAL, + "shipit_project": OPTIONAL, + "shipit_fbcode_builder": OPTIONAL, + }, + }, + "dependencies": {"optional_section": True, "allow_values": False}, + "depends.environment": {"optional_section": True}, + "git": { + "optional_section": True, + "fields": {"repo_url": REQUIRED, "rev": OPTIONAL, "depth": OPTIONAL}, + }, + "download": { + "optional_section": True, + "fields": {"url": REQUIRED, "sha256": REQUIRED}, + }, + "build": { + "optional_section": True, + "fields": { + "builder": REQUIRED, + "subdir": OPTIONAL, + "build_in_src_dir": OPTIONAL, + "disable_env_override_pkgconfig": OPTIONAL, + "disable_env_override_path": OPTIONAL, + }, + }, + "msbuild": {"optional_section": True, "fields": {"project": REQUIRED}}, + "cargo": { + "optional_section": True, + "fields": { + "build_doc": OPTIONAL, + "workspace_dir": OPTIONAL, + "manifests_to_build": OPTIONAL, + }, + }, + "cmake.defines": {"optional_section": True}, + "autoconf.args": {"optional_section": True}, + "rpms": {"optional_section": True}, + "debs": {"optional_section": True}, + "preinstalled.env": {"optional_section": True}, + "b2.args": {"optional_section": True}, + "make.build_args": {"optional_section": True}, + "make.install_args": {"optional_section": True}, + "make.test_args": {"optional_section": True}, + "header-only": {"optional_section": True, "fields": {"includedir": REQUIRED}}, + "shipit.pathmap": {"optional_section": True}, + "shipit.strip": {"optional_section": True}, + "install.files": {"optional_section": True}, +} + +# These sections are allowed to vary for different platforms +# using the expression syntax to enable/disable sections +ALLOWED_EXPR_SECTIONS = [ + "autoconf.args", + "build", + "cmake.defines", + "dependencies", + "make.build_args", + "make.install_args", + "b2.args", + "download", + "git", + "install.files", +] + + +def parse_conditional_section_name(name, section_def): + expr = name[len(section_def) + 1 :] + return parse_expr(expr, ManifestContext.ALLOWED_VARIABLES) + + +def validate_allowed_fields(file_name, section, config, allowed_fields): + for field in config.options(section): + if not allowed_fields.get(field): + raise Exception( + ("manifest file %s section '%s' contains " "unknown field '%s'") + % (file_name, section, field) + ) + + for field in allowed_fields: + if allowed_fields[field] == REQUIRED and not config.has_option(section, field): + raise Exception( + ("manifest file %s section '%s' is missing " "required field '%s'") + % (file_name, section, field) + ) + + +def validate_allow_values(file_name, section, config): + for field in config.options(section): + value = config.get(section, field) + if value is not None: + raise Exception( + ( + "manifest file %s section '%s' has '%s = %s' but " + "this section doesn't allow specifying values " + "for its entries" + ) + % (file_name, section, field, value) + ) + + +def validate_section(file_name, section, config): + section_def = SCHEMA.get(section) + if not section_def: + for name in ALLOWED_EXPR_SECTIONS: + if section.startswith(name + "."): + # Verify that the conditional parses, but discard it + try: + parse_conditional_section_name(section, name) + except Exception as exc: + raise Exception( + ("manifest file %s section '%s' has invalid " "conditional: %s") + % (file_name, section, str(exc)) + ) + section_def = SCHEMA.get(name) + canonical_section_name = name + break + if not section_def: + raise Exception( + "manifest file %s contains unknown section '%s'" % (file_name, section) + ) + else: + canonical_section_name = section + + allowed_fields = section_def.get("fields") + if allowed_fields: + validate_allowed_fields(file_name, section, config, allowed_fields) + elif not section_def.get("allow_values", True): + validate_allow_values(file_name, section, config) + return canonical_section_name + + +class ManifestParser(object): + def __init__(self, file_name, fp=None): + # allow_no_value enables listing parameters in the + # autoconf.args section one per line + config = configparser.RawConfigParser(allow_no_value=True) + config.optionxform = str # make it case sensitive + + if fp is None: + with open(file_name, "r") as fp: + config.read_file(fp) + elif isinstance(fp, type("")): + # For testing purposes, parse from a string (str + # or unicode) + config.read_file(io.StringIO(fp)) + else: + config.read_file(fp) + + # validate against the schema + seen_sections = set() + + for section in config.sections(): + seen_sections.add(validate_section(file_name, section, config)) + + for section in SCHEMA.keys(): + section_def = SCHEMA[section] + if ( + not section_def.get("optional_section", False) + and section not in seen_sections + ): + raise Exception( + "manifest file %s is missing required section %s" + % (file_name, section) + ) + + self._config = config + self.name = config.get("manifest", "name") + self.fbsource_path = self.get("manifest", "fbsource_path") + self.shipit_project = self.get("manifest", "shipit_project") + self.shipit_fbcode_builder = self.get("manifest", "shipit_fbcode_builder") + + if self.name != os.path.basename(file_name): + raise Exception( + "filename of the manifest '%s' does not match the manifest name '%s'" + % (file_name, self.name) + ) + + def get(self, section, key, defval=None, ctx=None): + ctx = ctx or {} + + for s in self._config.sections(): + if s == section: + if self._config.has_option(s, key): + return self._config.get(s, key) + return defval + + if s.startswith(section + "."): + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + + if self._config.has_option(s, key): + return self._config.get(s, key) + + return defval + + def get_section_as_args(self, section, ctx=None): + """Intended for use with the make.[build_args/install_args] and + autoconf.args sections, this method collects the entries and returns an + array of strings. + If the manifest contains conditional sections, ctx is used to + evaluate the condition and merge in the values. + """ + args = [] + ctx = ctx or {} + + for s in self._config.sections(): + if s != section: + if not s.startswith(section + "."): + continue + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + for field in self._config.options(s): + value = self._config.get(s, field) + if value is None: + args.append(field) + else: + args.append("%s=%s" % (field, value)) + return args + + def get_section_as_ordered_pairs(self, section, ctx=None): + """Used for eg: shipit.pathmap which has strong + ordering requirements""" + res = [] + ctx = ctx or {} + + for s in self._config.sections(): + if s != section: + if not s.startswith(section + "."): + continue + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + + for key in self._config.options(s): + value = self._config.get(s, key) + res.append((key, value)) + return res + + def get_section_as_dict(self, section, ctx=None): + d = {} + ctx = ctx or {} + + for s in self._config.sections(): + if s != section: + if not s.startswith(section + "."): + continue + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + for field in self._config.options(s): + value = self._config.get(s, field) + d[field] = value + return d + + def update_hash(self, hasher, ctx): + """Compute a hash over the configuration for the given + context. The goal is for the hash to change if the config + for that context changes, but not if a change is made to + the config only for a different platform than that expressed + by ctx. The hash is intended to be used to help invalidate + a future cache for the third party build products. + The hasher argument is a hash object returned from hashlib.""" + for section in sorted(SCHEMA.keys()): + hasher.update(section.encode("utf-8")) + + # Note: at the time of writing, nothing in the implementation + # relies on keys in any config section being ordered. + # In theory we could have conflicting flags in different + # config sections and later flags override earlier flags. + # For the purposes of computing a hash we're not super + # concerned about this: manifest changes should be rare + # enough and we'd rather that this trigger an invalidation + # than strive for a cache hit at this time. + pairs = self.get_section_as_ordered_pairs(section, ctx) + pairs.sort(key=lambda pair: pair[0]) + for key, value in pairs: + hasher.update(key.encode("utf-8")) + if value is not None: + hasher.update(value.encode("utf-8")) + + def is_first_party_project(self): + """returns true if this is an FB first-party project""" + return self.shipit_project is not None + + def get_required_system_packages(self, ctx): + """Returns dictionary of packager system -> list of packages""" + return { + "rpm": self.get_section_as_args("rpms", ctx), + "deb": self.get_section_as_args("debs", ctx), + } + + def _is_satisfied_by_preinstalled_environment(self, ctx): + envs = self.get_section_as_args("preinstalled.env", ctx) + if not envs: + return False + for key in envs: + val = os.environ.get(key, None) + print(f"Testing ENV[{key}]: {repr(val)}") + if val is None: + return False + if len(val) == 0: + return False + + return True + + def create_fetcher(self, build_options, ctx): + use_real_shipit = ( + ShipitTransformerFetcher.available() and build_options.use_shipit + ) + if ( + not use_real_shipit + and self.fbsource_path + and build_options.fbsource_dir + and self.shipit_project + ): + return SimpleShipitTransformerFetcher(build_options, self) + + if ( + self.fbsource_path + and build_options.fbsource_dir + and self.shipit_project + and ShipitTransformerFetcher.available() + ): + # We can use the code from fbsource + return ShipitTransformerFetcher(build_options, self.shipit_project) + + # Can we satisfy this dep with system packages? + if build_options.allow_system_packages: + if self._is_satisfied_by_preinstalled_environment(ctx): + return PreinstalledNopFetcher() + + packages = self.get_required_system_packages(ctx) + package_fetcher = SystemPackageFetcher(build_options, packages) + if package_fetcher.packages_are_installed(): + return package_fetcher + + repo_url = self.get("git", "repo_url", ctx=ctx) + if repo_url: + rev = self.get("git", "rev") + depth = self.get("git", "depth") + return GitFetcher(build_options, self, repo_url, rev, depth) + + url = self.get("download", "url", ctx=ctx) + if url: + # We need to defer this import until now to avoid triggering + # a cycle when the facebook/__init__.py is loaded. + try: + from getdeps.facebook.lfs import LFSCachingArchiveFetcher + + return LFSCachingArchiveFetcher( + build_options, self, url, self.get("download", "sha256", ctx=ctx) + ) + except ImportError: + # This FB internal module isn't shippped to github, + # so just use its base class + return ArchiveFetcher( + build_options, self, url, self.get("download", "sha256", ctx=ctx) + ) + + raise KeyError( + "project %s has no fetcher configuration matching %s" % (self.name, ctx) + ) + + def create_builder( # noqa:C901 + self, + build_options, + src_dir, + build_dir, + inst_dir, + ctx, + loader, + final_install_prefix=None, + extra_cmake_defines=None, + ): + builder = self.get("build", "builder", ctx=ctx) + if not builder: + raise Exception("project %s has no builder for %r" % (self.name, ctx)) + build_in_src_dir = self.get("build", "build_in_src_dir", "false", ctx=ctx) + if build_in_src_dir == "true": + # Some scripts don't work when they are configured and build in + # a different directory than source (or when the build directory + # is not a subdir of source). + build_dir = src_dir + subdir = self.get("build", "subdir", None, ctx=ctx) + if subdir is not None: + build_dir = os.path.join(build_dir, subdir) + print("build_dir is %s" % build_dir) # just to quiet lint + + if builder == "make" or builder == "cmakebootstrap": + build_args = self.get_section_as_args("make.build_args", ctx) + install_args = self.get_section_as_args("make.install_args", ctx) + test_args = self.get_section_as_args("make.test_args", ctx) + if builder == "cmakebootstrap": + return CMakeBootStrapBuilder( + build_options, + ctx, + self, + src_dir, + None, + inst_dir, + build_args, + install_args, + test_args, + ) + else: + return MakeBuilder( + build_options, + ctx, + self, + src_dir, + None, + inst_dir, + build_args, + install_args, + test_args, + ) + + if builder == "autoconf": + args = self.get_section_as_args("autoconf.args", ctx) + return AutoconfBuilder( + build_options, ctx, self, src_dir, build_dir, inst_dir, args + ) + + if builder == "boost": + args = self.get_section_as_args("b2.args", ctx) + return Boost(build_options, ctx, self, src_dir, build_dir, inst_dir, args) + + if builder == "bistro": + return BistroBuilder( + build_options, + ctx, + self, + src_dir, + build_dir, + inst_dir, + ) + + if builder == "cmake": + defines = self.get_section_as_dict("cmake.defines", ctx) + return CMakeBuilder( + build_options, + ctx, + self, + src_dir, + build_dir, + inst_dir, + defines, + final_install_prefix, + extra_cmake_defines, + ) + + if builder == "python-wheel": + return PythonWheelBuilder( + build_options, ctx, self, src_dir, build_dir, inst_dir + ) + + if builder == "sqlite": + return SqliteBuilder(build_options, ctx, self, src_dir, build_dir, inst_dir) + + if builder == "ninja_bootstrap": + return NinjaBootstrap( + build_options, ctx, self, build_dir, src_dir, inst_dir + ) + + if builder == "nop": + return NopBuilder(build_options, ctx, self, src_dir, inst_dir) + + if builder == "openssl": + return OpenSSLBuilder( + build_options, ctx, self, build_dir, src_dir, inst_dir + ) + + if builder == "iproute2": + return Iproute2Builder( + build_options, ctx, self, src_dir, build_dir, inst_dir + ) + + if builder == "cargo": + build_doc = self.get("cargo", "build_doc", False, ctx) + workspace_dir = self.get("cargo", "workspace_dir", None, ctx) + manifests_to_build = self.get("cargo", "manifests_to_build", None, ctx) + return CargoBuilder( + build_options, + ctx, + self, + src_dir, + build_dir, + inst_dir, + build_doc, + workspace_dir, + manifests_to_build, + loader, + ) + + if builder == "OpenNSA": + return OpenNSABuilder(build_options, ctx, self, src_dir, inst_dir) + + raise KeyError("project %s has no known builder" % (self.name)) + + +class ManifestContext(object): + """ProjectContext contains a dictionary of values to use when evaluating boolean + expressions in a project manifest. + + This object should be passed as the `ctx` parameter in ManifestParser.get() calls. + """ + + ALLOWED_VARIABLES = {"os", "distro", "distro_vers", "fb", "test"} + + def __init__(self, ctx_dict): + assert set(ctx_dict.keys()) == self.ALLOWED_VARIABLES + self.ctx_dict = ctx_dict + + def get(self, key): + return self.ctx_dict[key] + + def set(self, key, value): + assert key in self.ALLOWED_VARIABLES + self.ctx_dict[key] = value + + def copy(self): + return ManifestContext(dict(self.ctx_dict)) + + def __str__(self): + s = ", ".join( + "%s=%s" % (key, value) for key, value in sorted(self.ctx_dict.items()) + ) + return "{" + s + "}" + + +class ContextGenerator(object): + """ContextGenerator allows creating ManifestContext objects on a per-project basis. + This allows us to evaluate different projects with slightly different contexts. + + For instance, this can be used to only enable tests for some projects.""" + + def __init__(self, default_ctx): + self.default_ctx = ManifestContext(default_ctx) + self.ctx_by_project = {} + + def set_value_for_project(self, project_name, key, value): + project_ctx = self.ctx_by_project.get(project_name) + if project_ctx is None: + project_ctx = self.default_ctx.copy() + self.ctx_by_project[project_name] = project_ctx + project_ctx.set(key, value) + + def set_value_for_all_projects(self, key, value): + self.default_ctx.set(key, value) + for ctx in self.ctx_by_project.values(): + ctx.set(key, value) + + def get_context(self, project_name): + return self.ctx_by_project.get(project_name, self.default_ctx) diff --git a/build/fbcode_builder/getdeps/platform.py b/build/fbcode_builder/getdeps/platform.py new file mode 100644 index 000000000..fd8382e73 --- /dev/null +++ b/build/fbcode_builder/getdeps/platform.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import platform +import re +import shlex +import sys + + +def is_windows(): + """Returns true if the system we are currently running on + is a Windows system""" + return sys.platform.startswith("win") + + +def get_linux_type(): + try: + with open("/etc/os-release") as f: + data = f.read() + except EnvironmentError: + return (None, None) + + os_vars = {} + for line in data.splitlines(): + parts = line.split("=", 1) + if len(parts) != 2: + continue + key = parts[0].strip() + value_parts = shlex.split(parts[1].strip()) + if not value_parts: + value = "" + else: + value = value_parts[0] + os_vars[key] = value + + name = os_vars.get("NAME") + if name: + name = name.lower() + name = re.sub("linux", "", name) + name = name.strip() + + version_id = os_vars.get("VERSION_ID") + if version_id: + version_id = version_id.lower() + + return "linux", name, version_id + + +class HostType(object): + def __init__(self, ostype=None, distro=None, distrovers=None): + if ostype is None: + distro = None + distrovers = None + if sys.platform.startswith("linux"): + ostype, distro, distrovers = get_linux_type() + elif sys.platform.startswith("darwin"): + ostype = "darwin" + elif is_windows(): + ostype = "windows" + distrovers = str(sys.getwindowsversion().major) + else: + ostype = sys.platform + + # The operating system type + self.ostype = ostype + # The distribution, if applicable + self.distro = distro + # The OS/distro version if known + self.distrovers = distrovers + machine = platform.machine().lower() + if "arm" in machine or "aarch" in machine: + self.isarm = True + else: + self.isarm = False + + def is_windows(self): + return self.ostype == "windows" + + def is_arm(self): + return self.isarm + + def is_darwin(self): + return self.ostype == "darwin" + + def is_linux(self): + return self.ostype == "linux" + + def as_tuple_string(self): + return "%s-%s-%s" % ( + self.ostype, + self.distro or "none", + self.distrovers or "none", + ) + + def get_package_manager(self): + if not self.is_linux(): + return None + if self.distro in ("fedora", "centos"): + return "rpm" + if self.distro in ("debian", "ubuntu"): + return "deb" + return None + + @staticmethod + def from_tuple_string(s): + ostype, distro, distrovers = s.split("-") + return HostType(ostype=ostype, distro=distro, distrovers=distrovers) + + def __eq__(self, b): + return ( + self.ostype == b.ostype + and self.distro == b.distro + and self.distrovers == b.distrovers + ) diff --git a/build/fbcode_builder/getdeps/py_wheel_builder.py b/build/fbcode_builder/getdeps/py_wheel_builder.py new file mode 100644 index 000000000..82ad8b807 --- /dev/null +++ b/build/fbcode_builder/getdeps/py_wheel_builder.py @@ -0,0 +1,289 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import codecs +import collections +import email +import os +import re +import stat + +from .builder import BuilderBase, CMakeBuilder + + +WheelNameInfo = collections.namedtuple( + "WheelNameInfo", ("distribution", "version", "build", "python", "abi", "platform") +) + +CMAKE_HEADER = """ +cmake_minimum_required(VERSION 3.8) + +project("{manifest_name}" LANGUAGES C) + +set(CMAKE_MODULE_PATH + "{cmake_dir}" + ${{CMAKE_MODULE_PATH}} +) +include(FBPythonBinary) + +set(CMAKE_INSTALL_DIR lib/cmake/{manifest_name} CACHE STRING + "The subdirectory where CMake package config files should be installed") +""" + +CMAKE_FOOTER = """ +install_fb_python_library({lib_name} EXPORT all) +install( + EXPORT all + FILE {manifest_name}-targets.cmake + NAMESPACE {namespace}:: + DESTINATION ${{CMAKE_INSTALL_DIR}} +) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${{CMAKE_BINARY_DIR}}/{manifest_name}-config.cmake.in + {manifest_name}-config.cmake + INSTALL_DESTINATION ${{CMAKE_INSTALL_DIR}} + PATH_VARS + CMAKE_INSTALL_DIR +) +install( + FILES ${{CMAKE_CURRENT_BINARY_DIR}}/{manifest_name}-config.cmake + DESTINATION ${{CMAKE_INSTALL_DIR}} +) +""" + +CMAKE_CONFIG_FILE = """ +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) + +set_and_check({upper_name}_CMAKE_DIR "@PACKAGE_CMAKE_INSTALL_DIR@") + +if (NOT TARGET {namespace}::{lib_name}) + include("${{{upper_name}_CMAKE_DIR}}/{manifest_name}-targets.cmake") +endif() + +set({upper_name}_LIBRARIES {namespace}::{lib_name}) + +{find_dependency_lines} + +if (NOT {manifest_name}_FIND_QUIETLY) + message(STATUS "Found {manifest_name}: ${{PACKAGE_PREFIX_DIR}}") +endif() +""" + + +# Note: for now we are manually manipulating the wheel packet contents. +# The wheel format is documented here: +# https://www.python.org/dev/peps/pep-0491/#file-format +# +# We currently aren't particularly smart about correctly handling the full wheel +# functionality, but this is good enough to handle simple pure-python wheels, +# which is the main thing we care about right now. +# +# We could potentially use pip to install the wheel to a temporary location and +# then copy its "installed" files, but this has its own set of complications. +# This would require pip to already be installed and available, and we would +# need to correctly find the right version of pip or pip3 to use. +# If we did ever want to go down that path, we would probably want to use +# something like the following pip3 command: +# pip3 --isolated install --no-cache-dir --no-index --system \ +# --target +class PythonWheelBuilder(BuilderBase): + """This Builder can take Python wheel archives and install them as python libraries + that can be used by add_fb_python_library()/add_fb_python_executable() CMake rules. + """ + + def _build(self, install_dirs, reconfigure): + # type: (List[str], bool) -> None + + # When we are invoked, self.src_dir contains the unpacked wheel contents. + # + # Since a wheel file is just a zip file, the Fetcher code recognizes it as such + # and goes ahead and unpacks it. (We could disable that Fetcher behavior in the + # future if we ever wanted to, say if we wanted to call pip here.) + wheel_name = self._parse_wheel_name() + name_version_prefix = "-".join((wheel_name.distribution, wheel_name.version)) + dist_info_name = name_version_prefix + ".dist-info" + data_dir_name = name_version_prefix + ".data" + self.dist_info_dir = os.path.join(self.src_dir, dist_info_name) + wheel_metadata = self._read_wheel_metadata(wheel_name) + + # Check that we can understand the wheel version. + # We don't really care about wheel_metadata["Root-Is-Purelib"] since + # we are generating our own standalone python archives rather than installing + # into site-packages. + version = wheel_metadata["Wheel-Version"] + if not version.startswith("1."): + raise Exception("unsupported wheel version %s" % (version,)) + + # Add a find_dependency() call for each of our dependencies. + # The dependencies are also listed in the wheel METADATA file, but it is simpler + # to pull this directly from the getdeps manifest. + dep_list = sorted( + self.manifest.get_section_as_dict("dependencies", self.ctx).keys() + ) + find_dependency_lines = ["find_dependency({})".format(dep) for dep in dep_list] + + getdeps_cmake_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "CMake" + ) + self.template_format_dict = { + # Note that CMake files always uses forward slash separators in path names, + # even on Windows. Therefore replace path separators here. + "cmake_dir": _to_cmake_path(getdeps_cmake_dir), + "lib_name": self.manifest.name, + "manifest_name": self.manifest.name, + "namespace": self.manifest.name, + "upper_name": self.manifest.name.upper().replace("-", "_"), + "find_dependency_lines": "\n".join(find_dependency_lines), + } + + # Find sources from the root directory + path_mapping = {} + for entry in os.listdir(self.src_dir): + if entry in (dist_info_name, data_dir_name): + continue + self._add_sources(path_mapping, os.path.join(self.src_dir, entry), entry) + + # Files under the .data directory also need to be installed in the correct + # locations + if os.path.exists(data_dir_name): + # TODO: process the subdirectories of data_dir_name + # This isn't implemented yet since for now we have only needed dependencies + # on some simple pure Python wheels, so I haven't tested against wheels with + # additional files in the .data directory. + raise Exception( + "handling of the subdirectories inside %s is not implemented yet" + % data_dir_name + ) + + # Emit CMake files + self._write_cmakelists(path_mapping, dep_list) + self._write_cmake_config_template() + + # Run the build + self._run_cmake_build(install_dirs, reconfigure) + + def _run_cmake_build(self, install_dirs, reconfigure): + # type: (List[str], bool) -> None + + cmake_builder = CMakeBuilder( + build_opts=self.build_opts, + ctx=self.ctx, + manifest=self.manifest, + # Note that we intentionally supply src_dir=build_dir, + # since we wrote out our generated CMakeLists.txt in the build directory + src_dir=self.build_dir, + build_dir=self.build_dir, + inst_dir=self.inst_dir, + defines={}, + final_install_prefix=None, + ) + cmake_builder.build(install_dirs=install_dirs, reconfigure=reconfigure) + + def _write_cmakelists(self, path_mapping, dependencies): + # type: (List[str]) -> None + + cmake_path = os.path.join(self.build_dir, "CMakeLists.txt") + with open(cmake_path, "w") as f: + f.write(CMAKE_HEADER.format(**self.template_format_dict)) + for dep in dependencies: + f.write("find_package({0} REQUIRED)\n".format(dep)) + + f.write( + "add_fb_python_library({lib_name}\n".format(**self.template_format_dict) + ) + f.write(' BASE_DIR "%s"\n' % _to_cmake_path(self.src_dir)) + f.write(" SOURCES\n") + for src_path, install_path in path_mapping.items(): + f.write( + ' "%s=%s"\n' + % (_to_cmake_path(src_path), _to_cmake_path(install_path)) + ) + if dependencies: + f.write(" DEPENDS\n") + for dep in dependencies: + f.write(' "{0}::{0}"\n'.format(dep)) + f.write(")\n") + + f.write(CMAKE_FOOTER.format(**self.template_format_dict)) + + def _write_cmake_config_template(self): + config_path_name = self.manifest.name + "-config.cmake.in" + output_path = os.path.join(self.build_dir, config_path_name) + + with open(output_path, "w") as f: + f.write(CMAKE_CONFIG_FILE.format(**self.template_format_dict)) + + def _add_sources(self, path_mapping, src_path, install_path): + # type: (List[str], str, str) -> None + + s = os.lstat(src_path) + if not stat.S_ISDIR(s.st_mode): + path_mapping[src_path] = install_path + return + + for entry in os.listdir(src_path): + self._add_sources( + path_mapping, + os.path.join(src_path, entry), + os.path.join(install_path, entry), + ) + + def _parse_wheel_name(self): + # type: () -> WheelNameInfo + + # The ArchiveFetcher prepends "manifest_name-", so strip that off first. + wheel_name = os.path.basename(self.src_dir) + prefix = self.manifest.name + "-" + if not wheel_name.startswith(prefix): + raise Exception( + "expected wheel source directory to be of the form %s-NAME.whl" + % (prefix,) + ) + wheel_name = wheel_name[len(prefix) :] + + wheel_name_re = re.compile( + r"(?P[^-]+)" + r"-(?P\d+[^-]*)" + r"(-(?P\d+[^-]*))?" + r"-(?P\w+\d+(\.\w+\d+)*)" + r"-(?P\w+)" + r"-(?P\w+(\.\w+)*)" + r"\.whl" + ) + match = wheel_name_re.match(wheel_name) + if not match: + raise Exception( + "bad python wheel name %s: expected to have the form " + "DISTRIBUTION-VERSION-[-BUILD]-PYTAG-ABI-PLATFORM" + ) + + return WheelNameInfo( + distribution=match.group("distribution"), + version=match.group("version"), + build=match.group("build"), + python=match.group("python"), + abi=match.group("abi"), + platform=match.group("platform"), + ) + + def _read_wheel_metadata(self, wheel_name): + metadata_path = os.path.join(self.dist_info_dir, "WHEEL") + with codecs.open(metadata_path, "r", encoding="utf-8") as f: + return email.message_from_file(f) + + +def _to_cmake_path(path): + # CMake always uses forward slashes to separate paths in CMakeLists.txt files, + # even on Windows. It treats backslashes as character escapes, so using + # backslashes in the path will cause problems. Therefore replace all path + # separators with forward slashes to make sure the paths are correct on Windows. + # e.g. "C:\foo\bar.txt" becomes "C:/foo/bar.txt" + return path.replace(os.path.sep, "/") diff --git a/build/fbcode_builder/getdeps/runcmd.py b/build/fbcode_builder/getdeps/runcmd.py new file mode 100644 index 000000000..44e7994aa --- /dev/null +++ b/build/fbcode_builder/getdeps/runcmd.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import select +import subprocess +import sys + +from .envfuncs import Env +from .platform import is_windows + + +try: + from shlex import quote as shellquote +except ImportError: + from pipes import quote as shellquote + + +class RunCommandError(Exception): + pass + + +def _print_env_diff(env, log_fn): + current_keys = set(os.environ.keys()) + wanted_env = set(env.keys()) + + unset_keys = current_keys.difference(wanted_env) + for k in sorted(unset_keys): + log_fn("+ unset %s\n" % k) + + added_keys = wanted_env.difference(current_keys) + for k in wanted_env.intersection(current_keys): + if os.environ[k] != env[k]: + added_keys.add(k) + + for k in sorted(added_keys): + if ("PATH" in k) and (os.pathsep in env[k]): + log_fn("+ %s=\\\n" % k) + for elem in env[k].split(os.pathsep): + log_fn("+ %s%s\\\n" % (shellquote(elem), os.pathsep)) + else: + log_fn("+ %s=%s \\\n" % (k, shellquote(env[k]))) + + +def run_cmd(cmd, env=None, cwd=None, allow_fail=False, log_file=None): + def log_to_stdout(msg): + sys.stdout.buffer.write(msg.encode(errors="surrogateescape")) + + if log_file is not None: + with open(log_file, "a", encoding="utf-8", errors="surrogateescape") as log: + + def log_function(msg): + log.write(msg) + log_to_stdout(msg) + + return _run_cmd( + cmd, env=env, cwd=cwd, allow_fail=allow_fail, log_fn=log_function + ) + else: + return _run_cmd( + cmd, env=env, cwd=cwd, allow_fail=allow_fail, log_fn=log_to_stdout + ) + + +def _run_cmd(cmd, env, cwd, allow_fail, log_fn): + log_fn("---\n") + try: + cmd_str = " \\\n+ ".join(shellquote(arg) for arg in cmd) + except TypeError: + # eg: one of the elements is None + raise RunCommandError("problem quoting cmd: %r" % cmd) + + if env: + assert isinstance(env, Env) + _print_env_diff(env, log_fn) + + # Convert from our Env type to a regular dict. + # This is needed because python3 looks up b'PATH' and 'PATH' + # and emits an error if both are present. In our Env type + # we'll return the same value for both requests, but we don't + # have duplicate potentially conflicting values which is the + # spirit of the check. + env = dict(env.items()) + + if cwd: + log_fn("+ cd %s && \\\n" % shellquote(cwd)) + # Our long path escape sequence may confuse cmd.exe, so if the cwd + # is short enough, strip that off. + if is_windows() and (len(cwd) < 250) and cwd.startswith("\\\\?\\"): + cwd = cwd[4:] + + log_fn("+ %s\n" % cmd_str) + + isinteractive = os.isatty(sys.stdout.fileno()) + if isinteractive: + stdout = None + sys.stdout.buffer.flush() + else: + stdout = subprocess.PIPE + + try: + p = subprocess.Popen( + cmd, env=env, cwd=cwd, stdout=stdout, stderr=subprocess.STDOUT + ) + except (TypeError, ValueError, OSError) as exc: + log_fn("error running `%s`: %s" % (cmd_str, exc)) + raise RunCommandError( + "%s while running `%s` with env=%r\nos.environ=%r" + % (str(exc), cmd_str, env, os.environ) + ) + + if not isinteractive: + _pipe_output(p, log_fn) + + p.wait() + if p.returncode != 0 and not allow_fail: + raise subprocess.CalledProcessError(p.returncode, cmd) + + return p.returncode + + +if hasattr(select, "poll"): + + def _pipe_output(p, log_fn): + """Read output from p.stdout and call log_fn() with each chunk of data as it + becomes available.""" + # Perform non-blocking reads + import fcntl + + fcntl.fcntl(p.stdout.fileno(), fcntl.F_SETFL, os.O_NONBLOCK) + poll = select.poll() + poll.register(p.stdout.fileno(), select.POLLIN) + + buffer_size = 4096 + while True: + poll.poll() + data = p.stdout.read(buffer_size) + if not data: + break + # log_fn() accepts arguments as str (binary in Python 2, unicode in + # Python 3). In Python 3 the subprocess output will be plain bytes, + # and need to be decoded. + if not isinstance(data, str): + data = data.decode("utf-8", errors="surrogateescape") + log_fn(data) + + +else: + + def _pipe_output(p, log_fn): + """Read output from p.stdout and call log_fn() with each chunk of data as it + becomes available.""" + # Perform blocking reads. Use a smaller buffer size to avoid blocking + # for very long when data is available. + buffer_size = 64 + while True: + data = p.stdout.read(buffer_size) + if not data: + break + # log_fn() accepts arguments as str (binary in Python 2, unicode in + # Python 3). In Python 3 the subprocess output will be plain bytes, + # and need to be decoded. + if not isinstance(data, str): + data = data.decode("utf-8", errors="surrogateescape") + log_fn(data) diff --git a/build/fbcode_builder/getdeps/subcmd.py b/build/fbcode_builder/getdeps/subcmd.py new file mode 100644 index 000000000..95f9a07ca --- /dev/null +++ b/build/fbcode_builder/getdeps/subcmd.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + + +class SubCmd(object): + NAME = None + HELP = None + + def run(self, args): + """perform the command""" + return 0 + + def setup_parser(self, parser): + # Subclasses should override setup_parser() if they have any + # command line options or arguments. + pass + + +CmdTable = [] + + +def add_subcommands(parser, common_args, cmd_table=CmdTable): + """Register parsers for the defined commands with the provided parser""" + for cls in cmd_table: + command = cls() + command_parser = parser.add_parser( + command.NAME, help=command.HELP, parents=[common_args] + ) + command.setup_parser(command_parser) + command_parser.set_defaults(func=command.run) + + +def cmd(name, help=None, cmd_table=CmdTable): + """ + @cmd() is a decorator that can be used to help define Subcmd instances + + Example usage: + + @subcmd('list', 'Show the result list') + class ListCmd(Subcmd): + def run(self, args): + # Perform the command actions here... + pass + """ + + def wrapper(cls): + class SubclassedCmd(cls): + NAME = name + HELP = help + + cmd_table.append(SubclassedCmd) + return SubclassedCmd + + return wrapper diff --git a/build/fbcode_builder/getdeps/test/expr_test.py b/build/fbcode_builder/getdeps/test/expr_test.py new file mode 100644 index 000000000..59d66a943 --- /dev/null +++ b/build/fbcode_builder/getdeps/test/expr_test.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import unittest + +from ..expr import parse_expr + + +class ExprTest(unittest.TestCase): + def test_equal(self): + valid_variables = {"foo", "some_var", "another_var"} + e = parse_expr("foo=bar", valid_variables) + self.assertTrue(e.eval({"foo": "bar"})) + self.assertFalse(e.eval({"foo": "not-bar"})) + self.assertFalse(e.eval({"not-foo": "bar"})) + + def test_not_equal(self): + valid_variables = {"foo"} + e = parse_expr("not(foo=bar)", valid_variables) + self.assertFalse(e.eval({"foo": "bar"})) + self.assertTrue(e.eval({"foo": "not-bar"})) + + def test_bad_not(self): + valid_variables = {"foo"} + with self.assertRaises(Exception): + parse_expr("foo=not(bar)", valid_variables) + + def test_bad_variable(self): + valid_variables = {"bar"} + with self.assertRaises(Exception): + parse_expr("foo=bar", valid_variables) + + def test_all(self): + valid_variables = {"foo", "baz"} + e = parse_expr("all(foo = bar, baz = qux)", valid_variables) + self.assertTrue(e.eval({"foo": "bar", "baz": "qux"})) + self.assertFalse(e.eval({"foo": "bar", "baz": "nope"})) + self.assertFalse(e.eval({"foo": "nope", "baz": "nope"})) + + def test_any(self): + valid_variables = {"foo", "baz"} + e = parse_expr("any(foo = bar, baz = qux)", valid_variables) + self.assertTrue(e.eval({"foo": "bar", "baz": "qux"})) + self.assertTrue(e.eval({"foo": "bar", "baz": "nope"})) + self.assertFalse(e.eval({"foo": "nope", "baz": "nope"})) diff --git a/build/fbcode_builder/getdeps/test/fixtures/duplicate/foo b/build/fbcode_builder/getdeps/test/fixtures/duplicate/foo new file mode 100644 index 000000000..a0384ee3b --- /dev/null +++ b/build/fbcode_builder/getdeps/test/fixtures/duplicate/foo @@ -0,0 +1,2 @@ +[manifest] +name = foo diff --git a/build/fbcode_builder/getdeps/test/fixtures/duplicate/subdir/foo b/build/fbcode_builder/getdeps/test/fixtures/duplicate/subdir/foo new file mode 100644 index 000000000..a0384ee3b --- /dev/null +++ b/build/fbcode_builder/getdeps/test/fixtures/duplicate/subdir/foo @@ -0,0 +1,2 @@ +[manifest] +name = foo diff --git a/build/fbcode_builder/getdeps/test/manifest_test.py b/build/fbcode_builder/getdeps/test/manifest_test.py new file mode 100644 index 000000000..8be9896d8 --- /dev/null +++ b/build/fbcode_builder/getdeps/test/manifest_test.py @@ -0,0 +1,233 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import sys +import unittest + +from ..load import load_all_manifests, patch_loader +from ..manifest import ManifestParser + + +class ManifestTest(unittest.TestCase): + def test_missing_section(self): + with self.assertRaisesRegex( + Exception, "manifest file test is missing required section manifest" + ): + ManifestParser("test", "") + + def test_missing_name(self): + with self.assertRaisesRegex( + Exception, + "manifest file test section 'manifest' is missing required field 'name'", + ): + ManifestParser( + "test", + """ +[manifest] +""", + ) + + def test_minimal(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test +""", + ) + self.assertEqual(p.name, "test") + self.assertEqual(p.fbsource_path, None) + + def test_minimal_with_fbsource_path(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test +fbsource_path = fbcode/wat +""", + ) + self.assertEqual(p.name, "test") + self.assertEqual(p.fbsource_path, "fbcode/wat") + + def test_unknown_field(self): + with self.assertRaisesRegex( + Exception, + ( + "manifest file test section 'manifest' contains " + "unknown field 'invalid.field'" + ), + ): + ManifestParser( + "test", + """ +[manifest] +name = test +invalid.field = woot +""", + ) + + def test_invalid_section_name(self): + with self.assertRaisesRegex( + Exception, "manifest file test contains unknown section 'invalid.section'" + ): + ManifestParser( + "test", + """ +[manifest] +name = test + +[invalid.section] +foo = bar +""", + ) + + def test_value_in_dependencies_section(self): + with self.assertRaisesRegex( + Exception, + ( + "manifest file test section 'dependencies' has " + "'foo = bar' but this section doesn't allow " + "specifying values for its entries" + ), + ): + ManifestParser( + "test", + """ +[manifest] +name = test + +[dependencies] +foo = bar +""", + ) + + def test_invalid_conditional_section_name(self): + with self.assertRaisesRegex( + Exception, + ( + "manifest file test section 'dependencies.=' " + "has invalid conditional: expected " + "identifier found =" + ), + ): + ManifestParser( + "test", + """ +[manifest] +name = test + +[dependencies.=] +""", + ) + + def test_section_as_args(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test + +[dependencies] +a +b +c + +[dependencies.test=on] +foo +""", + ) + self.assertEqual(p.get_section_as_args("dependencies"), ["a", "b", "c"]) + self.assertEqual( + p.get_section_as_args("dependencies", {"test": "off"}), ["a", "b", "c"] + ) + self.assertEqual( + p.get_section_as_args("dependencies", {"test": "on"}), + ["a", "b", "c", "foo"], + ) + + p2 = ManifestParser( + "test", + """ +[manifest] +name = test + +[autoconf.args] +--prefix=/foo +--with-woot +""", + ) + self.assertEqual( + p2.get_section_as_args("autoconf.args"), ["--prefix=/foo", "--with-woot"] + ) + + def test_section_as_dict(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test + +[cmake.defines] +foo = bar + +[cmake.defines.test=on] +foo = baz +""", + ) + self.assertEqual(p.get_section_as_dict("cmake.defines"), {"foo": "bar"}) + self.assertEqual( + p.get_section_as_dict("cmake.defines", {"test": "on"}), {"foo": "baz"} + ) + + p2 = ManifestParser( + "test", + """ +[manifest] +name = test + +[cmake.defines.test=on] +foo = baz + +[cmake.defines] +foo = bar +""", + ) + self.assertEqual( + p2.get_section_as_dict("cmake.defines", {"test": "on"}), + {"foo": "bar"}, + msg="sections cascade in the order they appear in the manifest", + ) + + def test_parse_common_manifests(self): + patch_loader(__name__) + manifests = load_all_manifests(None) + self.assertNotEqual(0, len(manifests), msg="parsed some number of manifests") + + def test_mismatch_name(self): + with self.assertRaisesRegex( + Exception, + "filename of the manifest 'foo' does not match the manifest name 'bar'", + ): + ManifestParser( + "foo", + """ +[manifest] +name = bar +""", + ) + + def test_duplicate_manifest(self): + patch_loader(__name__, "fixtures/duplicate") + + with self.assertRaisesRegex(Exception, "found duplicate manifest 'foo'"): + load_all_manifests(None) + + if sys.version_info < (3, 2): + + def assertRaisesRegex(self, *args, **kwargs): + return self.assertRaisesRegexp(*args, **kwargs) diff --git a/build/fbcode_builder/getdeps/test/platform_test.py b/build/fbcode_builder/getdeps/test/platform_test.py new file mode 100644 index 000000000..311e9c76c --- /dev/null +++ b/build/fbcode_builder/getdeps/test/platform_test.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import unittest + +from ..platform import HostType + + +class PlatformTest(unittest.TestCase): + def test_create(self): + p = HostType() + self.assertNotEqual(p.ostype, None, msg="probed and returned something") + + tuple_string = p.as_tuple_string() + round_trip = HostType.from_tuple_string(tuple_string) + self.assertEqual(round_trip, p) + + def test_rendering_of_none(self): + p = HostType(ostype="foo") + self.assertEqual(p.as_tuple_string(), "foo-none-none") + + def test_is_methods(self): + p = HostType(ostype="windows") + self.assertTrue(p.is_windows()) + self.assertFalse(p.is_darwin()) + self.assertFalse(p.is_linux()) + + p = HostType(ostype="darwin") + self.assertFalse(p.is_windows()) + self.assertTrue(p.is_darwin()) + self.assertFalse(p.is_linux()) + + p = HostType(ostype="linux") + self.assertFalse(p.is_windows()) + self.assertFalse(p.is_darwin()) + self.assertTrue(p.is_linux()) diff --git a/build/fbcode_builder/getdeps/test/scratch_test.py b/build/fbcode_builder/getdeps/test/scratch_test.py new file mode 100644 index 000000000..1f43c5951 --- /dev/null +++ b/build/fbcode_builder/getdeps/test/scratch_test.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import unittest + +from ..buildopts import find_existing_win32_subst_for_path + + +class Win32SubstTest(unittest.TestCase): + def test_no_existing_subst(self): + self.assertIsNone( + find_existing_win32_subst_for_path( + r"C:\users\alice\appdata\local\temp\fbcode_builder_getdeps", + subst_mapping={}, + ) + ) + self.assertIsNone( + find_existing_win32_subst_for_path( + r"C:\users\alice\appdata\local\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:\users\alice\appdata\local\temp\other"}, + ) + ) + + def test_exact_match_returns_drive_path(self): + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:\temp\fbcode_builder_getdeps"}, + ), + "X:\\", + ) + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:/temp/fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:/temp/fbcode_builder_getdeps"}, + ), + "X:\\", + ) + + def test_multiple_exact_matches_returns_arbitrary_drive_path(self): + self.assertIn( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={ + "X:\\": r"C:\temp\fbcode_builder_getdeps", + "Y:\\": r"C:\temp\fbcode_builder_getdeps", + "Z:\\": r"C:\temp\fbcode_builder_getdeps", + }, + ), + ("X:\\", "Y:\\", "Z:\\"), + ) + + def test_drive_letter_is_case_insensitive(self): + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"c:\temp\fbcode_builder_getdeps"}, + ), + "X:\\", + ) + + def test_path_components_are_case_insensitive(self): + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\TEMP\FBCODE_builder_getdeps", + subst_mapping={"X:\\": r"C:\temp\fbcode_builder_getdeps"}, + ), + "X:\\", + ) + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:\TEMP\FBCODE_builder_getdeps"}, + ), + "X:\\", + ) diff --git a/build/fbcode_builder/make_docker_context.py b/build/fbcode_builder/make_docker_context.py new file mode 100755 index 000000000..d4b0f0a89 --- /dev/null +++ b/build/fbcode_builder/make_docker_context.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" +Reads `fbcode_builder_config.py` from the current directory, and prepares a +Docker context directory to build this project. Prints to stdout the path +to the context directory. + +Try `.../make_docker_context.py --help` from a project's `build/` directory. + +By default, the Docker context directory will be in /tmp. It will always +contain a Dockerfile, and might also contain copies of your local repos, and +other data needed for the build container. +""" + +import os +import tempfile +import textwrap + +from docker_builder import DockerFBCodeBuilder +from parse_args import parse_args_to_fbcode_builder_opts + + +def make_docker_context( + get_steps_fn, github_project, opts=None, default_context_dir=None +): + """ + Returns a path to the Docker context directory. See parse_args.py. + + Helper for making a command-line utility that writes your project's + Dockerfile and associated data into a (temporary) directory. Your main + program might look something like this: + + print(make_docker_context( + lambda builder: [builder.step(...), ...], + 'facebook/your_project', + )) + """ + + if opts is None: + opts = {} + + valid_versions = ( + ("ubuntu:16.04", "5"), + ("ubuntu:18.04", "7"), + ) + + def add_args(parser): + parser.add_argument( + "--docker-context-dir", + metavar="DIR", + default=default_context_dir, + help="Write the Dockerfile and its context into this directory. " + "If empty, make a temporary directory. Default: %(default)s.", + ) + parser.add_argument( + "--user", + metavar="NAME", + default=opts.get("user", "nobody"), + help="Build and install as this user. Default: %(default)s.", + ) + parser.add_argument( + "--prefix", + metavar="DIR", + default=opts.get("prefix", "/home/install"), + help="Install all libraries in this prefix. Default: %(default)s.", + ) + parser.add_argument( + "--projects-dir", + metavar="DIR", + default=opts.get("projects_dir", "/home"), + help="Place project code directories here. Default: %(default)s.", + ) + parser.add_argument( + "--os-image", + metavar="IMG", + choices=zip(*valid_versions)[0], + default=opts.get("os_image", valid_versions[0][0]), + help="Docker OS image -- be sure to use only ones you trust (See " + "README.docker). Choices: %(choices)s. Default: %(default)s.", + ) + parser.add_argument( + "--gcc-version", + metavar="VER", + choices=set(zip(*valid_versions)[1]), + default=opts.get("gcc_version", valid_versions[0][1]), + help="Choices: %(choices)s. Default: %(default)s.", + ) + parser.add_argument( + "--make-parallelism", + metavar="NUM", + type=int, + default=opts.get("make_parallelism", 1), + help="Use `make -j` on multi-CPU systems with lots of RAM. " + "Default: %(default)s.", + ) + parser.add_argument( + "--local-repo-dir", + metavar="DIR", + help="If set, build {0} from a local directory instead of Github.".format( + github_project + ), + ) + parser.add_argument( + "--ccache-tgz", + metavar="PATH", + help="If set, enable ccache for the build. To initialize the " + "cache, first try to hardlink, then to copy --cache-tgz " + "as ccache.tgz into the --docker-context-dir.", + ) + + opts = parse_args_to_fbcode_builder_opts( + add_args, + # These have add_argument() calls, others are set via --option. + ( + "docker_context_dir", + "user", + "prefix", + "projects_dir", + "os_image", + "gcc_version", + "make_parallelism", + "local_repo_dir", + "ccache_tgz", + ), + opts, + help=textwrap.dedent( + """ + + Reads `fbcode_builder_config.py` from the current directory, and + prepares a Docker context directory to build {github_project} and + its dependencies. Prints to stdout the path to the context + directory. + + Pass --option {github_project}:git_hash SHA1 to build something + other than the master branch from Github. + + Or, pass --option {github_project}:local_repo_dir LOCAL_PATH to + build from a local repo instead of cloning from Github. + + Usage: + (cd $(./make_docker_context.py) && docker build . 2>&1 | tee log) + + """.format( + github_project=github_project + ) + ), + ) + + # This allows travis_docker_build.sh not to know the main Github project. + local_repo_dir = opts.pop("local_repo_dir", None) + if local_repo_dir is not None: + opts["{0}:local_repo_dir".format(github_project)] = local_repo_dir + + if (opts.get("os_image"), opts.get("gcc_version")) not in valid_versions: + raise Exception( + "Due to 4/5 ABI changes (std::string), we can only use {0}".format( + " / ".join("GCC {1} on {0}".format(*p) for p in valid_versions) + ) + ) + + if opts.get("docker_context_dir") is None: + opts["docker_context_dir"] = tempfile.mkdtemp(prefix="docker-context-") + elif not os.path.exists(opts.get("docker_context_dir")): + os.makedirs(opts.get("docker_context_dir")) + + builder = DockerFBCodeBuilder(**opts) + context_dir = builder.option("docker_context_dir") # Mark option "in-use" + # The renderer may also populate some files into the context_dir. + dockerfile = builder.render(get_steps_fn(builder)) + + with os.fdopen( + os.open( + os.path.join(context_dir, "Dockerfile"), + os.O_RDWR | os.O_CREAT | os.O_EXCL, # Do not overwrite existing files + 0o644, + ), + "w", + ) as f: + f.write(dockerfile) + + return context_dir + + +if __name__ == "__main__": + from utils import read_fbcode_builder_config, build_fbcode_builder_config + + # Load a spec from the current directory + config = read_fbcode_builder_config("fbcode_builder_config.py") + print( + make_docker_context( + build_fbcode_builder_config(config), + config["github_project"], + ) + ) diff --git a/build/fbcode_builder/manifests/CLI11 b/build/fbcode_builder/manifests/CLI11 new file mode 100644 index 000000000..14cb2332a --- /dev/null +++ b/build/fbcode_builder/manifests/CLI11 @@ -0,0 +1,14 @@ +[manifest] +name = CLI11 + +[download] +url = https://github.com/CLIUtils/CLI11/archive/v2.0.0.tar.gz +sha256 = 2c672f17bf56e8e6223a3bfb74055a946fa7b1ff376510371902adb9cb0ab6a3 + +[build] +builder = cmake +subdir = CLI11-2.0.0 + +[cmake.defines] +CLI11_BUILD_TESTS = OFF +CLI11_BUILD_EXAMPLES = OFF diff --git a/build/fbcode_builder/manifests/OpenNSA b/build/fbcode_builder/manifests/OpenNSA new file mode 100644 index 000000000..62354c997 --- /dev/null +++ b/build/fbcode_builder/manifests/OpenNSA @@ -0,0 +1,17 @@ +[manifest] +name = OpenNSA + +[download] +url = https://docs.broadcom.com/docs-and-downloads/csg/opennsa-6.5.22.tgz +sha256 = 74bfbdaebb6bfe9ebb0deac3aff624385cdcf5aa416ba63706c36538b3c3c46c + +[build] +builder = nop +subdir = opennsa-6.5.22 + +[install.files] +lib/x86-64 = lib +include = include +src/gpl-modules/systems/bde/linux/include = include/systems/bde/linux +src/gpl-modules/include/ibde.h = include/ibde.h +src/gpl-modules = src/gpl-modules diff --git a/build/fbcode_builder/manifests/autoconf b/build/fbcode_builder/manifests/autoconf new file mode 100644 index 000000000..35963096c --- /dev/null +++ b/build/fbcode_builder/manifests/autoconf @@ -0,0 +1,16 @@ +[manifest] +name = autoconf + +[rpms] +autoconf + +[debs] +autoconf + +[download] +url = http://ftp.gnu.org/gnu/autoconf/autoconf-2.69.tar.gz +sha256 = 954bd69b391edc12d6a4a51a2dd1476543da5c6bbf05a95b59dc0dd6fd4c2969 + +[build] +builder = autoconf +subdir = autoconf-2.69 diff --git a/build/fbcode_builder/manifests/automake b/build/fbcode_builder/manifests/automake new file mode 100644 index 000000000..71115068a --- /dev/null +++ b/build/fbcode_builder/manifests/automake @@ -0,0 +1,19 @@ +[manifest] +name = automake + +[rpms] +automake + +[debs] +automake + +[download] +url = http://ftp.gnu.org/gnu/automake/automake-1.16.1.tar.gz +sha256 = 608a97523f97db32f1f5d5615c98ca69326ced2054c9f82e65bade7fc4c9dea8 + +[build] +builder = autoconf +subdir = automake-1.16.1 + +[dependencies] +autoconf diff --git a/build/fbcode_builder/manifests/bison b/build/fbcode_builder/manifests/bison new file mode 100644 index 000000000..6e355d052 --- /dev/null +++ b/build/fbcode_builder/manifests/bison @@ -0,0 +1,27 @@ +[manifest] +name = bison + +[rpms] +bison + +[debs] +bison + +[download.not(os=windows)] +url = https://mirrors.kernel.org/gnu/bison/bison-3.3.tar.gz +sha256 = fdeafb7fffade05604a61e66b8c040af4b2b5cbb1021dcfe498ed657ac970efd + +[download.os=windows] +url = https://github.com/lexxmark/winflexbison/releases/download/v2.5.17/winflexbison-2.5.17.zip +sha256 = 3dc27a16c21b717bcc5de8590b564d4392a0b8577170c058729d067d95ded825 + +[build.not(os=windows)] +builder = autoconf +subdir = bison-3.3 + +[build.os=windows] +builder = nop + +[install.files.os=windows] +data = bin/data +win_bison.exe = bin/bison.exe diff --git a/build/fbcode_builder/manifests/bistro b/build/fbcode_builder/manifests/bistro new file mode 100644 index 000000000..d93839275 --- /dev/null +++ b/build/fbcode_builder/manifests/bistro @@ -0,0 +1,28 @@ +[manifest] +name = bistro +fbsource_path = fbcode/bistro +shipit_project = bistro +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/bistro.git + +[build.os=linux] +builder = bistro + +# Bistro is Linux-specific +[build.not(os=linux)] +builder = nop + +[dependencies] +fmt +folly +proxygen +fbthrift +libsodium +googletest_1_8 +sqlite3 + +[shipit.pathmap] +fbcode/bistro/public_tld = . +fbcode/bistro = bistro diff --git a/build/fbcode_builder/manifests/boost b/build/fbcode_builder/manifests/boost new file mode 100644 index 000000000..4b254e308 --- /dev/null +++ b/build/fbcode_builder/manifests/boost @@ -0,0 +1,86 @@ +[manifest] +name = boost + +[download.not(os=windows)] +url = https://versaweb.dl.sourceforge.net/project/boost/boost/1.69.0/boost_1_69_0.tar.bz2 +sha256 = 8f32d4617390d1c2d16f26a27ab60d97807b35440d45891fa340fc2648b04406 + +[download.os=windows] +url = https://versaweb.dl.sourceforge.net/project/boost/boost/1.69.0/boost_1_69_0.zip +sha256 = d074bcbcc0501c4917b965fc890e303ee70d8b01ff5712bae4a6c54f2b6b4e52 + +[preinstalled.env] +BOOST_ROOT_1_69_0 + +[debs] +libboost-all-dev + +[rpms] +boost +boost-math +boost-test +boost-fiber +boost-graph +boost-log +boost-openmpi +boost-timer +boost-chrono +boost-locale +boost-thread +boost-atomic +boost-random +boost-static +boost-contract +boost-date-time +boost-iostreams +boost-container +boost-coroutine +boost-filesystem +boost-system +boost-stacktrace +boost-regex +boost-devel +boost-context +boost-python3-devel +boost-type_erasure +boost-wave +boost-python3 +boost-serialization +boost-program-options + +[build] +builder = boost + +[b2.args] +--with-atomic +--with-chrono +--with-container +--with-context +--with-contract +--with-coroutine +--with-date_time +--with-exception +--with-fiber +--with-filesystem +--with-graph +--with-graph_parallel +--with-iostreams +--with-locale +--with-log +--with-math +--with-mpi +--with-program_options +--with-python +--with-random +--with-regex +--with-serialization +--with-stacktrace +--with-system +--with-test +--with-thread +--with-timer +--with-type_erasure +--with-wave + +[b2.args.os=darwin] +toolset=clang diff --git a/build/fbcode_builder/manifests/cmake b/build/fbcode_builder/manifests/cmake new file mode 100644 index 000000000..f756caed0 --- /dev/null +++ b/build/fbcode_builder/manifests/cmake @@ -0,0 +1,43 @@ +[manifest] +name = cmake + +[rpms] +cmake + +# All current deb based distros have a cmake that is too old +#[debs] +#cmake + +[dependencies] +ninja + +[download.os=windows] +url = https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-win64-x64.zip +sha256 = 40e8140d68120378262322bbc8c261db8d184d7838423b2e5bf688a6209d3807 + +[download.os=darwin] +url = https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Darwin-x86_64.tar.gz +sha256 = a02ad0d5b955dfad54c095bd7e937eafbbbfe8a99860107025cc442290a3e903 + +[download.os=linux] +url = https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0.tar.gz +sha256 = aa76ba67b3c2af1946701f847073f4652af5cbd9f141f221c97af99127e75502 + +[build.os=windows] +builder = nop +subdir = cmake-3.14.0-win64-x64 + +[build.os=darwin] +builder = nop +subdir = cmake-3.14.0-Darwin-x86_64 + +[install.files.os=darwin] +CMake.app/Contents/bin = bin +CMake.app/Contents/share = share + +[build.os=linux] +builder = cmakebootstrap +subdir = cmake-3.14.0 + +[make.install_args.os=linux] +install diff --git a/build/fbcode_builder/manifests/cpptoml b/build/fbcode_builder/manifests/cpptoml new file mode 100644 index 000000000..5a3c781dc --- /dev/null +++ b/build/fbcode_builder/manifests/cpptoml @@ -0,0 +1,10 @@ +[manifest] +name = cpptoml + +[download] +url = https://github.com/skystrife/cpptoml/archive/v0.1.1.tar.gz +sha256 = 23af72468cfd4040984d46a0dd2a609538579c78ddc429d6b8fd7a10a6e24403 + +[build] +builder = cmake +subdir = cpptoml-0.1.1 diff --git a/build/fbcode_builder/manifests/delos_core b/build/fbcode_builder/manifests/delos_core new file mode 100644 index 000000000..1de6c3342 --- /dev/null +++ b/build/fbcode_builder/manifests/delos_core @@ -0,0 +1,25 @@ +[manifest] +name = delos_core +fbsource_path = fbcode/delos_core +shipit_project = delos_core +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/delos_core.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +builder = nop + +[dependencies] +glog +googletest +folly +fbthrift +fb303 +re2 + +[shipit.pathmap] +fbcode/delos_core = . diff --git a/build/fbcode_builder/manifests/double-conversion b/build/fbcode_builder/manifests/double-conversion new file mode 100644 index 000000000..e27c7ae06 --- /dev/null +++ b/build/fbcode_builder/manifests/double-conversion @@ -0,0 +1,11 @@ +[manifest] +name = double-conversion + +[download] +url = https://github.com/google/double-conversion/archive/v3.1.4.tar.gz +sha256 = 95004b65e43fefc6100f337a25da27bb99b9ef8d4071a36a33b5e83eb1f82021 + +[build] +builder = cmake +subdir = double-conversion-3.1.4 + diff --git a/build/fbcode_builder/manifests/eden b/build/fbcode_builder/manifests/eden new file mode 100644 index 000000000..700cc82ec --- /dev/null +++ b/build/fbcode_builder/manifests/eden @@ -0,0 +1,70 @@ +[manifest] +name = eden +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build] +builder = cmake + +[dependencies] +googletest +folly +fbthrift +fb303 +cpptoml +rocksdb +re2 +libgit2 +lz4 +pexpect +python-toml + +[dependencies.fb=on] +rust + +# macOS ships with sqlite3, and some of the core system +# frameworks require that that version be linked rather +# than the one we might build for ourselves here, so we +# skip building it on macos. +[dependencies.not(os=darwin)] +sqlite3 + +[dependencies.os=darwin] +osxfuse + +# TODO: teach getdeps to compile curl on Windows. +# Enabling curl on Windows requires us to find a way to compile libcurl with +# msvc. +[dependencies.not(os=windows)] +libcurl + +[shipit.pathmap] +fbcode/common/rust/shed/hostcaps = common/rust/shed/hostcaps +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/tools/lfs = tools/lfs +fbcode/thrift/lib/rust = thrift/lib/rust + +[shipit.strip] +^fbcode/eden/fs/eden-config\.h$ +^fbcode/eden/fs/py/eden/config\.py$ +^fbcode/eden/hg/.*$ +^fbcode/eden/mononoke/(?!lfs_protocol) +^fbcode/eden/scm/build/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/.*/\.cargo/.*$ +/Cargo\.lock$ +\.pyc$ + +[cmake.defines.all(fb=on,os=windows)] +INSTALL_PYTHON_LIB=ON + +[cmake.defines.fb=on] +USE_CARGO_VENDOR=ON + +[depends.environment] +EDEN_VERSION_OVERRIDE diff --git a/build/fbcode_builder/manifests/eden_scm b/build/fbcode_builder/manifests/eden_scm new file mode 100644 index 000000000..cfe9c7096 --- /dev/null +++ b/build/fbcode_builder/manifests/eden_scm @@ -0,0 +1,57 @@ +[manifest] +name = eden_scm +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build.not(os=windows)] +builder = make +subdir = eden/scm +disable_env_override_pkgconfig = 1 +disable_env_override_path = 1 + +[build.os=windows] +# For now the biggest blocker is missing "make" on windows, but there are bound +# to be more +builder = nop + +[make.build_args] +getdepsbuild + +[make.install_args] +install-getdeps + +[make.test_args] +test-getdeps + +[shipit.pathmap] +fbcode/common/rust = common/rust +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/tools/lfs = tools/lfs +fbcode/fboss/common = common + +[shipit.strip] +^fbcode/eden/fs/eden-config\.h$ +^fbcode/eden/fs/py/eden/config\.py$ +^fbcode/eden/hg/.*$ +^fbcode/eden/mononoke/(?!lfs_protocol) +^fbcode/eden/scm/build/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/.*/\.cargo/.*$ +^.*/fb/.*$ +/Cargo\.lock$ +\.pyc$ + +[dependencies] +fb303-source +fbthrift +fbthrift-source +openssl +rust-shed + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/eden_scm_lib_edenapi_tools b/build/fbcode_builder/manifests/eden_scm_lib_edenapi_tools new file mode 100644 index 000000000..be29d70f8 --- /dev/null +++ b/build/fbcode_builder/manifests/eden_scm_lib_edenapi_tools @@ -0,0 +1,36 @@ +[manifest] +name = eden_scm_lib_edenapi_tools +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build] +builder = cargo + +[cargo] +build_doc = true +manifests_to_build = eden/scm/lib/edenapi/tools/make_req/Cargo.toml,eden/scm/lib/edenapi/tools/read_res/Cargo.toml + +[shipit.pathmap] +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/tools/lfs = tools/lfs +fbcode/fboss/common = common + +[shipit.strip] +^fbcode/eden/fs/eden-config\.h$ +^fbcode/eden/fs/py/eden/config\.py$ +^fbcode/eden/hg/.*$ +^fbcode/eden/mononoke/(?!lfs_protocol) +^fbcode/eden/scm/build/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/.*/\.cargo/.*$ +^.*/fb/.*$ +/Cargo\.lock$ +\.pyc$ + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/f4d b/build/fbcode_builder/manifests/f4d new file mode 100644 index 000000000..db30894c7 --- /dev/null +++ b/build/fbcode_builder/manifests/f4d @@ -0,0 +1,29 @@ +[manifest] +name = f4d +fbsource_path = fbcode/f4d +shipit_project = f4d +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexternal/f4d.git + +[build.os=windows] +builder = nop + +[build.not(os=windows)] +builder = cmake + +[dependencies] +double-conversion +folly +glog +googletest +boost +protobuf +lzo +libicu +re2 + +[shipit.pathmap] +fbcode/f4d/public_tld = . +fbcode/f4d = f4d diff --git a/build/fbcode_builder/manifests/fatal b/build/fbcode_builder/manifests/fatal new file mode 100644 index 000000000..3c333561f --- /dev/null +++ b/build/fbcode_builder/manifests/fatal @@ -0,0 +1,15 @@ +[manifest] +name = fatal +fbsource_path = fbcode/fatal +shipit_project = fatal + +[git] +repo_url = https://github.com/facebook/fatal.git + +[shipit.pathmap] +fbcode/fatal = . +fbcode/fatal/public_tld = . + +[build] +builder = nop +subdir = . diff --git a/build/fbcode_builder/manifests/fb303 b/build/fbcode_builder/manifests/fb303 new file mode 100644 index 000000000..743aca01e --- /dev/null +++ b/build/fbcode_builder/manifests/fb303 @@ -0,0 +1,27 @@ +[manifest] +name = fb303 +fbsource_path = fbcode/fb303 +shipit_project = fb303 +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/fb303.git + +[build] +builder = cmake + +[dependencies] +folly +gflags +glog +fbthrift + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF + +[shipit.pathmap] +fbcode/fb303/github = . +fbcode/fb303 = fb303 diff --git a/build/fbcode_builder/manifests/fb303-source b/build/fbcode_builder/manifests/fb303-source new file mode 100644 index 000000000..ea160c500 --- /dev/null +++ b/build/fbcode_builder/manifests/fb303-source @@ -0,0 +1,15 @@ +[manifest] +name = fb303-source +fbsource_path = fbcode/fb303 +shipit_project = fb303 +shipit_fbcode_builder = false + +[git] +repo_url = https://github.com/facebook/fb303.git + +[build] +builder = nop + +[shipit.pathmap] +fbcode/fb303/github = . +fbcode/fb303 = fb303 diff --git a/build/fbcode_builder/manifests/fboss b/build/fbcode_builder/manifests/fboss new file mode 100644 index 000000000..f29873e72 --- /dev/null +++ b/build/fbcode_builder/manifests/fboss @@ -0,0 +1,42 @@ +[manifest] +name = fboss +fbsource_path = fbcode/fboss +shipit_project = fboss +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fboss.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +builder = nop + +[dependencies] +folly +fb303 +wangle +fizz +fmt +libsodium +googletest +zstd +fbthrift +iproute2 +libmnl +libusb +libcurl +libnl +libsai +OpenNSA +re2 +python +yaml-cpp +libyaml +CLI11 + +[shipit.pathmap] +fbcode/fboss/github = . +fbcode/fboss/common = common +fbcode/fboss = fboss diff --git a/build/fbcode_builder/manifests/fbthrift b/build/fbcode_builder/manifests/fbthrift new file mode 100644 index 000000000..072dd4512 --- /dev/null +++ b/build/fbcode_builder/manifests/fbthrift @@ -0,0 +1,33 @@ +[manifest] +name = fbthrift +fbsource_path = fbcode/thrift +shipit_project = fbthrift +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fbthrift.git + +[build] +builder = cmake + +[dependencies] +bison +flex +folly +wangle +fizz +fmt +googletest +libsodium +python-six +zstd + +[shipit.pathmap] +fbcode/thrift/public_tld = . +fbcode/thrift = thrift + +[shipit.strip] +^fbcode/thrift/thrift-config\.h$ +^fbcode/thrift/perf/canary.py$ +^fbcode/thrift/perf/loadtest.py$ +^fbcode/thrift/.castle/.* diff --git a/build/fbcode_builder/manifests/fbthrift-source b/build/fbcode_builder/manifests/fbthrift-source new file mode 100644 index 000000000..7af0d6dda --- /dev/null +++ b/build/fbcode_builder/manifests/fbthrift-source @@ -0,0 +1,21 @@ +[manifest] +name = fbthrift-source +fbsource_path = fbcode/thrift +shipit_project = fbthrift +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fbthrift.git + +[build] +builder = nop + +[shipit.pathmap] +fbcode/thrift/public_tld = . +fbcode/thrift = thrift + +[shipit.strip] +^fbcode/thrift/thrift-config\.h$ +^fbcode/thrift/perf/canary.py$ +^fbcode/thrift/perf/loadtest.py$ +^fbcode/thrift/.castle/.* diff --git a/build/fbcode_builder/manifests/fbzmq b/build/fbcode_builder/manifests/fbzmq new file mode 100644 index 000000000..5739016c8 --- /dev/null +++ b/build/fbcode_builder/manifests/fbzmq @@ -0,0 +1,29 @@ +[manifest] +name = fbzmq +fbsource_path = facebook/fbzmq +shipit_project = fbzmq +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fbzmq.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +# boost.fiber is required and that is not available on macos. +# libzmq doesn't currently build on windows. +builder = nop + +[dependencies] +boost +folly +fbthrift +googletest +libzmq + +[shipit.pathmap] +fbcode/fbzmq = fbzmq +fbcode/fbzmq/public_tld = . + +[shipit.strip] diff --git a/build/fbcode_builder/manifests/fizz b/build/fbcode_builder/manifests/fizz new file mode 100644 index 000000000..72f29973f --- /dev/null +++ b/build/fbcode_builder/manifests/fizz @@ -0,0 +1,36 @@ +[manifest] +name = fizz +fbsource_path = fbcode/fizz +shipit_project = fizz +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/fizz.git + +[build] +builder = cmake +subdir = fizz + +[cmake.defines] +BUILD_EXAMPLES = OFF + +[cmake.defines.test=on] +BUILD_TESTS = ON + +[cmake.defines.all(os=windows, test=on)] +BUILD_TESTS = OFF + +[cmake.defines.test=off] +BUILD_TESTS = OFF + +[dependencies] +folly +libsodium +zstd + +[dependencies.all(test=on, not(os=windows))] +googletest_1_8 + +[shipit.pathmap] +fbcode/fizz/public_tld = . +fbcode/fizz = fizz diff --git a/build/fbcode_builder/manifests/flex b/build/fbcode_builder/manifests/flex new file mode 100644 index 000000000..f266c4033 --- /dev/null +++ b/build/fbcode_builder/manifests/flex @@ -0,0 +1,32 @@ +[manifest] +name = flex + +[rpms] +flex + +[debs] +flex + +[download.not(os=windows)] +url = https://github.com/westes/flex/releases/download/v2.6.4/flex-2.6.4.tar.gz +sha256 = e87aae032bf07c26f85ac0ed3250998c37621d95f8bd748b31f15b33c45ee995 + +[download.os=windows] +url = https://github.com/lexxmark/winflexbison/releases/download/v2.5.17/winflexbison-2.5.17.zip +sha256 = 3dc27a16c21b717bcc5de8590b564d4392a0b8577170c058729d067d95ded825 + +[build.not(os=windows)] +builder = autoconf +subdir = flex-2.6.4 + +[build.os=windows] +builder = nop + +[install.files.os=windows] +data = bin/data +win_flex.exe = bin/flex.exe + +# Moral equivalent to this PR that fixes a crash when bootstrapping flex +# on linux: https://github.com/easybuilders/easybuild-easyconfigs/pull/5792 +[autoconf.args.os=linux] +CFLAGS=-D_GNU_SOURCE diff --git a/build/fbcode_builder/manifests/fmt b/build/fbcode_builder/manifests/fmt new file mode 100644 index 000000000..21503d202 --- /dev/null +++ b/build/fbcode_builder/manifests/fmt @@ -0,0 +1,14 @@ +[manifest] +name = fmt + +[download] +url = https://github.com/fmtlib/fmt/archive/6.1.1.tar.gz +sha256 = bf4e50955943c1773cc57821d6c00f7e2b9e10eb435fafdd66739d36056d504e + +[build] +builder = cmake +subdir = fmt-6.1.1 + +[cmake.defines] +FMT_TEST = OFF +FMT_DOC = OFF diff --git a/build/fbcode_builder/manifests/folly b/build/fbcode_builder/manifests/folly new file mode 100644 index 000000000..9647b17f8 --- /dev/null +++ b/build/fbcode_builder/manifests/folly @@ -0,0 +1,58 @@ +[manifest] +name = folly +fbsource_path = fbcode/folly +shipit_project = folly +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/folly.git + +[build] +builder = cmake + +[dependencies] +gflags +glog +googletest +boost +libevent +double-conversion +fmt +lz4 +snappy +zstd +# no openssl or zlib in the linux case, why? +# these are usually installed on the system +# and are the easiest system deps to pull in. +# In the future we want to be able to express +# that a system dep is sufficient in the manifest +# for eg: openssl and zlib, but for now we don't +# have it. + +# macOS doesn't expose the openssl api so we need +# to build our own. +[dependencies.os=darwin] +openssl + +# Windows has neither openssl nor zlib, so we get +# to provide both +[dependencies.os=windows] +openssl +zlib + +[shipit.pathmap] +fbcode/folly/public_tld = . +fbcode/folly = folly + +[shipit.strip] +^fbcode/folly/folly-config\.h$ +^fbcode/folly/public_tld/build/facebook_.* + +[cmake.defines] +BUILD_SHARED_LIBS=OFF + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF diff --git a/build/fbcode_builder/manifests/gflags b/build/fbcode_builder/manifests/gflags new file mode 100644 index 000000000..d7ec44eab --- /dev/null +++ b/build/fbcode_builder/manifests/gflags @@ -0,0 +1,17 @@ +[manifest] +name = gflags + +[download] +url = https://github.com/gflags/gflags/archive/v2.2.2.tar.gz +sha256 = 34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf + +[build] +builder = cmake +subdir = gflags-2.2.2 + +[cmake.defines] +BUILD_SHARED_LIBS = ON +BUILD_STATIC_LIBS = ON +#BUILD_gflags_nothreads_LIB = OFF +BUILD_gflags_LIB = ON + diff --git a/build/fbcode_builder/manifests/git-lfs b/build/fbcode_builder/manifests/git-lfs new file mode 100644 index 000000000..38a5e6aeb --- /dev/null +++ b/build/fbcode_builder/manifests/git-lfs @@ -0,0 +1,12 @@ +[manifest] +name = git-lfs + +[download.os=linux] +url = https://github.com/git-lfs/git-lfs/releases/download/v2.9.1/git-lfs-linux-amd64-v2.9.1.tar.gz +sha256 = 2a8e60cf51ec45aa0f4332aa0521d60ec75c76e485d13ebaeea915b9d70ea466 + +[build] +builder = nop + +[install.files] +git-lfs = bin/git-lfs diff --git a/build/fbcode_builder/manifests/glog b/build/fbcode_builder/manifests/glog new file mode 100644 index 000000000..d2354610a --- /dev/null +++ b/build/fbcode_builder/manifests/glog @@ -0,0 +1,16 @@ +[manifest] +name = glog + +[download] +url = https://github.com/google/glog/archive/v0.4.0.tar.gz +sha256 = f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c + +[build] +builder = cmake +subdir = glog-0.4.0 + +[dependencies] +gflags + +[cmake.defines] +BUILD_SHARED_LIBS=ON diff --git a/build/fbcode_builder/manifests/gnu-bash b/build/fbcode_builder/manifests/gnu-bash new file mode 100644 index 000000000..89da77ca2 --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-bash @@ -0,0 +1,20 @@ +[manifest] +name = gnu-bash + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/bash/bash-5.1-rc1.tar.gz +sha256 = 0b2684eb1990329d499c96decfe2459f3e150deb915b0a9d03cf1be692b1d6d3 + +[build.os=darwin] +# The buildin FreeBSD bash on OSX is both outdated and incompatible with the +# modern GNU bash, so for the sake of being cross-platform friendly this +# manifest provides GNU bash. +# NOTE: This is the 5.1-rc1 version, which is almost the same as what Homebrew +# uses (Homebrew installs 5.0 with the 18 patches that in fact make the 5.1-rc1 +# version). +builder = autoconf +subdir = bash-5.1-rc1 +build_in_src_dir = true + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/gnu-coreutils b/build/fbcode_builder/manifests/gnu-coreutils new file mode 100644 index 000000000..1ab4d9d4a --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-coreutils @@ -0,0 +1,15 @@ +[manifest] +name = gnu-coreutils + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/coreutils/coreutils-8.32.tar.gz +sha256 = d5ab07435a74058ab69a2007e838be4f6a90b5635d812c2e26671e3972fca1b8 + +[build.os=darwin] +# The buildin FreeBSD version incompatible with the GNU one, so for the sake of +# being cross-platform friendly this manifest provides the GNU version. +builder = autoconf +subdir = coreutils-8.32 + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/gnu-grep b/build/fbcode_builder/manifests/gnu-grep new file mode 100644 index 000000000..e6a163d37 --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-grep @@ -0,0 +1,15 @@ +[manifest] +name = gnu-grep + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/grep/grep-3.5.tar.gz +sha256 = 9897220992a8fd38a80b70731462defa95f7ff2709b235fb54864ddd011141dd + +[build.os=darwin] +# The buildin FreeBSD version incompatible with the GNU one, so for the sake of +# being cross-platform friendly this manifest provides the GNU version. +builder = autoconf +subdir = grep-3.5 + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/gnu-sed b/build/fbcode_builder/manifests/gnu-sed new file mode 100644 index 000000000..9b458df6e --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-sed @@ -0,0 +1,15 @@ +[manifest] +name = gnu-sed + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/sed/sed-4.8.tar.gz +sha256 = 53cf3e14c71f3a149f29d13a0da64120b3c1d3334fba39c4af3e520be053982a + +[build.os=darwin] +# The buildin FreeBSD version incompatible with the GNU one, so for the sake of +# being cross-platform friendly this manifest provides the GNU version. +builder = autoconf +subdir = sed-4.8 + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/googletest b/build/fbcode_builder/manifests/googletest new file mode 100644 index 000000000..775aac34f --- /dev/null +++ b/build/fbcode_builder/manifests/googletest @@ -0,0 +1,18 @@ +[manifest] +name = googletest + +[download] +url = https://github.com/google/googletest/archive/release-1.10.0.tar.gz +sha256 = 9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb + +[build] +builder = cmake +subdir = googletest-release-1.10.0 + +[cmake.defines] +# Everything else defaults to the shared runtime, so tell gtest that +# it should not use its choice of the static runtime +gtest_force_shared_crt=ON + +[cmake.defines.os=windows] +BUILD_SHARED_LIBS=ON diff --git a/build/fbcode_builder/manifests/googletest_1_8 b/build/fbcode_builder/manifests/googletest_1_8 new file mode 100644 index 000000000..76c0ce51f --- /dev/null +++ b/build/fbcode_builder/manifests/googletest_1_8 @@ -0,0 +1,18 @@ +[manifest] +name = googletest_1_8 + +[download] +url = https://github.com/google/googletest/archive/release-1.8.0.tar.gz +sha256 = 58a6f4277ca2bc8565222b3bbd58a177609e9c488e8a72649359ba51450db7d8 + +[build] +builder = cmake +subdir = googletest-release-1.8.0 + +[cmake.defines] +# Everything else defaults to the shared runtime, so tell gtest that +# it should not use its choice of the static runtime +gtest_force_shared_crt=ON + +[cmake.defines.os=windows] +BUILD_SHARED_LIBS=ON diff --git a/build/fbcode_builder/manifests/gperf b/build/fbcode_builder/manifests/gperf new file mode 100644 index 000000000..13d7a890f --- /dev/null +++ b/build/fbcode_builder/manifests/gperf @@ -0,0 +1,14 @@ +[manifest] +name = gperf + +[download] +url = http://ftp.gnu.org/pub/gnu/gperf/gperf-3.1.tar.gz +sha256 = 588546b945bba4b70b6a3a616e80b4ab466e3f33024a352fc2198112cdbb3ae2 + +[build.not(os=windows)] +builder = autoconf +subdir = gperf-3.1 + +[build.os=windows] +builder = nop + diff --git a/build/fbcode_builder/manifests/iproute2 b/build/fbcode_builder/manifests/iproute2 new file mode 100644 index 000000000..6fb7f77ed --- /dev/null +++ b/build/fbcode_builder/manifests/iproute2 @@ -0,0 +1,13 @@ +[manifest] +name = iproute2 + +[download] +url = https://mirrors.edge.kernel.org/pub/linux/utils/net/iproute2/iproute2-4.12.0.tar.gz +sha256 = 46612a1e2d01bb31932557bccdb1b8618cae9a439dfffc08ef35ed8e197f14ce + +[build.os=linux] +builder = iproute2 +subdir = iproute2-4.12.0 + +[build.not(os=linux)] +builder = nop diff --git a/build/fbcode_builder/manifests/jq b/build/fbcode_builder/manifests/jq new file mode 100644 index 000000000..231818f34 --- /dev/null +++ b/build/fbcode_builder/manifests/jq @@ -0,0 +1,24 @@ +[manifest] +name = jq + +[rpms] +jq + +[debs] +jq + +[download.not(os=windows)] +url = https://github.com/stedolan/jq/releases/download/jq-1.5/jq-1.5.tar.gz +sha256 = c4d2bfec6436341113419debf479d833692cc5cdab7eb0326b5a4d4fbe9f493c + +[build.not(os=windows)] +builder = autoconf +subdir = jq-1.5 + +[build.os=windows] +builder = nop + +[autoconf.args] +# This argument turns off some developers tool and it is recommended in jq's +# README +--disable-maintainer-mode diff --git a/build/fbcode_builder/manifests/katran b/build/fbcode_builder/manifests/katran new file mode 100644 index 000000000..224ccbe21 --- /dev/null +++ b/build/fbcode_builder/manifests/katran @@ -0,0 +1,38 @@ +[manifest] +name = katran +fbsource_path = fbcode/katran +shipit_project = katran +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/katran.git + +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = cmake +subdir = . + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF + +[dependencies] +folly +fizz +libbpf +libmnl +zlib +googletest + + +[shipit.pathmap] +fbcode/katran/public_root = . +fbcode/katran = katran + +[shipit.strip] +^fbcode/katran/facebook +^fbcode/katran/OSS_SYNC diff --git a/build/fbcode_builder/manifests/libbpf b/build/fbcode_builder/manifests/libbpf new file mode 100644 index 000000000..0416822e4 --- /dev/null +++ b/build/fbcode_builder/manifests/libbpf @@ -0,0 +1,26 @@ +[manifest] +name = libbpf + +[download] +url = https://github.com/libbpf/libbpf/archive/v0.3.tar.gz +sha256 = c168d84a75b541f753ceb49015d9eb886e3fb5cca87cdd9aabce7e10ad3a1efc + +# BPF only builds on linux, so make it a NOP on other platforms +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = make +subdir = libbpf-0.3/src + +[make.build_args] +BUILD_STATIC_ONLY=y + +# libbpf-0.3 requires uapi headers >= 5.8 +[make.install_args] +install +install_uapi_headers +BUILD_STATIC_ONLY=y + +[dependencies] +libelf diff --git a/build/fbcode_builder/manifests/libbpf_0_2_0_beta b/build/fbcode_builder/manifests/libbpf_0_2_0_beta new file mode 100644 index 000000000..072639817 --- /dev/null +++ b/build/fbcode_builder/manifests/libbpf_0_2_0_beta @@ -0,0 +1,26 @@ +[manifest] +name = libbpf_0_2_0_beta + +[download] +url = https://github.com/libbpf/libbpf/archive/b6dd2f2.tar.gz +sha256 = 8db9dca90f5c445ef2362e3c6a00f3d6c4bf36e8782f8e27704109c78e541497 + +# BPF only builds on linux, so make it a NOP on other platforms +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = make +subdir = libbpf-b6dd2f2b7df4d3bd35d64aaf521d9ad18d766f53/src + +[make.build_args] +BUILD_STATIC_ONLY=y + +# libbpf now requires uapi headers >= 5.8 +[make.install_args] +install +install_uapi_headers +BUILD_STATIC_ONLY=y + +[dependencies] +libelf diff --git a/build/fbcode_builder/manifests/libcurl b/build/fbcode_builder/manifests/libcurl new file mode 100644 index 000000000..466b4497c --- /dev/null +++ b/build/fbcode_builder/manifests/libcurl @@ -0,0 +1,39 @@ +[manifest] +name = libcurl + +[rpms] +libcurl-devel +libcurl + +[debs] +libcurl4-openssl-dev + +[download] +url = https://curl.haxx.se/download/curl-7.65.1.tar.gz +sha256 = 821aeb78421375f70e55381c9ad2474bf279fc454b791b7e95fc83562951c690 + +[dependencies] +nghttp2 + +# We use system OpenSSL on Linux (see folly's manifest for details) +[dependencies.not(os=linux)] +openssl + +[build.not(os=windows)] +builder = autoconf +subdir = curl-7.65.1 + +[autoconf.args] +# fboss (which added the libcurl dep) doesn't need ldap so it is disabled here. +# if someone in the future wants to add ldap for something else, it won't hurt +# fboss. However, that would require adding an ldap manifest. +# +# For the same reason, we disable libssh2 and libidn2 which aren't really used +# but would require adding manifests if we don't disable them. +--disable-ldap +--without-libssh2 +--without-libidn2 + +[build.os=windows] +builder = cmake +subdir = curl-7.65.1 diff --git a/build/fbcode_builder/manifests/libelf b/build/fbcode_builder/manifests/libelf new file mode 100644 index 000000000..a46aab879 --- /dev/null +++ b/build/fbcode_builder/manifests/libelf @@ -0,0 +1,20 @@ +[manifest] +name = libelf + +[rpms] +elfutils-libelf-devel-static + +[debs] +libelf-dev + +[download] +url = https://ftp.osuosl.org/pub/blfs/conglomeration/libelf/libelf-0.8.13.tar.gz +sha256 = 591a9b4ec81c1f2042a97aa60564e0cb79d041c52faa7416acb38bc95bd2c76d + +# libelf only makes sense on linux, so make it a NOP on other platforms +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = autoconf +subdir = libelf-0.8.13 diff --git a/build/fbcode_builder/manifests/libevent b/build/fbcode_builder/manifests/libevent new file mode 100644 index 000000000..eaa39a9e6 --- /dev/null +++ b/build/fbcode_builder/manifests/libevent @@ -0,0 +1,29 @@ +[manifest] +name = libevent + +[rpms] +libevent-devel + +[debs] +libevent-dev + +# Note that the CMakeLists.txt file is present only in +# git repo and not in the release tarball, so take care +# to use the github generated source tarball rather than +# the explicitly uploaded source tarball +[download] +url = https://github.com/libevent/libevent/archive/release-2.1.8-stable.tar.gz +sha256 = 316ddb401745ac5d222d7c529ef1eada12f58f6376a66c1118eee803cb70f83d + +[build] +builder = cmake +subdir = libevent-release-2.1.8-stable + +[cmake.defines] +EVENT__DISABLE_TESTS = ON +EVENT__DISABLE_BENCHMARK = ON +EVENT__DISABLE_SAMPLES = ON +EVENT__DISABLE_REGRESS = ON + +[dependencies.not(os=linux)] +openssl diff --git a/build/fbcode_builder/manifests/libgit2 b/build/fbcode_builder/manifests/libgit2 new file mode 100644 index 000000000..1d6a53e5e --- /dev/null +++ b/build/fbcode_builder/manifests/libgit2 @@ -0,0 +1,24 @@ +[manifest] +name = libgit2 + +[rpms] +libgit2-devel + +[debs] +libgit2-dev + +[download] +url = https://github.com/libgit2/libgit2/archive/v0.28.1.tar.gz +sha256 = 0ca11048795b0d6338f2e57717370208c2c97ad66c6d5eac0c97a8827d13936b + +[build] +builder = cmake +subdir = libgit2-0.28.1 + +[cmake.defines] +# Could turn this on if we also wanted to add a manifest for libssh2 +USE_SSH = OFF +BUILD_CLAR = OFF +# Have to build shared to work around annoying problems with cmake +# mis-parsing the frameworks required to link this on macos :-/ +BUILD_SHARED_LIBS = ON diff --git a/build/fbcode_builder/manifests/libicu b/build/fbcode_builder/manifests/libicu new file mode 100644 index 000000000..c1deda503 --- /dev/null +++ b/build/fbcode_builder/manifests/libicu @@ -0,0 +1,19 @@ +[manifest] +name = libicu + +[rpms] +libicu-devel + +[debs] +libicu-dev + +[download] +url = https://github.com/unicode-org/icu/releases/download/release-68-2/icu4c-68_2-src.tgz +sha256 = c79193dee3907a2199b8296a93b52c5cb74332c26f3d167269487680d479d625 + +[build.not(os=windows)] +builder = autoconf +subdir = icu/source + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/libmnl b/build/fbcode_builder/manifests/libmnl new file mode 100644 index 000000000..9b28b87b9 --- /dev/null +++ b/build/fbcode_builder/manifests/libmnl @@ -0,0 +1,17 @@ +[manifest] +name = libmnl + +[rpms] +libmnl-devel +libmnl-static + +[debs] +libmnl-dev + +[download] +url = http://www.netfilter.org/pub/libmnl/libmnl-1.0.4.tar.bz2 +sha256 = 171f89699f286a5854b72b91d06e8f8e3683064c5901fb09d954a9ab6f551f81 + +[build.os=linux] +builder = autoconf +subdir = libmnl-1.0.4 diff --git a/build/fbcode_builder/manifests/libnl b/build/fbcode_builder/manifests/libnl new file mode 100644 index 000000000..f864acb49 --- /dev/null +++ b/build/fbcode_builder/manifests/libnl @@ -0,0 +1,17 @@ +[manifest] +name = libnl + +[rpms] +libnl3-devel +libnl3 + +[debs] +libnl-3-dev + +[download] +url = https://www.infradead.org/~tgr/libnl/files/libnl-3.2.25.tar.gz +sha256 = 8beb7590674957b931de6b7f81c530b85dc7c1ad8fbda015398bc1e8d1ce8ec5 + +[build.os=linux] +builder = autoconf +subdir = libnl-3.2.25 diff --git a/build/fbcode_builder/manifests/libsai b/build/fbcode_builder/manifests/libsai new file mode 100644 index 000000000..4f422d8e1 --- /dev/null +++ b/build/fbcode_builder/manifests/libsai @@ -0,0 +1,13 @@ +[manifest] +name = libsai + +[download] +url = https://github.com/opencomputeproject/SAI/archive/v1.7.1.tar.gz +sha256 = e18eb1a2a6e5dd286d97e13569d8b78cc1f8229030beed0db4775b9a50ab6a83 + +[build] +builder = nop +subdir = SAI-1.7.1 + +[install.files] +inc = include diff --git a/build/fbcode_builder/manifests/libsodium b/build/fbcode_builder/manifests/libsodium new file mode 100644 index 000000000..d69bfcc4b --- /dev/null +++ b/build/fbcode_builder/manifests/libsodium @@ -0,0 +1,33 @@ +[manifest] +name = libsodium + +[rpms] +libsodium-devel +libsodium-static + +[debs] +libsodium-dev + +[download.not(os=windows)] +url = https://github.com/jedisct1/libsodium/releases/download/1.0.17/libsodium-1.0.17.tar.gz +sha256 = 0cc3dae33e642cc187b5ceb467e0ad0e1b51dcba577de1190e9ffa17766ac2b1 + +[build.not(os=windows)] +builder = autoconf +subdir = libsodium-1.0.17 + +[download.os=windows] +url = https://download.libsodium.org/libsodium/releases/libsodium-1.0.17-msvc.zip +sha256 = f0f32ad8ebd76eee99bb039f843f583f2babca5288a8c26a7261db9694c11467 + +[build.os=windows] +builder = nop + +[install.files.os=windows] +x64/Release/v141/dynamic/libsodium.dll = bin/libsodium.dll +x64/Release/v141/dynamic/libsodium.lib = lib/libsodium.lib +x64/Release/v141/dynamic/libsodium.exp = lib/libsodium.exp +x64/Release/v141/dynamic/libsodium.pdb = lib/libsodium.pdb +include = include + +[autoconf.args] diff --git a/build/fbcode_builder/manifests/libtool b/build/fbcode_builder/manifests/libtool new file mode 100644 index 000000000..1ec99b5f4 --- /dev/null +++ b/build/fbcode_builder/manifests/libtool @@ -0,0 +1,22 @@ +[manifest] +name = libtool + +[rpms] +libtool + +[debs] +libtool + +[download] +url = http://ftp.gnu.org/gnu/libtool/libtool-2.4.6.tar.gz +sha256 = e3bd4d5d3d025a36c21dd6af7ea818a2afcd4dfc1ea5a17b39d7854bcd0c06e3 + +[build] +builder = autoconf +subdir = libtool-2.4.6 + +[dependencies] +automake + +[autoconf.args] +--enable-ltdl-install diff --git a/build/fbcode_builder/manifests/libusb b/build/fbcode_builder/manifests/libusb new file mode 100644 index 000000000..74702d3f0 --- /dev/null +++ b/build/fbcode_builder/manifests/libusb @@ -0,0 +1,23 @@ +[manifest] +name = libusb + +[rpms] +libusb-devel +libusb + +[debs] +libusb-1.0-0-dev + +[download] +url = https://github.com/libusb/libusb/releases/download/v1.0.22/libusb-1.0.22.tar.bz2 +sha256 = 75aeb9d59a4fdb800d329a545c2e6799f732362193b465ea198f2aa275518157 + +[build.os=linux] +builder = autoconf +subdir = libusb-1.0.22 + +[autoconf.args] +# fboss (which added the libusb dep) doesn't need udev so it is disabled here. +# if someone in the future wants to add udev for something else, it won't hurt +# fboss. +--disable-udev diff --git a/build/fbcode_builder/manifests/libyaml b/build/fbcode_builder/manifests/libyaml new file mode 100644 index 000000000..a7ff57316 --- /dev/null +++ b/build/fbcode_builder/manifests/libyaml @@ -0,0 +1,13 @@ +[manifest] +name = libyaml + +[download] +url = http://pyyaml.org/download/libyaml/yaml-0.1.7.tar.gz +sha256 = 8088e457264a98ba451a90b8661fcb4f9d6f478f7265d48322a196cec2480729 + +[build.os=linux] +builder = autoconf +subdir = yaml-0.1.7 + +[build.not(os=linux)] +builder = nop diff --git a/build/fbcode_builder/manifests/libzmq b/build/fbcode_builder/manifests/libzmq new file mode 100644 index 000000000..4f555fa65 --- /dev/null +++ b/build/fbcode_builder/manifests/libzmq @@ -0,0 +1,24 @@ +[manifest] +name = libzmq + +[rpms] +zeromq-devel +zeromq + +[debs] +libzmq3-dev + +[download] +url = https://github.com/zeromq/libzmq/releases/download/v4.3.1/zeromq-4.3.1.tar.gz +sha256 = bcbabe1e2c7d0eec4ed612e10b94b112dd5f06fcefa994a0c79a45d835cd21eb + + +[build] +builder = autoconf +subdir = zeromq-4.3.1 + +[autoconf.args] + +[dependencies] +autoconf +libtool diff --git a/build/fbcode_builder/manifests/lz4 b/build/fbcode_builder/manifests/lz4 new file mode 100644 index 000000000..03dbd9de4 --- /dev/null +++ b/build/fbcode_builder/manifests/lz4 @@ -0,0 +1,17 @@ +[manifest] +name = lz4 + +[rpms] +lz4-devel +lz4-static + +[debs] +liblz4-dev + +[download] +url = https://github.com/lz4/lz4/archive/v1.8.3.tar.gz +sha256 = 33af5936ac06536805f9745e0b6d61da606a1f8b4cc5c04dd3cbaca3b9b4fc43 + +[build] +builder = cmake +subdir = lz4-1.8.3/contrib/cmake_unofficial diff --git a/build/fbcode_builder/manifests/lzo b/build/fbcode_builder/manifests/lzo new file mode 100644 index 000000000..342428ab5 --- /dev/null +++ b/build/fbcode_builder/manifests/lzo @@ -0,0 +1,19 @@ +[manifest] +name = lzo + +[rpms] +lzo-devel + +[debs] +liblzo2-dev + +[download] +url = http://www.oberhumer.com/opensource/lzo/download/lzo-2.10.tar.gz +sha256 = c0f892943208266f9b6543b3ae308fab6284c5c90e627931446fb49b4221a072 + +[build.not(os=windows)] +builder = autoconf +subdir = lzo-2.10 + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/mononoke b/build/fbcode_builder/manifests/mononoke new file mode 100644 index 000000000..7df92c77b --- /dev/null +++ b/build/fbcode_builder/manifests/mononoke @@ -0,0 +1,44 @@ +[manifest] +name = mononoke +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build.not(os=windows)] +builder = cargo + +[build.os=windows] +# building Mononoke on windows is not supported +builder = nop + +[cargo] +build_doc = true +workspace_dir = eden/mononoke + +[shipit.pathmap] +fbcode/configerator/structs/scm/mononoke/public_autocargo = configerator/structs/scm/mononoke +fbcode/configerator/structs/scm/mononoke = configerator/structs/scm/mononoke +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/eden/mononoke/public_autocargo = eden/mononoke +fbcode/tools/lfs = tools/lfs +tools/rust/ossconfigs = . + +[shipit.strip] +# strip all code unrelated to mononoke to prevent triggering unnecessary checks +^fbcode/eden/(?!mononoke|scm/lib/xdiff.*)/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/mononoke/Cargo\.toml$ +^fbcode/eden/mononoke/(?!public_autocargo).+/Cargo\.toml$ +^fbcode/configerator/structs/scm/mononoke/(?!public_autocargo).+/Cargo\.toml$ +^.*/facebook/.*$ + +[dependencies] +fbthrift-source +rust-shed + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/mononoke_integration b/build/fbcode_builder/manifests/mononoke_integration new file mode 100644 index 000000000..a796e967e --- /dev/null +++ b/build/fbcode_builder/manifests/mononoke_integration @@ -0,0 +1,47 @@ +[manifest] +name = mononoke_integration +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[build.not(os=windows)] +builder = make +subdir = eden/mononoke/tests/integration + +[build.os=windows] +# building Mononoke on windows is not supported +builder = nop + +[make.build_args] +build-getdeps + +[make.install_args] +install-getdeps + +[make.test_args] +test-getdeps + +[shipit.pathmap] +fbcode/eden/mononoke/tests/integration = eden/mononoke/tests/integration + +[shipit.strip] +^.*/facebook/.*$ + +[dependencies] +eden_scm +eden_scm_lib_edenapi_tools +jq +mononoke +nmap +python-click +python-dulwich +tree + +[dependencies.os=linux] +sqlite3-bin + +[dependencies.os=darwin] +gnu-bash +gnu-coreutils +gnu-grep +gnu-sed diff --git a/build/fbcode_builder/manifests/mvfst b/build/fbcode_builder/manifests/mvfst new file mode 100644 index 000000000..4f72a9192 --- /dev/null +++ b/build/fbcode_builder/manifests/mvfst @@ -0,0 +1,32 @@ +[manifest] +name = mvfst +fbsource_path = fbcode/quic +shipit_project = mvfst +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/mvfst.git + +[build] +builder = cmake +subdir = . + +[cmake.defines.test=on] +BUILD_TESTS = ON + +[cmake.defines.all(os=windows, test=on)] +BUILD_TESTS = OFF + +[cmake.defines.test=off] +BUILD_TESTS = OFF + +[dependencies] +folly +fizz + +[dependencies.all(test=on, not(os=windows))] +googletest_1_8 + +[shipit.pathmap] +fbcode/quic/public_root = . +fbcode/quic = quic diff --git a/build/fbcode_builder/manifests/nghttp2 b/build/fbcode_builder/manifests/nghttp2 new file mode 100644 index 000000000..151daf8af --- /dev/null +++ b/build/fbcode_builder/manifests/nghttp2 @@ -0,0 +1,20 @@ +[manifest] +name = nghttp2 + +[rpms] +libnghttp2-devel +libnghttp2 + +[debs] +libnghttp2-dev + +[download] +url = https://github.com/nghttp2/nghttp2/releases/download/v1.39.2/nghttp2-1.39.2.tar.gz +sha256 = fc820a305e2f410fade1a3260f09229f15c0494fc089b0100312cd64a33a38c0 + +[build] +builder = autoconf +subdir = nghttp2-1.39.2 + +[autoconf.args] +--enable-lib-only diff --git a/build/fbcode_builder/manifests/ninja b/build/fbcode_builder/manifests/ninja new file mode 100644 index 000000000..2b6c5dc8d --- /dev/null +++ b/build/fbcode_builder/manifests/ninja @@ -0,0 +1,26 @@ +[manifest] +name = ninja + +[rpms] +ninja-build + +[debs] +ninja-build + +[download.os=windows] +url = https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-win.zip +sha256 = bbde850d247d2737c5764c927d1071cbb1f1957dcabda4a130fa8547c12c695f + +[build.os=windows] +builder = nop + +[install.files.os=windows] +ninja.exe = bin/ninja.exe + +[download.not(os=windows)] +url = https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz +sha256 = ce35865411f0490368a8fc383f29071de6690cbadc27704734978221f25e2bed + +[build.not(os=windows)] +builder = ninja_bootstrap +subdir = ninja-1.10.2 diff --git a/build/fbcode_builder/manifests/nmap b/build/fbcode_builder/manifests/nmap new file mode 100644 index 000000000..c245e1241 --- /dev/null +++ b/build/fbcode_builder/manifests/nmap @@ -0,0 +1,25 @@ +[manifest] +name = nmap + +[rpms] +nmap + +[debs] +nmap + +[download.not(os=windows)] +url = https://api.github.com/repos/nmap/nmap/tarball/ef8213a36c2e89233c806753a57b5cd473605408 +sha256 = eda39e5a8ef4964fac7db16abf91cc11ff568eac0fa2d680b0bfa33b0ed71f4a + +[build.not(os=windows)] +builder = autoconf +subdir = nmap-nmap-ef8213a +build_in_src_dir = true + +[build.os=windows] +builder = nop + +[autoconf.args] +# Without this option the build was filing to find some third party libraries +# that we don't need +enable_rdma=no diff --git a/build/fbcode_builder/manifests/openr b/build/fbcode_builder/manifests/openr new file mode 100644 index 000000000..754ba8cd5 --- /dev/null +++ b/build/fbcode_builder/manifests/openr @@ -0,0 +1,37 @@ +[manifest] +name = openr +fbsource_path = facebook/openr +shipit_project = openr +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/openr.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +# boost.fiber is required and that is not available on macos. +# libzmq doesn't currently build on windows. +builder = nop + +[dependencies] +boost +fb303 +fbthrift +fbzmq +folly +googletest +re2 + +[cmake.defines.test=on] +BUILD_TESTS=ON +ADD_ROOT_TESTS=OFF + +[cmake.defines.test=off] +BUILD_TESTS=OFF + + +[shipit.pathmap] +fbcode/openr = openr +fbcode/openr/public_tld = . diff --git a/build/fbcode_builder/manifests/openssl b/build/fbcode_builder/manifests/openssl new file mode 100644 index 000000000..991196c9a --- /dev/null +++ b/build/fbcode_builder/manifests/openssl @@ -0,0 +1,20 @@ +[manifest] +name = openssl + +[rpms] +openssl-devel +openssl + +[debs] +libssl-dev + +[download] +url = https://www.openssl.org/source/openssl-1.1.1i.tar.gz +sha256 = e8be6a35fe41d10603c3cc635e93289ed00bf34b79671a3a4de64fcee00d5242 + +[build] +builder = openssl +subdir = openssl-1.1.1i + +[dependencies.os=windows] +perl diff --git a/build/fbcode_builder/manifests/osxfuse b/build/fbcode_builder/manifests/osxfuse new file mode 100644 index 000000000..b6c6c551f --- /dev/null +++ b/build/fbcode_builder/manifests/osxfuse @@ -0,0 +1,12 @@ +[manifest] +name = osxfuse + +[download] +url = https://github.com/osxfuse/osxfuse/archive/osxfuse-3.8.3.tar.gz +sha256 = 93bab6731bdfe8dc1ef069483437270ce7fe5a370f933d40d8d0ef09ba846c0c + +[build] +builder = nop + +[install.files] +osxfuse-osxfuse-3.8.3/common = include diff --git a/build/fbcode_builder/manifests/patchelf b/build/fbcode_builder/manifests/patchelf new file mode 100644 index 000000000..f9d050424 --- /dev/null +++ b/build/fbcode_builder/manifests/patchelf @@ -0,0 +1,17 @@ +[manifest] +name = patchelf + +[rpms] +patchelf + +[debs] +patchelf + +[download] +url = https://github.com/NixOS/patchelf/archive/0.10.tar.gz +sha256 = b3cb6bdedcef5607ce34a350cf0b182eb979f8f7bc31eae55a93a70a3f020d13 + +[build] +builder = autoconf +subdir = patchelf-0.10 + diff --git a/build/fbcode_builder/manifests/pcre b/build/fbcode_builder/manifests/pcre new file mode 100644 index 000000000..5353d8c27 --- /dev/null +++ b/build/fbcode_builder/manifests/pcre @@ -0,0 +1,18 @@ +[manifest] +name = pcre + +[rpms] +pcre-devel +pcre-static + +[debs] +libpcre3-dev + +[download] +url = https://ftp.pcre.org/pub/pcre/pcre-8.43.tar.gz +sha256 = 0b8e7465dc5e98c757cc3650a20a7843ee4c3edf50aaf60bb33fd879690d2c73 + +[build] +builder = cmake +subdir = pcre-8.43 + diff --git a/build/fbcode_builder/manifests/perl b/build/fbcode_builder/manifests/perl new file mode 100644 index 000000000..32bddc51c --- /dev/null +++ b/build/fbcode_builder/manifests/perl @@ -0,0 +1,11 @@ +[manifest] +name = perl + +[download.os=windows] +url = http://strawberryperl.com/download/5.28.1.1/strawberry-perl-5.28.1.1-64bit-portable.zip +sha256 = 935c95ba096fa11c4e1b5188732e3832d330a2a79e9882ab7ba8460ddbca810d + +[build.os=windows] +builder = nop +subdir = perl + diff --git a/build/fbcode_builder/manifests/pexpect b/build/fbcode_builder/manifests/pexpect new file mode 100644 index 000000000..682e66a54 --- /dev/null +++ b/build/fbcode_builder/manifests/pexpect @@ -0,0 +1,12 @@ +[manifest] +name = pexpect + +[download] +url = https://files.pythonhosted.org/packages/0e/3e/377007e3f36ec42f1b84ec322ee12141a9e10d808312e5738f52f80a232c/pexpect-4.7.0-py2.py3-none-any.whl +sha256 = 2094eefdfcf37a1fdbfb9aa090862c1a4878e5c7e0e7e7088bdb511c558e5cd1 + +[build] +builder = python-wheel + +[dependencies] +python-ptyprocess diff --git a/build/fbcode_builder/manifests/protobuf b/build/fbcode_builder/manifests/protobuf new file mode 100644 index 000000000..7f21e4821 --- /dev/null +++ b/build/fbcode_builder/manifests/protobuf @@ -0,0 +1,17 @@ +[manifest] +name = protobuf + +[rpms] +protobuf-devel + +[debs] +libprotobuf-dev + +[git] +repo_url = https://github.com/protocolbuffers/protobuf.git + +[build.not(os=windows)] +builder = autoconf + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/proxygen b/build/fbcode_builder/manifests/proxygen new file mode 100644 index 000000000..5452a2454 --- /dev/null +++ b/build/fbcode_builder/manifests/proxygen @@ -0,0 +1,39 @@ +[manifest] +name = proxygen +fbsource_path = fbcode/proxygen +shipit_project = proxygen +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/proxygen.git + +[build.os=windows] +builder = nop + +[build] +builder = cmake +subdir = . + +[cmake.defines] +BUILD_QUIC = ON + +[cmake.defines.test=on] +BUILD_TESTS = ON + +[cmake.defines.test=off] +BUILD_TESTS = OFF + +[dependencies] +zlib +gperf +folly +fizz +wangle +mvfst + +[dependencies.test=on] +googletest_1_8 + +[shipit.pathmap] +fbcode/proxygen/public_tld = . +fbcode/proxygen = proxygen diff --git a/build/fbcode_builder/manifests/python b/build/fbcode_builder/manifests/python new file mode 100644 index 000000000..e51c0ab51 --- /dev/null +++ b/build/fbcode_builder/manifests/python @@ -0,0 +1,17 @@ +[manifest] +name = python + +[rpms] +python3 +python3-devel + +[debs] +python3-all-dev + +[download.os=linux] +url = https://www.python.org/ftp/python/3.7.6/Python-3.7.6.tgz +sha256 = aeee681c235ad336af116f08ab6563361a0c81c537072c1b309d6e4050aa2114 + +[build.os=linux] +builder = autoconf +subdir = Python-3.7.6 diff --git a/build/fbcode_builder/manifests/python-click b/build/fbcode_builder/manifests/python-click new file mode 100644 index 000000000..ea9a9d2d3 --- /dev/null +++ b/build/fbcode_builder/manifests/python-click @@ -0,0 +1,9 @@ +[manifest] +name = python-click + +[download] +url = https://files.pythonhosted.org/packages/d2/3d/fa76db83bf75c4f8d338c2fd15c8d33fdd7ad23a9b5e57eb6c5de26b430e/click-7.1.2-py2.py3-none-any.whl +sha256 = dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/python-dulwich b/build/fbcode_builder/manifests/python-dulwich new file mode 100644 index 000000000..0d995e12f --- /dev/null +++ b/build/fbcode_builder/manifests/python-dulwich @@ -0,0 +1,19 @@ +[manifest] +name = python-dulwich + +# The below links point to custom github forks of project dulwich, because the +# 0.18.6 version didn't have an official rollout of wheel packages. + +[download.os=linux] +url = https://github.com/lukaspiatkowski/dulwich/releases/download/dulwich-0.18.6-wheel/dulwich-0.18.6-cp36-cp36m-linux_x86_64.whl +sha256 = e96f545f3d003e67236785473caaba2c368e531ea85fd508a3bd016ebac3a6d8 + +[download.os=darwin] +url = https://github.com/lukaspiatkowski/dulwich/releases/download/dulwich-0.18.6-wheel/dulwich-0.18.6-cp37-cp37m-macosx_10_14_x86_64.whl +sha256 = 8373652056284ad40ea5220b659b3489b0a91f25536322345a3e4b5d29069308 + +[build.not(os=windows)] +builder = python-wheel + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/python-ptyprocess b/build/fbcode_builder/manifests/python-ptyprocess new file mode 100644 index 000000000..adc60e048 --- /dev/null +++ b/build/fbcode_builder/manifests/python-ptyprocess @@ -0,0 +1,9 @@ +[manifest] +name = python-ptyprocess + +[download] +url = https://files.pythonhosted.org/packages/d1/29/605c2cc68a9992d18dada28206eeada56ea4bd07a239669da41674648b6f/ptyprocess-0.6.0-py2.py3-none-any.whl +sha256 = d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/python-six b/build/fbcode_builder/manifests/python-six new file mode 100644 index 000000000..a712188dc --- /dev/null +++ b/build/fbcode_builder/manifests/python-six @@ -0,0 +1,9 @@ +[manifest] +name = python-six + +[download] +url = https://files.pythonhosted.org/packages/73/fb/00a976f728d0d1fecfe898238ce23f502a721c0ac0ecfedb80e0d88c64e9/six-1.12.0-py2.py3-none-any.whl +sha256 = 3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/python-toml b/build/fbcode_builder/manifests/python-toml new file mode 100644 index 000000000..b49a3b8fb --- /dev/null +++ b/build/fbcode_builder/manifests/python-toml @@ -0,0 +1,9 @@ +[manifest] +name = python-toml + +[download] +url = https://files.pythonhosted.org/packages/a2/12/ced7105d2de62fa7c8fb5fce92cc4ce66b57c95fb875e9318dba7f8c5db0/toml-0.10.0-py2.py3-none-any.whl +sha256 = 235682dd292d5899d361a811df37e04a8828a5b1da3115886b73cf81ebc9100e + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/re2 b/build/fbcode_builder/manifests/re2 new file mode 100644 index 000000000..eb4d6a92c --- /dev/null +++ b/build/fbcode_builder/manifests/re2 @@ -0,0 +1,17 @@ +[manifest] +name = re2 + +[rpms] +re2 +re2-devel + +[debs] +libre2-dev + +[download] +url = https://github.com/google/re2/archive/2019-06-01.tar.gz +sha256 = 02b7d73126bd18e9fbfe5d6375a8bb13fadaf8e99e48cbb062e4500fc18e8e2e + +[build] +builder = cmake +subdir = re2-2019-06-01 diff --git a/build/fbcode_builder/manifests/rocksdb b/build/fbcode_builder/manifests/rocksdb new file mode 100644 index 000000000..323e6dc6d --- /dev/null +++ b/build/fbcode_builder/manifests/rocksdb @@ -0,0 +1,41 @@ +[manifest] +name = rocksdb + +[download] +url = https://github.com/facebook/rocksdb/archive/v6.8.1.tar.gz +sha256 = ca192a06ed3bcb9f09060add7e9d0daee1ae7a8705a3d5ecbe41867c5e2796a2 + +[dependencies] +lz4 +snappy + +[build] +builder = cmake +subdir = rocksdb-6.8.1 + +[cmake.defines] +WITH_SNAPPY=ON +WITH_LZ4=ON +WITH_TESTS=OFF +WITH_BENCHMARK_TOOLS=OFF +# We get relocation errors with the static gflags lib, +# and there's no clear way to make it pick the shared gflags +# so just turn it off. +WITH_GFLAGS=OFF +# mac pro machines don't have some of the newer features that +# rocksdb enables by default; ask it to disable their use even +# when building on new hardware +PORTABLE = ON +# Disable the use of -Werror +FAIL_ON_WARNINGS = OFF + +[cmake.defines.os=windows] +ROCKSDB_INSTALL_ON_WINDOWS=ON +# RocksDB hard codes the paths to the snappy libs to something +# that doesn't exist; ignoring the usual cmake rules. As a result, +# we can't build it with snappy without either patching rocksdb or +# without introducing more complex logic to the build system to +# connect the snappy build outputs to rocksdb's custom logic here. +# Let's just turn it off on windows. +WITH_SNAPPY=OFF +WITH_LZ4=OFF diff --git a/build/fbcode_builder/manifests/rust-shed b/build/fbcode_builder/manifests/rust-shed new file mode 100644 index 000000000..c94b3fdd6 --- /dev/null +++ b/build/fbcode_builder/manifests/rust-shed @@ -0,0 +1,34 @@ +[manifest] +name = rust-shed +fbsource_path = fbcode/common/rust/shed +shipit_project = rust-shed +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/rust-shed.git + +[build] +builder = cargo + +[cargo] +build_doc = true +workspace_dir = + +[shipit.pathmap] +fbcode/common/rust/shed = shed +fbcode/common/rust/shed/public_autocargo = shed +fbcode/common/rust/shed/public_tld = . +tools/rust/ossconfigs = . + +[shipit.strip] +^fbcode/common/rust/shed/(?!public_autocargo|public_tld).+/Cargo\.toml$ + +[dependencies] +fbthrift +# macOS doesn't expose the openssl api so we need to build our own. +# Windows doesn't have openssl and Linux might contain an old version, +# so we get to provide it +openssl + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/snappy b/build/fbcode_builder/manifests/snappy new file mode 100644 index 000000000..2f46a7734 --- /dev/null +++ b/build/fbcode_builder/manifests/snappy @@ -0,0 +1,25 @@ +[manifest] +name = snappy + +[rpms] +snappy +snappy-devel + +[debs] +libsnappy-dev + +[download] +url = https://github.com/google/snappy/archive/1.1.7.tar.gz +sha256 = 3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4 + +[build] +builder = cmake +subdir = snappy-1.1.7 + +[cmake.defines] +SNAPPY_BUILD_TESTS = OFF + +# Avoid problems like `relocation R_X86_64_PC32 against symbol` on ELF systems +# when linking rocksdb, which builds PIC even when building a static lib +[cmake.defines.os=linux] +BUILD_SHARED_LIBS = ON diff --git a/build/fbcode_builder/manifests/sqlite3 b/build/fbcode_builder/manifests/sqlite3 new file mode 100644 index 000000000..2463f5761 --- /dev/null +++ b/build/fbcode_builder/manifests/sqlite3 @@ -0,0 +1,21 @@ +[manifest] +name = sqlite3 + +[rpms] +sqlite-devel +sqlite-libs + +[debs] +libsqlite3-dev + +[download] +url = https://sqlite.org/2019/sqlite-amalgamation-3280000.zip +sha256 = d02fc4e95cfef672b45052e221617a050b7f2e20103661cda88387349a9b1327 + +[dependencies] +cmake +ninja + +[build] +builder = sqlite +subdir = sqlite-amalgamation-3280000 diff --git a/build/fbcode_builder/manifests/sqlite3-bin b/build/fbcode_builder/manifests/sqlite3-bin new file mode 100644 index 000000000..aa138d499 --- /dev/null +++ b/build/fbcode_builder/manifests/sqlite3-bin @@ -0,0 +1,28 @@ +[manifest] +name = sqlite3-bin + +[rpms] +sqlite + +[debs] +sqlite3 + +[download.os=linux] +url = https://github.com/sqlite/sqlite/archive/version-3.33.0.tar.gz +sha256 = 48e5f989eefe9af0ac758096f82ead0f3c7b58118ac17cc5810495bd5084a331 + +[build.os=linux] +builder = autoconf +subdir = sqlite-version-3.33.0 + +[build.not(os=linux)] +# MacOS comes with sqlite3 preinstalled and don't need Windows here +builder = nop + +[dependencies.os=linux] +tcl + +[autoconf.args] +# This flag disabled tcl as a runtime library used for some functionality, +# but tcl is still a required dependency as it is used by the build files +--disable-tcl diff --git a/build/fbcode_builder/manifests/tcl b/build/fbcode_builder/manifests/tcl new file mode 100644 index 000000000..5e9892f37 --- /dev/null +++ b/build/fbcode_builder/manifests/tcl @@ -0,0 +1,20 @@ +[manifest] +name = tcl + +[rpms] +tcl + +[debs] +tcl + +[download] +url = https://github.com/tcltk/tcl/archive/core-8-7a3.tar.gz +sha256 = 22d748f0c9652f3ecc195fed3f24a1b6eea8d449003085e6651197951528982e + +[build.os=linux] +builder = autoconf +subdir = tcl-core-8-7a3/unix + +[build.not(os=linux)] +# This is for sqlite3 on Linux for now +builder = nop diff --git a/build/fbcode_builder/manifests/tree b/build/fbcode_builder/manifests/tree new file mode 100644 index 000000000..0c982f35a --- /dev/null +++ b/build/fbcode_builder/manifests/tree @@ -0,0 +1,34 @@ +[manifest] +name = tree + +[rpms] +tree + +[debs] +tree + +[download.os=linux] +url = https://salsa.debian.org/debian/tree-packaging/-/archive/debian/1.8.0-1/tree-packaging-debian-1.8.0-1.tar.gz +sha256 = a841eee1d52bfd64a48f54caab9937b9bd92935055c48885c4ab1ae4dab7fae5 + +[download.os=darwin] +# The official package of tree source requires users of non-Linux platform to +# comment/uncomment certain lines in the Makefile to build for their platform. +# Besauce getdeps.py doesn't have that functionality we just use this custom +# fork of tree which has proper lines uncommented for a OSX build +url = https://github.com/lukaspiatkowski/tree-command/archive/debian/1.8.0-1-macos.tar.gz +sha256 = 9cbe889553d95cf5a2791dd0743795d46a3c092c5bba691769c0e5c52e11229e + +[build.os=linux] +builder = make +subdir = tree-packaging-debian-1.8.0-1 + +[build.os=darwin] +builder = make +subdir = tree-command-debian-1.8.0-1-macos + +[build.os=windows] +builder = nop + +[make.install_args] +install diff --git a/build/fbcode_builder/manifests/wangle b/build/fbcode_builder/manifests/wangle new file mode 100644 index 000000000..6b330d620 --- /dev/null +++ b/build/fbcode_builder/manifests/wangle @@ -0,0 +1,27 @@ +[manifest] +name = wangle +fbsource_path = fbcode/wangle +shipit_project = wangle +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/wangle.git + +[build] +builder = cmake +subdir = wangle + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF + +[dependencies] +folly +googletest +fizz + +[shipit.pathmap] +fbcode/wangle/public_tld = . +fbcode/wangle = wangle diff --git a/build/fbcode_builder/manifests/watchman b/build/fbcode_builder/manifests/watchman new file mode 100644 index 000000000..0fcd6bb9f --- /dev/null +++ b/build/fbcode_builder/manifests/watchman @@ -0,0 +1,45 @@ +[manifest] +name = watchman +fbsource_path = fbcode/watchman +shipit_project = watchman +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/watchman.git + +[build] +builder = cmake + +[dependencies] +boost +cpptoml +fb303 +fbthrift +folly +pcre +googletest + +[dependencies.fb=on] +rust + +[shipit.pathmap] +fbcode/watchman = watchman +fbcode/watchman/oss = . +fbcode/eden/fs = eden/fs + +[shipit.strip] +^fbcode/eden/fs/(?!.*\.thrift|service/shipit_test_file\.txt) + +[cmake.defines.fb=on] +ENABLE_EDEN_SUPPORT=ON + +# FB macos specific settings +[cmake.defines.all(fb=on,os=darwin)] +# this path is coupled with the FB internal watchman-osx.spec +WATCHMAN_STATE_DIR=/opt/facebook/watchman/var/run/watchman +# tell cmake not to try to create /opt/facebook/... +INSTALL_WATCHMAN_STATE_DIR=OFF +USE_SYS_PYTHON=OFF + +[depends.environment] +WATCHMAN_VERSION_OVERRIDE diff --git a/build/fbcode_builder/manifests/yaml-cpp b/build/fbcode_builder/manifests/yaml-cpp new file mode 100644 index 000000000..bffa540fe --- /dev/null +++ b/build/fbcode_builder/manifests/yaml-cpp @@ -0,0 +1,20 @@ +[manifest] +name = yaml-cpp + +[download] +url = https://github.com/jbeder/yaml-cpp/archive/yaml-cpp-0.6.2.tar.gz +sha256 = e4d8560e163c3d875fd5d9e5542b5fd5bec810febdcba61481fe5fc4e6b1fd05 + +[build.os=linux] +builder = cmake +subdir = yaml-cpp-yaml-cpp-0.6.2 + +[build.not(os=linux)] +builder = nop + +[dependencies] +boost +googletest + +[cmake.defines] +YAML_CPP_BUILD_TESTS=OFF diff --git a/build/fbcode_builder/manifests/zlib b/build/fbcode_builder/manifests/zlib new file mode 100644 index 000000000..8df0e3e48 --- /dev/null +++ b/build/fbcode_builder/manifests/zlib @@ -0,0 +1,22 @@ +[manifest] +name = zlib + +[rpms] +zlib-devel +zlib-static + +[debs] +zlib1g-dev + +[download] +url = http://www.zlib.net/zlib-1.2.11.tar.gz +sha256 = c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1 + +[build.os=windows] +builder = cmake +subdir = zlib-1.2.11 + +# Every platform but windows ships with zlib, so just skip +# building on not(windows) +[build.not(os=windows)] +builder = nop diff --git a/build/fbcode_builder/manifests/zstd b/build/fbcode_builder/manifests/zstd new file mode 100644 index 000000000..71db9d5c6 --- /dev/null +++ b/build/fbcode_builder/manifests/zstd @@ -0,0 +1,28 @@ +[manifest] +name = zstd + +[rpms] +libzstd-devel +libzstd + +[debs] +libzstd-dev + +[download] +url = https://github.com/facebook/zstd/releases/download/v1.4.5/zstd-1.4.5.tar.gz +sha256 = 98e91c7c6bf162bf90e4e70fdbc41a8188b9fa8de5ad840c401198014406ce9e + +[build] +builder = cmake +subdir = zstd-1.4.5/build/cmake + +# The zstd cmake build explicitly sets the install name +# for the shared library in such a way that cmake discards +# the path to the library from the install_name, rendering +# the library non-resolvable during the build. The short +# term solution for this is just to link static on macos. +[cmake.defines.os=darwin] +ZSTD_BUILD_SHARED = OFF + +[cmake.defines.os=windows] +ZSTD_BUILD_SHARED = OFF diff --git a/build/fbcode_builder/parse_args.py b/build/fbcode_builder/parse_args.py new file mode 100644 index 000000000..8d5e35330 --- /dev/null +++ b/build/fbcode_builder/parse_args.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +"Argument parsing logic shared by all fbcode_builder CLI tools." + +import argparse +import logging + +from shell_quoting import raw_shell, ShellQuoted + + +def parse_args_to_fbcode_builder_opts(add_args_fn, top_level_opts, opts, help): + """ + + Provides some standard arguments: --debug, --option, --shell-quoted-option + + Then, calls `add_args_fn(parser)` to add application-specific arguments. + + `opts` are first used as defaults for the various command-line + arguments. Then, the parsed arguments are mapped back into `opts`, + which then become the values for `FBCodeBuilder.option()`, to be used + both by the builder and by `get_steps_fn()`. + + `help` is printed in response to the `--help` argument. + + """ + top_level_opts = set(top_level_opts) + + parser = argparse.ArgumentParser( + description=help, formatter_class=argparse.RawDescriptionHelpFormatter + ) + + add_args_fn(parser) + + parser.add_argument( + "--option", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + default=[ + (k, v) + for k, v in opts.items() + if k not in top_level_opts and not isinstance(v, ShellQuoted) + ], + help="Set project-specific options. These are assumed to be raw " + "strings, to be shell-escaped as needed. Default: %(default)s.", + ) + parser.add_argument( + "--shell-quoted-option", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + default=[ + (k, raw_shell(v)) + for k, v in opts.items() + if k not in top_level_opts and isinstance(v, ShellQuoted) + ], + help="Set project-specific options. These are assumed to be shell-" + "quoted, and may be used in commands as-is. Default: %(default)s.", + ) + + parser.add_argument("--debug", action="store_true", help="Log more") + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.debug else logging.INFO, + format="%(levelname)s: %(message)s", + ) + + # Map command-line args back into opts. + logging.debug("opts before command-line arguments: {0}".format(opts)) + + new_opts = {} + for key in top_level_opts: + val = getattr(args, key) + # Allow clients to unset a default by passing a value of None in opts + if val is not None: + new_opts[key] = val + for key, val in args.option: + new_opts[key] = val + for key, val in args.shell_quoted_option: + new_opts[key] = ShellQuoted(val) + + logging.debug("opts after command-line arguments: {0}".format(new_opts)) + + return new_opts diff --git a/build/fbcode_builder/shell_builder.py b/build/fbcode_builder/shell_builder.py new file mode 100644 index 000000000..e0d5429ad --- /dev/null +++ b/build/fbcode_builder/shell_builder.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" +shell_builder.py allows running the fbcode_builder logic +on the host rather than in a container. + +It emits a bash script with set -exo pipefail configured such that +any failing step will cause the script to exit with failure. + +== How to run it? == + +cd build +python fbcode_builder/shell_builder.py > ~/run.sh +bash ~/run.sh +""" + +import distutils.spawn +import os + +from fbcode_builder import FBCodeBuilder +from shell_quoting import raw_shell, shell_comment, shell_join, ShellQuoted +from utils import recursively_flatten_list + + +class ShellFBCodeBuilder(FBCodeBuilder): + def _render_impl(self, steps): + return raw_shell(shell_join("\n", recursively_flatten_list(steps))) + + def set_env(self, key, value): + return ShellQuoted("export {key}={val}").format(key=key, val=value) + + def workdir(self, dir): + return [ + ShellQuoted("mkdir -p {d} && cd {d}").format(d=dir), + ] + + def run(self, shell_cmd): + return ShellQuoted("{cmd}").format(cmd=shell_cmd) + + def step(self, name, actions): + assert "\n" not in name, "Name {0} would span > 1 line".format(name) + b = ShellQuoted("") + return [ShellQuoted("### {0} ###".format(name)), b] + actions + [b] + + def setup(self): + steps = ( + [ + ShellQuoted("set -exo pipefail"), + ] + + self.create_python_venv() + + self.python_venv() + ) + if self.has_option("ccache_dir"): + ccache_dir = self.option("ccache_dir") + steps += [ + ShellQuoted( + # Set CCACHE_DIR before the `ccache` invocations below. + "export CCACHE_DIR={ccache_dir} " + 'CC="ccache ${{CC:-gcc}}" CXX="ccache ${{CXX:-g++}}"' + ).format(ccache_dir=ccache_dir) + ] + return steps + + def comment(self, comment): + return shell_comment(comment) + + def copy_local_repo(self, dir, dest_name): + return [ + ShellQuoted("cp -r {dir} {dest_name}").format(dir=dir, dest_name=dest_name), + ] + + +def find_project_root(): + here = os.path.dirname(os.path.realpath(__file__)) + maybe_root = os.path.dirname(os.path.dirname(here)) + if os.path.isdir(os.path.join(maybe_root, ".git")): + return maybe_root + raise RuntimeError( + "I expected shell_builder.py to be in the " + "build/fbcode_builder subdir of a git repo" + ) + + +def persistent_temp_dir(repo_root): + escaped = repo_root.replace("/", "sZs").replace("\\", "sZs").replace(":", "") + return os.path.join(os.path.expandvars("$HOME"), ".fbcode_builder-" + escaped) + + +if __name__ == "__main__": + from utils import read_fbcode_builder_config, build_fbcode_builder_config + + repo_root = find_project_root() + temp = persistent_temp_dir(repo_root) + + config = read_fbcode_builder_config("fbcode_builder_config.py") + builder = ShellFBCodeBuilder(projects_dir=temp) + + if distutils.spawn.find_executable("ccache"): + builder.add_option( + "ccache_dir", os.environ.get("CCACHE_DIR", os.path.join(temp, ".ccache")) + ) + builder.add_option("prefix", os.path.join(temp, "installed")) + builder.add_option("make_parallelism", 4) + builder.add_option( + "{project}:local_repo_dir".format(project=config["github_project"]), repo_root + ) + make_steps = build_fbcode_builder_config(config) + steps = make_steps(builder) + print(builder.render(steps)) diff --git a/build/fbcode_builder/shell_quoting.py b/build/fbcode_builder/shell_quoting.py new file mode 100644 index 000000000..7429226bd --- /dev/null +++ b/build/fbcode_builder/shell_quoting.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" + +Almost every FBCodeBuilder string is ultimately passed to a shell. Escaping +too little or too much tends to be the most common error. The utilities in +this file give a systematic way of avoiding such bugs: + - When you write literal strings destined for the shell, use `ShellQuoted`. + - When these literal strings are parameterized, use `ShellQuoted.format`. + - Any parameters that are raw strings get `shell_quote`d automatically, + while any ShellQuoted parameters will be left intact. + - Use `path_join` to join path components. + - Use `shell_join` to join already-quoted command arguments or shell lines. + +""" + +import os +from collections import namedtuple + + +class ShellQuoted(namedtuple("ShellQuoted", ("do_not_use_raw_str",))): + """ + + Wrap a string with this to make it transparent to shell_quote(). It + will almost always suffice to use ShellQuoted.format(), path_join(), + or shell_join(). + + If you really must, use raw_shell() to access the raw string. + + """ + + def __new__(cls, s): + "No need to nest ShellQuoted." + return super(ShellQuoted, cls).__new__( + cls, s.do_not_use_raw_str if isinstance(s, ShellQuoted) else s + ) + + def __str__(self): + raise RuntimeError( + "One does not simply convert {0} to a string -- use path_join() " + "or ShellQuoted.format() instead".format(repr(self)) + ) + + def __repr__(self): + return "{0}({1})".format(self.__class__.__name__, repr(self.do_not_use_raw_str)) + + def format(self, **kwargs): + """ + + Use instead of str.format() when the arguments are either + `ShellQuoted()` or raw strings needing to be `shell_quote()`d. + + Positional args are deliberately not supported since they are more + error-prone. + + """ + return ShellQuoted( + self.do_not_use_raw_str.format( + **dict( + (k, shell_quote(v).do_not_use_raw_str) for k, v in kwargs.items() + ) + ) + ) + + +def shell_quote(s): + "Quotes a string if it is not already quoted" + return ( + s + if isinstance(s, ShellQuoted) + else ShellQuoted("'" + str(s).replace("'", "'\\''") + "'") + ) + + +def raw_shell(s): + "Not a member of ShellQuoted so we get a useful error for raw strings" + if isinstance(s, ShellQuoted): + return s.do_not_use_raw_str + raise RuntimeError("{0} should have been ShellQuoted".format(s)) + + +def shell_join(delim, it): + "Joins an iterable of ShellQuoted with a delimiter between each two" + return ShellQuoted(delim.join(raw_shell(s) for s in it)) + + +def path_join(*args): + "Joins ShellQuoted and raw pieces of paths to make a shell-quoted path" + return ShellQuoted(os.path.join(*[raw_shell(shell_quote(s)) for s in args])) + + +def shell_comment(c): + "Do not shell-escape raw strings in comments, but do handle line breaks." + return ShellQuoted("# {c}").format( + c=ShellQuoted( + (raw_shell(c) if isinstance(c, ShellQuoted) else c).replace("\n", "\n# ") + ) + ) diff --git a/build/fbcode_builder/specs/__init__.py b/build/fbcode_builder/specs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/build/fbcode_builder/specs/fbthrift.py b/build/fbcode_builder/specs/fbthrift.py new file mode 100644 index 000000000..f0c7e7ac7 --- /dev/null +++ b/build/fbcode_builder/specs/fbthrift.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.sodium as sodium +import specs.wangle as wangle +import specs.zstd as zstd + + +def fbcode_builder_spec(builder): + return { + "depends_on": [fmt, folly, fizz, sodium, wangle, zstd], + "steps": [ + builder.fb_github_cmake_install("fbthrift/thrift"), + ], + } diff --git a/build/fbcode_builder/specs/fbzmq.py b/build/fbcode_builder/specs/fbzmq.py new file mode 100644 index 000000000..78c8bc9dd --- /dev/null +++ b/build/fbcode_builder/specs/fbzmq.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fbthrift as fbthrift +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.sodium as sodium +from shell_quoting import ShellQuoted + + +def fbcode_builder_spec(builder): + builder.add_option("zeromq/libzmq:git_hash", "v4.2.2") + return { + "depends_on": [fmt, folly, fbthrift, gmock, sodium], + "steps": [ + builder.github_project_workdir("zeromq/libzmq", "."), + builder.step( + "Build and install zeromq/libzmq", + [ + builder.run(ShellQuoted("./autogen.sh")), + builder.configure(), + builder.make_and_install(), + ], + ), + builder.fb_github_project_workdir("fbzmq/_build", "facebook"), + builder.step( + "Build and install fbzmq/", + [ + builder.cmake_configure("fbzmq/_build"), + # we need the pythonpath to find the thrift compiler + builder.run( + ShellQuoted( + 'PYTHONPATH="$PYTHONPATH:"{p}/lib/python2.7/site-packages ' + "make -j {n}" + ).format( + p=builder.option("prefix"), + n=builder.option("make_parallelism"), + ) + ), + builder.run(ShellQuoted("make install")), + ], + ), + ], + } diff --git a/build/fbcode_builder/specs/fizz.py b/build/fbcode_builder/specs/fizz.py new file mode 100644 index 000000000..82f26e67c --- /dev/null +++ b/build/fbcode_builder/specs/fizz.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.sodium as sodium +import specs.zstd as zstd + + +def fbcode_builder_spec(builder): + builder.add_option( + "fizz/fizz/build:cmake_defines", + { + # Fizz's build is kind of broken, in the sense that both `mvfst` + # and `proxygen` depend on files that are only installed with + # `BUILD_TESTS` enabled, e.g. `fizz/crypto/test/TestUtil.h`. + "BUILD_TESTS": "ON" + }, + ) + return { + "depends_on": [gmock, fmt, folly, sodium, zstd], + "steps": [ + builder.fb_github_cmake_install( + "fizz/fizz/build", github_org="facebookincubator" + ) + ], + } diff --git a/build/fbcode_builder/specs/fmt.py b/build/fbcode_builder/specs/fmt.py new file mode 100644 index 000000000..395316799 --- /dev/null +++ b/build/fbcode_builder/specs/fmt.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + builder.add_option("fmtlib/fmt:git_hash", "6.2.1") + builder.add_option( + "fmtlib/fmt:cmake_defines", + { + # Avoids a bizarred failure to run tests in Bistro: + # test_crontab_selector: error while loading shared libraries: + # libfmt.so.6: cannot open shared object file: + # No such file or directory + "BUILD_SHARED_LIBS": "OFF", + }, + ) + return { + "steps": [ + builder.github_project_workdir("fmtlib/fmt", "build"), + builder.cmake_install("fmtlib/fmt"), + ], + } diff --git a/build/fbcode_builder/specs/folly.py b/build/fbcode_builder/specs/folly.py new file mode 100644 index 000000000..e89d5e955 --- /dev/null +++ b/build/fbcode_builder/specs/folly.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fmt as fmt + + +def fbcode_builder_spec(builder): + return { + "depends_on": [fmt], + "steps": [ + # on macOS the filesystem is typically case insensitive. + # We need to ensure that the CWD is not the folly source + # dir when we build, otherwise the system will decide + # that `folly/String.h` is the file it wants when including + # `string.h` and the build will fail. + builder.fb_github_project_workdir("folly/_build"), + builder.cmake_install("facebook/folly"), + ], + } diff --git a/build/fbcode_builder/specs/gmock.py b/build/fbcode_builder/specs/gmock.py new file mode 100644 index 000000000..774137301 --- /dev/null +++ b/build/fbcode_builder/specs/gmock.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + builder.add_option("google/googletest:git_hash", "release-1.8.1") + builder.add_option( + "google/googletest:cmake_defines", + { + "BUILD_GTEST": "ON", + # Avoid problems with MACOSX_RPATH + "BUILD_SHARED_LIBS": "OFF", + }, + ) + return { + "steps": [ + builder.github_project_workdir("google/googletest", "build"), + builder.cmake_install("google/googletest"), + ], + } diff --git a/build/fbcode_builder/specs/mvfst.py b/build/fbcode_builder/specs/mvfst.py new file mode 100644 index 000000000..ce8b003d9 --- /dev/null +++ b/build/fbcode_builder/specs/mvfst.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.folly as folly +import specs.gmock as gmock + + +def fbcode_builder_spec(builder): + # Projects that **depend** on mvfst should don't need to build tests. + builder.add_option( + "mvfst/build:cmake_defines", + { + # This is set to ON in the mvfst `fbcode_builder_config.py` + "BUILD_TESTS": "OFF" + }, + ) + return { + "depends_on": [gmock, folly, fizz], + "steps": [ + builder.fb_github_cmake_install( + "mvfst/build", github_org="facebookincubator" + ) + ], + } diff --git a/build/fbcode_builder/specs/proxygen.py b/build/fbcode_builder/specs/proxygen.py new file mode 100644 index 000000000..6a584d710 --- /dev/null +++ b/build/fbcode_builder/specs/proxygen.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.mvfst as mvfst +import specs.sodium as sodium +import specs.wangle as wangle +import specs.zstd as zstd + + +def fbcode_builder_spec(builder): + # Projects that **depend** on proxygen should don't need to build tests + # or QUIC support. + builder.add_option( + "proxygen/proxygen:cmake_defines", + { + # These 2 are set to ON in `proxygen_quic.py` + "BUILD_QUIC": "OFF", + "BUILD_TESTS": "OFF", + # For bistro + "BUILD_SHARED_LIBS": "OFF", + }, + ) + + return { + "depends_on": [gmock, fmt, folly, wangle, fizz, sodium, zstd, mvfst], + "steps": [builder.fb_github_cmake_install("proxygen/proxygen", "..")], + } diff --git a/build/fbcode_builder/specs/proxygen_quic.py b/build/fbcode_builder/specs/proxygen_quic.py new file mode 100644 index 000000000..b4959fb89 --- /dev/null +++ b/build/fbcode_builder/specs/proxygen_quic.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.mvfst as mvfst +import specs.sodium as sodium +import specs.wangle as wangle +import specs.zstd as zstd + +# DO NOT USE THIS AS A LIBRARY -- this is currently effectively just part +# ofthe implementation of proxygen's `fbcode_builder_config.py`. This is +# why this builds tests and sets `BUILD_QUIC`. +def fbcode_builder_spec(builder): + builder.add_option( + "proxygen/proxygen:cmake_defines", + {"BUILD_QUIC": "ON", "BUILD_SHARED_LIBS": "OFF", "BUILD_TESTS": "ON"}, + ) + return { + "depends_on": [gmock, fmt, folly, wangle, fizz, sodium, zstd, mvfst], + "steps": [builder.fb_github_cmake_install("proxygen/proxygen", "..")], + } diff --git a/build/fbcode_builder/specs/re2.py b/build/fbcode_builder/specs/re2.py new file mode 100644 index 000000000..cf4e08a0b --- /dev/null +++ b/build/fbcode_builder/specs/re2.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + return { + "steps": [ + builder.github_project_workdir("google/re2", "build"), + builder.cmake_install("google/re2"), + ], + } diff --git a/build/fbcode_builder/specs/rocksdb.py b/build/fbcode_builder/specs/rocksdb.py new file mode 100644 index 000000000..9ebfe4739 --- /dev/null +++ b/build/fbcode_builder/specs/rocksdb.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + builder.add_option( + "rocksdb/_build:cmake_defines", + { + "USE_RTTI": "1", + "PORTABLE": "ON", + }, + ) + return { + "steps": [ + builder.fb_github_cmake_install("rocksdb/_build"), + ], + } diff --git a/build/fbcode_builder/specs/sodium.py b/build/fbcode_builder/specs/sodium.py new file mode 100644 index 000000000..8be9833cf --- /dev/null +++ b/build/fbcode_builder/specs/sodium.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from shell_quoting import ShellQuoted + + +def fbcode_builder_spec(builder): + builder.add_option("jedisct1/libsodium:git_hash", "stable") + return { + "steps": [ + builder.github_project_workdir("jedisct1/libsodium", "."), + builder.step( + "Build and install jedisct1/libsodium", + [ + builder.run(ShellQuoted("./autogen.sh")), + builder.configure(), + builder.make_and_install(), + ], + ), + ], + } diff --git a/build/fbcode_builder/specs/wangle.py b/build/fbcode_builder/specs/wangle.py new file mode 100644 index 000000000..62b5b3c86 --- /dev/null +++ b/build/fbcode_builder/specs/wangle.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.sodium as sodium + + +def fbcode_builder_spec(builder): + # Projects that **depend** on wangle need not spend time on tests. + builder.add_option( + "wangle/wangle/build:cmake_defines", + { + # This is set to ON in the wangle `fbcode_builder_config.py` + "BUILD_TESTS": "OFF" + }, + ) + return { + "depends_on": [gmock, fmt, folly, fizz, sodium], + "steps": [builder.fb_github_cmake_install("wangle/wangle/build")], + } diff --git a/build/fbcode_builder/specs/zstd.py b/build/fbcode_builder/specs/zstd.py new file mode 100644 index 000000000..14d9a1249 --- /dev/null +++ b/build/fbcode_builder/specs/zstd.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from shell_quoting import ShellQuoted + + +def fbcode_builder_spec(builder): + # This API should change rarely, so build the latest tag instead of master. + builder.add_option( + "facebook/zstd:git_hash", + ShellQuoted("$(git describe --abbrev=0 --tags origin/master)"), + ) + return { + "steps": [ + builder.github_project_workdir("facebook/zstd", "."), + builder.step( + "Build and install zstd", + [ + builder.make_and_install( + make_vars={ + "PREFIX": builder.option("prefix"), + } + ) + ], + ), + ], + } diff --git a/build/fbcode_builder/travis.yml b/build/fbcode_builder/travis.yml new file mode 100644 index 000000000..d2bb60778 --- /dev/null +++ b/build/fbcode_builder/travis.yml @@ -0,0 +1,51 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Facebook projects that use `fbcode_builder` for continuous integration +# share this Travis configuration to run builds via Docker. + +# Docker disables IPv6 in containers by default. Enable it for unit tests that need [::1]. +before_script: + - if [[ "$TRAVIS_OS_NAME" != "osx" ]]; + then + sudo build/fbcode_builder/docker_enable_ipv6.sh; + fi + +env: + global: + - travis_cache_dir=$HOME/travis_ccache + # Travis times out after 50 minutes. Very generously leave 10 minutes + # for setup (e.g. cache download, compression, and upload), so we never + # fail to cache the progress we made. + - docker_build_timeout=40m + +cache: + # Our build caches can be 200-300MB, so increase the timeout to 7 minutes + # to make sure we never fail to cache the progress we made. + timeout: 420 + directories: + - $HOME/travis_ccache # see docker_build_with_ccache.sh + +# Ugh, `services:` must be in the matrix, or we get `docker: command not found` +# https://github.com/travis-ci/travis-ci/issues/5142 +matrix: + include: + - env: ['os_image=ubuntu:18.04', gcc_version=7] + services: [docker] + +addons: + apt: + packages: python2.7 + +script: + # We don't want to write the script inline because of Travis kludginess -- + # it looks like it escapes " and \ in scripts when using `matrix:`. + - ./build/fbcode_builder/travis_docker_build.sh diff --git a/build/fbcode_builder/travis_docker_build.sh b/build/fbcode_builder/travis_docker_build.sh new file mode 100755 index 000000000..d4cba10ef --- /dev/null +++ b/build/fbcode_builder/travis_docker_build.sh @@ -0,0 +1,42 @@ +#!/bin/bash -uex +# Copyright (c) Facebook, Inc. and its affiliates. +# .travis.yml in the top-level dir explains why this is a separate script. +# Read the docs: ./make_docker_context.py --help + +os_image=${os_image?Must be set by Travis} +gcc_version=${gcc_version?Must be set by Travis} +make_parallelism=${make_parallelism:-4} +# ccache is off unless requested +travis_cache_dir=${travis_cache_dir:-} +# The docker build never times out, unless specified +docker_build_timeout=${docker_build_timeout:-} + +cur_dir="$(realpath "$(dirname "$0")")" + +if [[ "$travis_cache_dir" == "" ]]; then + echo "ccache disabled, enable by setting env. var. travis_cache_dir" + ccache_tgz="" +elif [[ -e "$travis_cache_dir/ccache.tgz" ]]; then + ccache_tgz="$travis_cache_dir/ccache.tgz" +else + echo "$travis_cache_dir/ccache.tgz does not exist, starting with empty cache" + ccache_tgz=$(mktemp) + tar -T /dev/null -czf "$ccache_tgz" +fi + +docker_context_dir=$( + cd "$cur_dir/.." # Let the script find our fbcode_builder_config.py + "$cur_dir/make_docker_context.py" \ + --os-image "$os_image" \ + --gcc-version "$gcc_version" \ + --make-parallelism "$make_parallelism" \ + --local-repo-dir "$cur_dir/../.." \ + --ccache-tgz "$ccache_tgz" +) +cd "${docker_context_dir?Failed to make Docker context directory}" + +# Make it safe to iterate on the .sh in the tree while the script runs. +cp "$cur_dir/docker_build_with_ccache.sh" . +exec ./docker_build_with_ccache.sh \ + --build-timeout "$docker_build_timeout" \ + "$travis_cache_dir" diff --git a/build/fbcode_builder/utils.py b/build/fbcode_builder/utils.py new file mode 100644 index 000000000..02459a200 --- /dev/null +++ b/build/fbcode_builder/utils.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +"Miscellaneous utility functions." + +import itertools +import logging +import os +import shutil +import subprocess +import sys +from contextlib import contextmanager + + +def recursively_flatten_list(l): + return itertools.chain.from_iterable( + (recursively_flatten_list(i) if type(i) is list else (i,)) for i in l + ) + + +def run_command(*cmd, **kwargs): + "The stdout of most fbcode_builder utilities is meant to be parsed." + logging.debug("Running: {0} with {1}".format(cmd, kwargs)) + kwargs["stdout"] = sys.stderr + subprocess.check_call(cmd, **kwargs) + + +@contextmanager +def make_temp_dir(d): + os.mkdir(d) + try: + yield d + finally: + shutil.rmtree(d, ignore_errors=True) + + +def _inner_read_config(path): + """ + Helper to read a named config file. + The grossness with the global is a workaround for this python bug: + https://bugs.python.org/issue21591 + The bug prevents us from defining either a local function or a lambda + in the scope of read_fbcode_builder_config below. + """ + global _project_dir + full_path = os.path.join(_project_dir, path) + return read_fbcode_builder_config(full_path) + + +def read_fbcode_builder_config(filename): + # Allow one spec to read another + # When doing so, treat paths as relative to the config's project directory. + # _project_dir is a "local" for _inner_read_config; see the comments + # in that function for an explanation of the use of global. + global _project_dir + _project_dir = os.path.dirname(filename) + + scope = {"read_fbcode_builder_config": _inner_read_config} + with open(filename) as config_file: + code = compile(config_file.read(), filename, mode="exec") + exec(code, scope) + return scope["config"] + + +def steps_for_spec(builder, spec, processed_modules=None): + """ + Sets `builder` configuration, and returns all the builder steps + necessary to build `spec` and its dependencies. + + Traverses the dependencies in depth-first order, honoring the sequencing + in each 'depends_on' list. + """ + if processed_modules is None: + processed_modules = set() + steps = [] + for module in spec.get("depends_on", []): + if module not in processed_modules: + processed_modules.add(module) + steps.extend( + steps_for_spec( + builder, module.fbcode_builder_spec(builder), processed_modules + ) + ) + steps.extend(spec.get("steps", [])) + return steps + + +def build_fbcode_builder_config(config): + return lambda builder: builder.build( + steps_for_spec(builder, config["fbcode_builder_spec"](builder)) + ) diff --git a/build/fbcode_builder_config.py b/build/fbcode_builder_config.py new file mode 100644 index 000000000..85018bf05 --- /dev/null +++ b/build/fbcode_builder_config.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +'fbcode_builder steps to build rsocket' + +import specs.rsocket as rsocket + + +def fbcode_builder_spec(builder): + return { + 'depends_on': [rsocket], + } + + +config = { + 'github_project': 'rsocket/rsocket-cpp', + 'fbcode_builder_spec': fbcode_builder_spec, +} diff --git a/cmake/FindFolly.cmake b/cmake/FindFolly.cmake deleted file mode 100644 index 30736a77f..000000000 --- a/cmake/FindFolly.cmake +++ /dev/null @@ -1,15 +0,0 @@ -cmake_minimum_required(VERSION 3.2) - -include(FindPackageHandleStandardArgs) - -if (FOLLY_INSTALL_DIR) - set(lib_paths ${FOLLY_INSTALL_DIR}/lib) - set(include_paths ${FOLLY_INSTALL_DIR}/include) -endif () - -find_library(FOLLY_LIBRARY folly PATHS ${lib_paths}) -find_library(FOLLY_BENCHMARK_LIBRARY follybenchmark PATHS ${lib_paths}) -find_path(FOLLY_INCLUDE_DIR "folly/String.h" PATHS ${include_paths}) - -find_package_handle_standard_args(Folly - DEFAULT_MSG FOLLY_LIBRARY FOLLY_BENCHMARK_LIBRARY FOLLY_INCLUDE_DIR) diff --git a/cmake/InstallFolly.cmake b/cmake/InstallFolly.cmake index 432c8c7fd..2bd17460c 100644 --- a/cmake/InstallFolly.cmake +++ b/cmake/InstallFolly.cmake @@ -1,73 +1,22 @@ +# Copyright (c) 2018, Facebook, Inc. +# All rights reserved. +# if (NOT FOLLY_INSTALL_DIR) - set(FOLLY_INSTALL_DIR $ENV{HOME}/folly) + set(FOLLY_INSTALL_DIR ${CMAKE_BINARY_DIR}/folly-install) endif () -# Check if the correct version of folly is already installed. -set(FOLLY_VERSION v2017.12.11.00) -set(FOLLY_VERSION_FILE ${FOLLY_INSTALL_DIR}/${FOLLY_VERSION}) if (RSOCKET_INSTALL_DEPS) - if (NOT EXISTS ${FOLLY_VERSION_FILE}) - # Remove the old version of folly. - file(REMOVE_RECURSE ${FOLLY_INSTALL_DIR}) - set(INSTALL_FOLLY True) - endif () -endif () - -if (INSTALL_FOLLY) - # Build and install folly. - ExternalProject_Add( - folly-ext - GIT_REPOSITORY https://github.com/facebook/folly - GIT_TAG ${FOLLY_VERSION} - BINARY_DIR folly-ext-prefix/src/folly-ext/folly - CONFIGURE_COMMAND autoreconf -ivf - COMMAND ./configure CXX=${CMAKE_CXX_COMPILER} - --prefix=${FOLLY_INSTALL_DIR} - BUILD_COMMAND make -j4 - INSTALL_COMMAND make install - COMMAND cmake -E touch ${FOLLY_VERSION_FILE}) - - set(FOLLY_INCLUDE_DIR ${FOLLY_INSTALL_DIR}/include) - set(lib ${CMAKE_SHARED_LIBRARY_PREFIX}folly${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(benchlib ${CMAKE_SHARED_LIBRARY_PREFIX}follybenchmark${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(FOLLY_LIBRARY ${FOLLY_INSTALL_DIR}/lib/${lib}) - set(FOLLY_BENCHMARK_LIBRARY ${FOLLY_INSTALL_DIR}/lib/${benchlib}) - - # CMake requires directories listed in INTERFACE_INCLUDE_DIRECTORIES to exist. - file(MAKE_DIRECTORY ${FOLLY_INCLUDE_DIR}) -else () - # Use installed folly. - find_package(Folly REQUIRED) + execute_process( + COMMAND + ${CMAKE_SOURCE_DIR}/scripts/build_folly.sh + ${CMAKE_BINARY_DIR}/folly-src + ${FOLLY_INSTALL_DIR} + RESULT_VARIABLE folly_result + ) + if (NOT "${folly_result}" STREQUAL "0") + message(FATAL_ERROR "failed to build folly") + endif() endif () find_package(Threads) -find_library(EVENT_LIBRARY event) - -add_library(folly SHARED IMPORTED) -set_property(TARGET folly PROPERTY IMPORTED_LOCATION ${FOLLY_LIBRARY}) -set_property(TARGET folly - APPEND PROPERTY INTERFACE_LINK_LIBRARIES - ${EXTRA_LINK_FLAGS} ${EVENT_LIBRARY} ${CMAKE_THREAD_LIBS_INIT}) -if (TARGET folly-ext) - add_dependencies(folly folly-ext) -endif () - -add_library(folly-benchmark SHARED IMPORTED) -set_property(TARGET folly-benchmark PROPERTY IMPORTED_LOCATION ${FOLLY_BENCHMARK_LIBRARY}) -set_property(TARGET folly-benchmark - APPEND PROPERTY INTERFACE_LINK_LIBRARIES - ${EXTRA_LINK_FLAGS} ${EVENT_LIBRARY} ${CMAKE_THREAD_LIBS_INIT}) -if (TARGET folly-ext) - add_dependencies(folly-benchmark folly-ext) -endif () - -# Folly includes are marked as system to prevent errors on non-standard -# extensions when compiling with -pedantic and -Werror. -set_property(TARGET folly - APPEND PROPERTY INTERFACE_SYSTEM_INCLUDE_DIRECTORIES ${FOLLY_INCLUDE_DIR}) -set_property(TARGET folly - APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${FOLLY_INCLUDE_DIR}) -set_property(TARGET folly-benchmark - APPEND PROPERTY INTERFACE_SYSTEM_INCLUDE_DIRECTORIES ${FOLLY_INCLUDE_DIR}) -set_property(TARGET folly-benchmark - APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${FOLLY_INCLUDE_DIR}) +find_package(folly CONFIG REQUIRED PATHS ${FOLLY_INSTALL_DIR}) diff --git a/cmake/rsocket-config.cmake.in b/cmake/rsocket-config.cmake.in new file mode 100644 index 000000000..d5579a856 --- /dev/null +++ b/cmake/rsocket-config.cmake.in @@ -0,0 +1,12 @@ +# Copyright (c) 2018, Facebook, Inc. +# All rights reserved. + +@PACKAGE_INIT@ + +if(NOT TARGET rsocket::ReactiveSocket) + include("${PACKAGE_PREFIX_DIR}/lib/cmake/rsocket/rsocket-exports.cmake") +endif() + +if (NOT rsocket_FIND_QUIETLY) + message(STATUS "Found rsocket: ${PACKAGE_PREFIX_DIR}") +endif() diff --git a/devtools/format_all.sh b/devtools/format_all.sh index aed32b572..235b985e2 100755 --- a/devtools/format_all.sh +++ b/devtools/format_all.sh @@ -1,4 +1,7 @@ #!/usr/bin/env bash +# +# Copyright 2004-present Facebook. All Rights Reserved. +# set -xue cd "$(dirname "$0")/.." diff --git a/examples/conditional-request-handling/JsonRequestHandler.cpp b/examples/conditional-request-handling/JsonRequestHandler.cpp deleted file mode 100644 index c19d40032..000000000 --- a/examples/conditional-request-handling/JsonRequestHandler.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "JsonRequestHandler.h" -#include -#include "yarpl/Flowable.h" - -using namespace rsocket; -using namespace yarpl::flowable; - -/// Handles a new inbound Stream requested by the other end. -yarpl::Reference> -JsonRequestResponder::handleRequestStream(Payload request, StreamId) { - LOG(INFO) << "JsonRequestResponder.handleRequestStream " << request; - - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 100)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello (should be JSON) " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); -} diff --git a/examples/conditional-request-handling/JsonRequestHandler.h b/examples/conditional-request-handling/JsonRequestHandler.h deleted file mode 100644 index f24f06ccf..000000000 --- a/examples/conditional-request-handling/JsonRequestHandler.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/Payload.h" -#include "rsocket/RSocket.h" - -class JsonRequestResponder : public rsocket::RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) - override; -}; diff --git a/examples/conditional-request-handling/TextRequestHandler.cpp b/examples/conditional-request-handling/TextRequestHandler.cpp deleted file mode 100644 index a6f0717a1..000000000 --- a/examples/conditional-request-handling/TextRequestHandler.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "TextRequestHandler.h" -#include -#include "yarpl/Flowable.h" - -using namespace rsocket; -using namespace yarpl::flowable; - -/// Handles a new inbound Stream requested by the other end. -yarpl::Reference> -TextRequestResponder::handleRequestStream(Payload request, StreamId) { - LOG(INFO) << "TextRequestResponder.handleRequestStream " << request; - - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 100)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); -} diff --git a/examples/conditional-request-handling/TextRequestHandler.h b/examples/conditional-request-handling/TextRequestHandler.h deleted file mode 100644 index 604fdbeea..000000000 --- a/examples/conditional-request-handling/TextRequestHandler.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/Payload.h" -#include "rsocket/RSocket.h" - -class TextRequestResponder : public rsocket::RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) - override; -}; diff --git a/rsocket/ColdResumeHandler.cpp b/rsocket/ColdResumeHandler.cpp index db3a2dce8..870faef48 100644 --- a/rsocket/ColdResumeHandler.cpp +++ b/rsocket/ColdResumeHandler.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/ColdResumeHandler.h" @@ -6,7 +18,6 @@ #include -using namespace yarpl; using namespace yarpl::flowable; namespace rsocket { @@ -14,20 +25,22 @@ namespace rsocket { std::string ColdResumeHandler::generateStreamToken( const Payload&, StreamId streamId, - StreamType) { + StreamType) const { return folly::to(streamId); } -Reference> ColdResumeHandler::handleResponderResumeStream( +std::shared_ptr> +ColdResumeHandler::handleResponderResumeStream( std::string /* streamToken */, size_t /* publisherAllowance */) { - return Flowables::error( + return Flowable::error( std::logic_error("ResumeHandler method not implemented")); } -Reference> ColdResumeHandler::handleRequesterResumeStream( +std::shared_ptr> +ColdResumeHandler::handleRequesterResumeStream( std::string /* streamToken */, size_t /* consumerAllowance */) { - return yarpl::make_ref>(); -} + return std::make_shared>(); } +} // namespace rsocket diff --git a/rsocket/ColdResumeHandler.h b/rsocket/ColdResumeHandler.h index dadedc3ec..f4190e16f 100644 --- a/rsocket/ColdResumeHandler.h +++ b/rsocket/ColdResumeHandler.h @@ -1,10 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include "yarpl/Flowable.h" #include "rsocket/Payload.h" +#include "rsocket/framing/FrameHeader.h" #include "rsocket/internal/Common.h" namespace rsocket { @@ -16,14 +29,15 @@ class ColdResumeHandler { virtual ~ColdResumeHandler() = default; // Generate an application-aware streamToken for the given stream parameters. - virtual std::string generateStreamToken(const Payload&, StreamId, StreamType); + virtual std::string + generateStreamToken(const Payload&, StreamId streamId, StreamType) const; // This method will be called for each REQUEST_STREAM for which the // application acted as a responder. The default action would be to return a // Flowable which errors out immediately. // The second parameter is the allowance which the application received // before cold-start and hasn't been fulfilled yet. - virtual yarpl::Reference> + virtual std::shared_ptr> handleResponderResumeStream( std::string streamToken, size_t publisherAllowance); @@ -33,9 +47,10 @@ class ColdResumeHandler { // Subscriber which cancels the stream immediately after getting subscribed. // The second parameter is the allowance which the application requested // before cold-start and hasn't been fulfilled yet. - virtual yarpl::Reference> + virtual std::shared_ptr> handleRequesterResumeStream( std::string streamToken, size_t consumerAllowance); }; -} + +} // namespace rsocket diff --git a/rsocket/ConnectionAcceptor.h b/rsocket/ConnectionAcceptor.h index 92aa5f535..3e94a4416 100644 --- a/rsocket/ConnectionAcceptor.h +++ b/rsocket/ConnectionAcceptor.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,9 +24,8 @@ class EventBase; namespace rsocket { -using OnDuplexConnectionAccept = std::function, - folly::EventBase&)>; +using OnDuplexConnectionAccept = std::function< + void(std::unique_ptr, folly::EventBase&)>; /** * Common interface for a server that accepts connections and turns them into @@ -24,8 +35,6 @@ using OnDuplexConnectionAccept = std::function #include #include "rsocket/DuplexConnection.h" +#include "rsocket/framing/ProtocolVersion.h" namespace folly { class EventBase; @@ -12,6 +25,8 @@ class EventBase; namespace rsocket { +enum class ResumeStatus { NEW_SESSION, RESUMING }; + /** * Common interface for a client to create connections and turn them into * DuplexConnections. @@ -44,6 +59,8 @@ class ConnectionFactory { * * Resource creation depends on the particular implementation. */ - virtual folly::Future connect() = 0; + virtual folly::Future connect( + ProtocolVersion, + ResumeStatus resume) = 0; }; } // namespace rsocket diff --git a/rsocket/DuplexConnection.h b/rsocket/DuplexConnection.h index c093c8488..7aaff2156 100644 --- a/rsocket/DuplexConnection.h +++ b/rsocket/DuplexConnection.h @@ -1,42 +1,27 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include "yarpl/flowable/Subscriber.h" +#include -namespace folly { -class IOBuf; -} +#include "yarpl/flowable/Subscriber.h" namespace rsocket { -using yarpl::Reference; - -class DuplexSubscriber : - public yarpl::flowable::Subscriber> -{ -public: - void onSubscribe(Reference sub) override { - subscription_ = sub; - } - void onComplete() override { - subscription_.reset(); - } - void onError(folly::exception_wrapper) override { - subscription_.reset(); - } - -protected: - Reference subscription() { - return subscription_; - } - -private: - Reference subscription_; -}; - /// Represents a connection of the underlying protocol, on top of which the /// RSocket protocol is layered. The underlying protocol MUST provide an /// ordered, guaranteed, bidirectional transport of frames. Moreover, frame @@ -55,7 +40,6 @@ class DuplexSubscriber : class DuplexConnection { public: using Subscriber = yarpl::flowable::Subscriber>; - using DuplexSubscriber = rsocket::DuplexSubscriber; virtual ~DuplexConnection() = default; @@ -63,7 +47,7 @@ class DuplexConnection { /// /// If setInput() has already been called, then calling setInput() again will /// complete the previous subscriber. - virtual void setInput(yarpl::Reference) = 0; + virtual void setInput(std::shared_ptr) = 0; /// Write a serialized frame to the connection. /// @@ -75,4 +59,5 @@ class DuplexConnection { return false; } }; -} + +} // namespace rsocket diff --git a/rsocket/Payload.cpp b/rsocket/Payload.cpp index 55810c784..b4037d888 100644 --- a/rsocket/Payload.cpp +++ b/rsocket/Payload.cpp @@ -1,65 +1,61 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/Payload.h" + #include #include -#include "rsocket/framing/Frame.h" + +#include "rsocket/internal/Common.h" namespace rsocket { -Payload::Payload( - std::unique_ptr _data, - std::unique_ptr _metadata) - : data(std::move(_data)), metadata(std::move(_metadata)) {} - -Payload::Payload(const std::string& _data, const std::string& _metadata) - : data(folly::IOBuf::copyBuffer(_data)) { - if (!_metadata.empty()) { - metadata = folly::IOBuf::copyBuffer(_metadata); - } -} +namespace { -void Payload::checkFlags(FrameFlags flags) const { - DCHECK(!!(flags & FrameFlags::METADATA) == bool(metadata)); +std::string moveIOBufToString(std::unique_ptr buf) { + return buf ? buf->moveToFbString().toStdString() : ""; } -std::ostream& operator<<(std::ostream& os, const Payload& payload) { - return os << "Metadata(" - << (payload.metadata - ? folly::to( - payload.metadata->computeChainDataLength()) - : "0") - << (payload.metadata - ? "): '" + - folly::humanify( - payload.metadata->cloneAsValue().moveToFbString().substr(0, 80)) + - "'" - : "): ") - << ", Data(" - << (payload.data ? folly::to( - payload.data->computeChainDataLength()) - : "0") - << (payload.data - ? "): '" + - folly::humanify( - payload.data->cloneAsValue().moveToFbString().substr(0, 80)) + - "'" - : "): "); +std::string cloneIOBufToString(std::unique_ptr const& buf) { + return buf ? buf->cloneAsValue().moveToFbString().toStdString() : ""; } -static std::string moveIOBufToString(std::unique_ptr iobuf) { - if (!iobuf) { - return ""; +} // namespace + +Payload::Payload( + std::unique_ptr d, + std::unique_ptr m) + : data{std::move(d)}, metadata{std::move(m)} {} + +Payload::Payload(folly::StringPiece d, folly::StringPiece m) + : data{folly::IOBuf::copyBuffer(d.data(), d.size())} { + if (!m.empty()) { + metadata = folly::IOBuf::copyBuffer(m.data(), m.size()); } - return iobuf->moveToFbString().toStdString(); } -static std::string cloneIOBufToString( - std::unique_ptr const& iobuf) { - if (!iobuf) { - return ""; - } - return iobuf->cloneAsValue().moveToFbString().toStdString(); +std::ostream& operator<<(std::ostream& os, const Payload& payload) { + return os << "Metadata(" + << (payload.metadata ? payload.metadata->computeChainDataLength() + : 0) + << "): " + << (payload.metadata ? "'" + humanify(payload.metadata) + "'" + : "") + << ", Data(" + << (payload.data ? payload.data->computeChainDataLength() : 0) + << "): " + << (payload.data ? "'" + humanify(payload.data) + "'" : ""); } std::string Payload::moveDataToString() { @@ -88,15 +84,28 @@ Payload Payload::clone() const { if (data) { out.data = data->clone(); } - if (metadata) { out.metadata = metadata->clone(); } return out; } -FrameFlags Payload::getFlags() const { - return (metadata != nullptr ? FrameFlags::METADATA : FrameFlags::EMPTY); +ErrorWithPayload::ErrorWithPayload(Payload&& payload) + : payload(std::move(payload)) {} + +ErrorWithPayload::ErrorWithPayload(const ErrorWithPayload& oth) { + payload = oth.payload.clone(); +} + +ErrorWithPayload& ErrorWithPayload::operator=(const ErrorWithPayload& oth) { + payload = oth.payload.clone(); + return *this; +} + +std::ostream& operator<<( + std::ostream& os, + const ErrorWithPayload& errorWithPayload) { + return os << "rsocket::ErrorWithPayload: " << errorWithPayload.payload; } } // namespace rsocket diff --git a/rsocket/Payload.h b/rsocket/Payload.h index b43f745f0..c21587014 100644 --- a/rsocket/Payload.h +++ b/rsocket/Payload.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -6,8 +18,6 @@ #include #include -#include "rsocket/framing/FrameFlags.h" - namespace rsocket { /// The type of a read-only view on a binary buffer. @@ -20,16 +30,13 @@ struct Payload { std::unique_ptr metadata = std::unique_ptr()); explicit Payload( - const std::string& data, - const std::string& metadata = std::string()); + folly::StringPiece data, + folly::StringPiece metadata = folly::StringPiece{}); explicit operator bool() const { return data != nullptr || metadata != nullptr; } - FrameFlags getFlags() const; - void checkFlags(FrameFlags flags) const; - std::string moveDataToString(); std::string cloneDataToString() const; @@ -44,5 +51,23 @@ struct Payload { std::unique_ptr metadata; }; -std::ostream& operator<<(std::ostream& os, const Payload& payload); -} +struct ErrorWithPayload : public std::exception { + explicit ErrorWithPayload(Payload&& payload); + + // folly::ExceptionWrapper requires exceptions to have copy constructors + ErrorWithPayload(const ErrorWithPayload& oth); + ErrorWithPayload& operator=(const ErrorWithPayload&); + ErrorWithPayload(ErrorWithPayload&&) = default; + ErrorWithPayload& operator=(ErrorWithPayload&&) = default; + + const char* what() const noexcept override { + return "ErrorWithPayload"; + } + + Payload payload; +}; + +std::ostream& operator<<(std::ostream& os, const Payload&); +std::ostream& operator<<(std::ostream& os, const ErrorWithPayload&); + +} // namespace rsocket diff --git a/rsocket/README.md b/rsocket/README.md new file mode 100644 index 000000000..3b811aca9 --- /dev/null +++ b/rsocket/README.md @@ -0,0 +1,27 @@ +# rsocket-cpp + +C++ implementation of [RSocket](https://rsocket.io) + + +[![Coverage Status](https://coveralls.io/repos/github/rsocket/rsocket-cpp/badge.svg?branch=master)](https://coveralls.io/github/rsocket/rsocket-cpp?branch=master) + +# Dependencies + +Install `folly`: + +``` +brew install folly +``` + +# Building and running tests + +After installing dependencies as above, you can build and run tests with: + +``` +# inside root ./rsocket-cpp +mkdir -p build +cd build +cmake -DCMAKE_BUILD_TYPE=DEBUG ../ +make -j +./tests +``` diff --git a/rsocket/RSocket.cpp b/rsocket/RSocket.cpp index f990df0cd..e83c5ca71 100644 --- a/rsocket/RSocket.cpp +++ b/rsocket/RSocket.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocket.h" @@ -14,44 +26,51 @@ folly::Future> RSocket::createConnectedClient( std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler, folly::EventBase* stateMachineEvb) { - auto createRSC = [ - connectionFactory, - setupParameters = std::move(setupParameters), - responder = std::move(responder), - keepaliveInterval, - stats = std::move(stats), - connectionEvents = std::move(connectionEvents), - resumeManager = std::move(resumeManager), - coldResumeHandler = std::move(coldResumeHandler), - stateMachineEvb - ](ConnectionFactory::ConnectedDuplexConnection connection) mutable { - VLOG(3) << "createConnectedClient received DuplexConnection"; - return RSocket::createClientFromConnection( - std::move(connection.connection), - connection.eventBase, - std::move(setupParameters), - std::move(connectionFactory), - std::move(responder), - keepaliveInterval, - std::move(stats), - std::move(connectionEvents), - std::move(resumeManager), - std::move(coldResumeHandler), - stateMachineEvb); - }; + CHECK(resumeManager) + << "provide ResumeManager::makeEmpty() instead of nullptr"; + auto protocolVersion = setupParameters.protocolVersion; + auto createRSC = + [connectionFactory, + setupParameters = std::move(setupParameters), + responder = std::move(responder), + keepaliveInterval, + stats = std::move(stats), + connectionEvents = std::move(connectionEvents), + resumeManager = std::move(resumeManager), + coldResumeHandler = std::move(coldResumeHandler), + stateMachineEvb]( + ConnectionFactory::ConnectedDuplexConnection connection) mutable { + VLOG(3) << "createConnectedClient received DuplexConnection"; + return RSocket::createClientFromConnection( + std::move(connection.connection), + connection.eventBase, + std::move(setupParameters), + std::move(connectionFactory), + std::move(responder), + keepaliveInterval, + std::move(stats), + std::move(connectionEvents), + std::move(resumeManager), + std::move(coldResumeHandler), + stateMachineEvb); + }; - return connectionFactory->connect().then([createRSC = std::move(createRSC)]( - ConnectionFactory::ConnectedDuplexConnection connection) mutable { - // fromConnection method must be called from the transport eventBase - // and since there is no guarantee that the Future returned from the - // connectionFactory::connect method is executed on the event base, we - // have to ensure it by using folly::via - auto* transportEvb = &connection.eventBase; - return via(transportEvb, [ - connection = std::move(connection), - createRSC = std::move(createRSC) - ]() mutable { return createRSC(std::move(connection)); }); - }); + return connectionFactory->connect(protocolVersion, ResumeStatus::NEW_SESSION) + .thenValue( + [createRSC = std::move(createRSC)]( + ConnectionFactory::ConnectedDuplexConnection connection) mutable { + // fromConnection method must be called from the transport eventBase + // and since there is no guarantee that the Future returned from the + // connectionFactory::connect method is executed on the event base, + // we have to ensure it by using folly::via + auto transportEvb = &connection.eventBase; + return folly::via( + transportEvb, + [connection = std::move(connection), + createRSC = std::move(createRSC)]() mutable { + return createRSC(std::move(connection)); + }); + }); } folly::Future> RSocket::createResumedClient( @@ -77,8 +96,8 @@ folly::Future> RSocket::createResumedClient( std::move(coldResumeHandler), stateMachineEvb); - return c->resume() - .then([client = std::unique_ptr(c)]() mutable { + return c->resume().thenValue( + [client = std::unique_ptr(c)](auto&&) mutable { return std::move(client); }); } @@ -86,7 +105,7 @@ folly::Future> RSocket::createResumedClient( std::unique_ptr RSocket::createClientFromConnection( std::unique_ptr connection, folly::EventBase& transportEvb, - SetupParameters setupParameters, + SetupParameters params, std::shared_ptr connectionFactory, std::shared_ptr responder, std::chrono::milliseconds keepaliveInterval, @@ -95,10 +114,10 @@ std::unique_ptr RSocket::createClientFromConnection( std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler, folly::EventBase* stateMachineEvb) { - auto c = std::unique_ptr(new RSocketClient( + auto client = std::unique_ptr(new RSocketClient( std::move(connectionFactory), - setupParameters.protocolVersion, - setupParameters.token, + params.protocolVersion, + params.token, std::move(responder), keepaliveInterval, std::move(stats), @@ -106,11 +125,9 @@ std::unique_ptr RSocket::createClientFromConnection( std::move(resumeManager), std::move(coldResumeHandler), stateMachineEvb)); - c->fromConnection( - std::move(connection), - transportEvb, - std::move(setupParameters)); - return c; + client->fromConnection( + std::move(connection), transportEvb, std::move(params)); + return client; } std::unique_ptr RSocket::createServer( @@ -119,4 +136,5 @@ std::unique_ptr RSocket::createServer( return std::make_unique( std::move(connectionAcceptor), std::move(stats)); } + } // namespace rsocket diff --git a/rsocket/RSocket.h b/rsocket/RSocket.h index f68f26497..13f642830 100644 --- a/rsocket/RSocket.h +++ b/rsocket/RSocket.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -23,7 +35,7 @@ class RSocket { std::shared_ptr stats = RSocketStats::noop(), std::shared_ptr connectionEvents = std::shared_ptr(), - std::shared_ptr resumeManager = nullptr, + std::shared_ptr resumeManager = ResumeManager::makeEmpty(), std::shared_ptr coldResumeHandler = std::shared_ptr(), folly::EventBase* stateMachineEvb = nullptr); @@ -41,11 +53,11 @@ class RSocket { std::shared_ptr stats = RSocketStats::noop(), std::shared_ptr connectionEvents = std::shared_ptr(), - ProtocolVersion protocolVersion = ProtocolVersion::Current(), + ProtocolVersion protocolVersion = ProtocolVersion::Latest, folly::EventBase* stateMachineEvb = nullptr); - // Creates a RSocketClient from an existing DuplexConnection - // keepaliveInterval of 0 will result in no keepAlives + // Creates a RSocketClient from an existing DuplexConnection. A keepalive + // interval of 0 will result in no keepalives. static std::unique_ptr createClientFromConnection( std::unique_ptr connection, folly::EventBase& transportEvb, @@ -55,11 +67,9 @@ class RSocket { std::make_shared(), std::chrono::milliseconds keepaliveInterval = kDefaultKeepaliveInterval, std::shared_ptr stats = RSocketStats::noop(), - std::shared_ptr connectionEvents = - std::shared_ptr(), - std::shared_ptr resumeManager = nullptr, - std::shared_ptr coldResumeHandler = - std::shared_ptr(), + std::shared_ptr connectionEvents = nullptr, + std::shared_ptr resumeManager = ResumeManager::makeEmpty(), + std::shared_ptr coldResumeHandler = nullptr, folly::EventBase* stateMachineEvb = nullptr); // A convenience function to create RSocketServer @@ -68,13 +78,9 @@ class RSocket { std::shared_ptr stats = RSocketStats::noop()); RSocket() = delete; - RSocket(const RSocket&) = delete; - RSocket(RSocket&&) = delete; - RSocket& operator=(const RSocket&) = delete; - RSocket& operator=(RSocket&&) = delete; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketClient.cpp b/rsocket/RSocketClient.cpp index 45177ed1b..7f10b3ead 100644 --- a/rsocket/RSocketClient.cpp +++ b/rsocket/RSocketClient.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketClient.h" #include "rsocket/RSocketRequester.h" @@ -10,8 +22,6 @@ #include "rsocket/internal/ClientResumeStatusCallback.h" #include "rsocket/internal/KeepaliveTimer.h" -using namespace folly; - namespace rsocket { RSocketClient::RSocketClient( @@ -34,13 +44,17 @@ RSocketClient::RSocketClient( coldResumeHandler_(coldResumeHandler), protocolVersion_(protocolVersion), token_(std::move(token)), - evb_(stateMachineEvb) {} + evb_(stateMachineEvb) { + CHECK(resumeManager_) + << "provide ResumeManager::makeEmpty() instead of nullptr"; +} RSocketClient::~RSocketClient() { VLOG(3) << "~RSocketClient .."; evb_->runImmediatelyOrRunInEventBaseThreadAndWait([sm = stateMachine_] { - std::runtime_error exn{"RSocketClient is closing"}; + auto exn = folly::make_exception_wrapper( + "RSocketClient is closing"); sm->close(std::move(exn), StreamCompletionSignal::CONNECTION_END); }); } @@ -49,88 +63,94 @@ const std::shared_ptr& RSocketClient::getRequester() const { return requester_; } -folly::Future RSocketClient::resume() { - VLOG(2) << "Resuming connection"; +// Returns if this client is currently disconnected +bool RSocketClient::isDisconnected() const { + return stateMachine_->isDisconnected(); +} +folly::Future RSocketClient::resume() { CHECK(connectionFactory_) << "The client was likely created without ConnectionFactory. Can't " << "resume"; - return connectionFactory_->connect().then( - [this](ConnectionFactory::ConnectedDuplexConnection connection) mutable { - - if (!evb_) { - // cold-resumption. EventBase hasn't been explicitly set for SM by - // the application. Use the transports eventBase. - evb_ = &connection.eventBase; - } - - class ResumeCallback : public ClientResumeStatusCallback { - public: - explicit ResumeCallback(folly::Promise promise) - : promise_(std::move(promise)) {} - - void onResumeOk() noexcept override { - promise_.setValue(); - } - - void onResumeError(folly::exception_wrapper ex) noexcept override { - promise_.setException(ex); - } - - private: - folly::Promise promise_; - }; - - folly::Promise promise; - auto future = promise.getFuture(); - - auto resumeCallback = - std::make_unique(std::move(promise)); - std::unique_ptr framedConnection; - if (connection.connection->isFramed()) { - framedConnection = std::move(connection.connection); - } else { - framedConnection = std::make_unique( - std::move(connection.connection), protocolVersion_); - } - auto transport = - yarpl::make_ref(std::move(framedConnection)); - - yarpl::Reference ft; - if (evb_ != &connection.eventBase) { - // If the StateMachine EventBase is different from the transport - // EventBase, then use ScheduledFrameTransport and - // ScheduledFrameProcessor to ensure the RSocketStateMachine and - // Transport live on the desired EventBases - ft = yarpl::make_ref( - std::move(transport), - &connection.eventBase, /* Transport EventBase */ - evb_); /* StateMachine EventBase */ - } else { - ft = std::move(transport); - } - - evb_->runInEventBaseThread([ - this, - frameTransport = std::move(ft), - resumeCallback = std::move(resumeCallback), - connection = std::move(connection) - ]() mutable { - if (!stateMachine_) { - createState(); - } - - stateMachine_->resumeClient( - token_, - std::move(frameTransport), - std::move(resumeCallback), - protocolVersion_); - }); - - return future; - - }); + return connectionFactory_->connect(protocolVersion_, ResumeStatus::RESUMING) + .thenValue( + [this]( + ConnectionFactory::ConnectedDuplexConnection connection) mutable { + return resumeFromConnection(std::move(connection)); + }); +} + +folly::Future RSocketClient::resumeFromConnection( + ConnectionFactory::ConnectedDuplexConnection connection) { + VLOG(2) << "Resuming connection"; + + if (!evb_) { + // Cold-resumption. EventBase hasn't been explicitly set for SM by the + // application. Use the transport's eventBase. + evb_ = &connection.eventBase; + } + + class ResumeCallback : public ClientResumeStatusCallback { + public: + explicit ResumeCallback(folly::Promise promise) + : promise_(std::move(promise)) {} + + void onResumeOk() noexcept override { + promise_.setValue(); + } + + void onResumeError(folly::exception_wrapper ex) noexcept override { + promise_.setException(ex); + } + + private: + folly::Promise promise_; + }; + + folly::Promise promise; + auto future = promise.getFuture(); + + auto resumeCallback = std::make_unique(std::move(promise)); + std::unique_ptr framedConnection; + if (connection.connection->isFramed()) { + framedConnection = std::move(connection.connection); + } else { + framedConnection = std::make_unique( + std::move(connection.connection), protocolVersion_); + } + auto transport = + std::make_shared(std::move(framedConnection)); + + std::shared_ptr ft; + if (evb_ != &connection.eventBase) { + // If the StateMachine EventBase is different from the transport + // EventBase, then use ScheduledFrameTransport and + // ScheduledFrameProcessor to ensure the RSocketStateMachine and + // Transport live on the desired EventBases + ft = std::make_shared( + std::move(transport), + &connection.eventBase, /* Transport EventBase */ + evb_); /* StateMachine EventBase */ + } else { + ft = std::move(transport); + } + + evb_->runInEventBaseThread([this, + frameTransport = std::move(ft), + callback = std::move(resumeCallback)]() mutable { + if (!stateMachine_) { + createState(); + } + + stateMachine_->resumeClient( + token_, + std::move(frameTransport), + std::move(callback), + protocolVersion_); + }); + + return future; } folly::Future RSocketClient::disconnect( @@ -140,7 +160,7 @@ folly::Future RSocketClient::disconnect( std::runtime_error{"RSocketClient must always have a state machine"}); } - auto work = [ sm = stateMachine_, e = std::move(ew) ]() mutable { + auto work = [sm = stateMachine_, e = std::move(ew)]() mutable { sm->disconnect(std::move(e)); }; @@ -157,43 +177,38 @@ folly::Future RSocketClient::disconnect( void RSocketClient::fromConnection( std::unique_ptr connection, folly::EventBase& transportEvb, - SetupParameters setupParameters) { + SetupParameters params) { if (!evb_) { // If no EventBase is given for the stateMachine, then use the transport's // EventBase to drive the stateMachine. evb_ = &transportEvb; } createState(); - std::unique_ptr framedConnection; + + std::unique_ptr framed; if (connection->isFramed()) { - framedConnection = std::move(connection); + framed = std::move(connection); } else { - framedConnection = std::make_unique( - std::move(connection), setupParameters.protocolVersion); + framed = std::make_unique( + std::move(connection), params.protocolVersion); } - auto transport = - yarpl::make_ref(std::move(framedConnection)); - if (evb_ != &transportEvb) { - // If the StateMachine EventBase is different from the transport - // EventBase, then use ScheduledFrameTransport and ScheduledFrameProcessor - // to ensure the RSocketStateMachine and Transport live on the desired - // EventBases - auto scheduledFT = yarpl::make_ref( - std::move(transport), - &transportEvb, /* Transport EventBase */ - evb_); /* StateMachine EventBase */ - evb_->runInEventBaseThread([ - stateMachine = stateMachine_, - scheduledFT = std::move(scheduledFT), - setupParameters = std::move(setupParameters) - ]() mutable { - stateMachine->connectClient( - std::move(scheduledFT), std::move(setupParameters)); - }); - } else { - stateMachine_->connectClient( - std::move(transport), std::move(setupParameters)); + auto transport = std::make_shared(std::move(framed)); + + if (evb_ == &transportEvb) { + stateMachine_->connectClient(std::move(transport), std::move(params)); + return; } + + // If the StateMachine EventBase is different from the transport EventBase, + // then use ScheduledFrameTransport and ScheduledFrameProcessor to ensure the + // RSocketStateMachine and Transport live on the desired EventBases. + auto scheduledFT = std::make_shared( + std::move(transport), &transportEvb, evb_); + evb_->runInEventBaseThread([stateMachine = stateMachine_, + scheduledFT = std::move(scheduledFT), + params = std::move(params)]() mutable { + stateMachine->connectClient(std::move(scheduledFT), std::move(params)); + }); } void RSocketClient::createState() { diff --git a/rsocket/RSocketClient.h b/rsocket/RSocketClient.h index 077fb8dbc..070a3f6be 100644 --- a/rsocket/RSocketClient.h +++ b/rsocket/RSocketClient.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -27,22 +39,32 @@ class RSocketClient { ~RSocketClient(); RSocketClient(const RSocketClient&) = delete; - RSocketClient(RSocketClient&&) = default; + RSocketClient(RSocketClient&&) = delete; RSocketClient& operator=(const RSocketClient&) = delete; - RSocketClient& operator=(RSocketClient&&) = default; + RSocketClient& operator=(RSocketClient&&) = delete; friend class RSocket; // Returns the RSocketRequester associated with the RSocketClient. const std::shared_ptr& getRequester() const; - // Resumes the connection. If a stateMachine already exists, - // it provides a warm-resumption. If a stateMachine does not exist, - // it does a cold-resumption. The returned future resolves on successful - // resumption. Else either a ConnectionException or a ResumptionException - // is raised. + // Returns if this client is currently disconnected + bool isDisconnected() const; + + // Resumes the client's connection. If the client was previously connected + // this will attempt a warm-resumption. Otherwise this will attempt a + // cold-resumption. + // + // Uses the internal ConnectionFactory instance to re-connect. folly::Future resume(); + // Like resume(), but this doesn't use a ConnectionFactory and instead takes + // the connection and transport EventBase by argument. + // + // Prefer using resume() if possible. + folly::Future resumeFromConnection( + ConnectionFactory::ConnectedDuplexConnection); + // Disconnect the underlying transport. folly::Future disconnect(folly::exception_wrapper = {}); @@ -70,9 +92,9 @@ class RSocketClient { // Creates RSocketStateMachine and RSocketRequester void createState(); - std::shared_ptr connectionFactory_; + const std::shared_ptr connectionFactory_; std::shared_ptr responder_; - std::chrono::milliseconds keepaliveInterval_; + const std::chrono::milliseconds keepaliveInterval_; std::shared_ptr stats_; std::shared_ptr connectionEvents_; std::shared_ptr resumeManager_; @@ -81,8 +103,8 @@ class RSocketClient { std::shared_ptr stateMachine_; std::shared_ptr requester_; - ProtocolVersion protocolVersion_; - ResumeIdentificationToken token_; + const ProtocolVersion protocolVersion_; + const ResumeIdentificationToken token_; // Remember the StateMachine's evb (supplied through constructor). If no // EventBase is provided, the underlying transport's EventBase will be used @@ -94,6 +116,5 @@ class RSocketClient { // EventBase, but the transport ends up being in different EventBase after // resumption, and vice versa. folly::EventBase* evb_{nullptr}; - }; -} +} // namespace rsocket diff --git a/rsocket/RSocketConnectionEvents.h b/rsocket/RSocketConnectionEvents.h index 8df3ee6a5..177a819d2 100644 --- a/rsocket/RSocketConnectionEvents.h +++ b/rsocket/RSocketConnectionEvents.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -39,4 +51,4 @@ class RSocketConnectionEvents { // typically gets called after onConnected() virtual void onStreamsResumed() {} }; -} +} // namespace rsocket diff --git a/rsocket/RSocketErrors.h b/rsocket/RSocketErrors.h index 8add6294f..e570e7532 100644 --- a/rsocket/RSocketErrors.h +++ b/rsocket/RSocketErrors.h @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include #include namespace rsocket { @@ -19,7 +32,7 @@ class RSocketError : public std::runtime_error { * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#error-codes * @return */ - virtual int getErrorCode() = 0; + virtual int getErrorCode() const = 0; }; /** @@ -29,7 +42,7 @@ class InvalidSetupError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000001; } @@ -45,7 +58,7 @@ class UnsupportedSetupError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000002; } @@ -61,7 +74,7 @@ class RejectedSetupError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000003; } @@ -77,7 +90,7 @@ class RejectedResumeError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000004; } @@ -93,7 +106,7 @@ class ConnectionError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000101; } @@ -103,13 +116,13 @@ class ConnectionError : public RSocketError { }; /** -* Error Code: CONNECTION_CLOSE 0x00000102 -*/ + * Error Code: CONNECTION_CLOSE 0x00000102 + */ class ConnectionCloseError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000102; } @@ -117,4 +130,4 @@ class ConnectionCloseError : public RSocketError { return "CONNECTION_CLOSE"; } }; -} +} // namespace rsocket diff --git a/rsocket/RSocketException.h b/rsocket/RSocketException.h index ae99ab70d..9dc9d61e7 100644 --- a/rsocket/RSocketException.h +++ b/rsocket/RSocketException.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -21,4 +33,4 @@ class ResumptionException : public RSocketException { class ConnectionException : public RSocketException { using RSocketException::RSocketException; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketParameters.cpp b/rsocket/RSocketParameters.cpp index b59b5ba62..08f221e44 100644 --- a/rsocket/RSocketParameters.cpp +++ b/rsocket/RSocketParameters.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketParameters.h" @@ -14,4 +26,4 @@ std::ostream& operator<<( << " token: " << setupPayload.token << " resumable: " << setupPayload.resumable; } -} +} // namespace rsocket diff --git a/rsocket/RSocketParameters.h b/rsocket/RSocketParameters.h index 8d1181d0c..0605bcfcf 100644 --- a/rsocket/RSocketParameters.h +++ b/rsocket/RSocketParameters.h @@ -1,12 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include +#include +#include #include +#include + #include "rsocket/Payload.h" -#include "rsocket/framing/FrameSerializer.h" -#include "rsocket/internal/Common.h" +#include "rsocket/framing/Frame.h" namespace rsocket { @@ -15,8 +29,8 @@ using OnRSocketResume = class RSocketParameters { public: - RSocketParameters(bool _resumable, ProtocolVersion _protocolVersion) - : resumable(_resumable), protocolVersion(std::move(_protocolVersion)) {} + RSocketParameters(bool resume, ProtocolVersion version) + : resumable{resume}, protocolVersion{std::move(version)} {} bool resumable; ProtocolVersion protocolVersion; @@ -25,18 +39,18 @@ class RSocketParameters { class SetupParameters : public RSocketParameters { public: explicit SetupParameters( - std::string _metadataMimeType = "text/plain", - std::string _dataMimeType = "text/plain", - Payload _payload = Payload(), - bool _resumable = false, - const ResumeIdentificationToken& _token = + std::string metadataMime = "text/plain", + std::string dataMime = "text/plain", + Payload buf = Payload(), + bool resume = false, + ResumeIdentificationToken resumeToken = ResumeIdentificationToken::generateNew(), - ProtocolVersion _protocolVersion = ProtocolVersion::Current()) - : RSocketParameters(_resumable, _protocolVersion), - metadataMimeType(std::move(_metadataMimeType)), - dataMimeType(std::move(_dataMimeType)), - payload(std::move(_payload)), - token(_token) {} + ProtocolVersion version = ProtocolVersion::Latest) + : RSocketParameters(resume, version), + metadataMimeType(std::move(metadataMime)), + dataMimeType(std::move(dataMime)), + payload(std::move(buf)), + token(resumeToken) {} std::string metadataMimeType; std::string dataMimeType; @@ -49,18 +63,18 @@ std::ostream& operator<<(std::ostream&, const SetupParameters&); class ResumeParameters : public RSocketParameters { public: ResumeParameters( - ResumeIdentificationToken _token, - ResumePosition _serverPosition, - ResumePosition _clientPosition, - ProtocolVersion _protocolVersion) - : RSocketParameters(true, _protocolVersion), - token(std::move(_token)), - serverPosition(_serverPosition), - clientPosition(_clientPosition) {} + ResumeIdentificationToken resumeToken, + ResumePosition serverPos, + ResumePosition clientPos, + ProtocolVersion version) + : RSocketParameters(true, version), + token(std::move(resumeToken)), + serverPosition(serverPos), + clientPosition(clientPos) {} ResumeIdentificationToken token; ResumePosition serverPosition; ResumePosition clientPosition; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/RSocketRequester.cpp b/rsocket/RSocketRequester.cpp index c216711e6..cf1799506 100644 --- a/rsocket/RSocketRequester.cpp +++ b/rsocket/RSocketRequester.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketRequester.h" @@ -10,161 +22,160 @@ #include "yarpl/single/SingleSubscriptions.h" using namespace folly; -using namespace yarpl; namespace rsocket { +namespace { + +template +void runOnCorrectThread(folly::EventBase& evb, Fn fn) { + if (evb.isInEventBaseThread()) { + fn(); + } else { + evb.runInEventBaseThread(std::move(fn)); + } +} + +} // namespace + RSocketRequester::RSocketRequester( std::shared_ptr srs, EventBase& eventBase) - : stateMachine_(std::move(srs)), eventBase_(&eventBase) {} + : stateMachine_{std::move(srs)}, eventBase_{&eventBase} {} RSocketRequester::~RSocketRequester() { VLOG(1) << "Destroying RSocketRequester"; } void RSocketRequester::closeSocket() { - eventBase_->add([stateMachine = std::move(stateMachine_)] { + eventBase_->runInEventBaseThread([stateMachine = std::move(stateMachine_)] { VLOG(2) << "Closing RSocketStateMachine on EventBase"; - stateMachine->close( - folly::exception_wrapper(), StreamCompletionSignal::SOCKET_CLOSED); + stateMachine->close({}, StreamCompletionSignal::SOCKET_CLOSED); }); } -yarpl::Reference> +std::shared_ptr> RSocketRequester::requestChannel( - yarpl::Reference> + std::shared_ptr> requestStream) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::flowable::Flowables::fromPublisher([ - eb = eventBase_, - requestStream = std::move(requestStream), - srs = stateMachine_ - ](yarpl::Reference> subscriber) mutable { - auto lambda = [ - requestStream = std::move(requestStream), - subscriber = std::move(subscriber), - srs = std::move(srs), - eb - ]() mutable { - auto responseSink = srs->streamsFactory().createChannelRequester( - yarpl::make_ref>( - std::move(subscriber), *eb)); - // responseSink is wrapped with thread scheduling - // so all emissions happen on the right thread - - // if we don't get a responseSink back, that means that - // the requesting peer wasn't connected (or similar error) - // and the Flowable it gets back will immediately call onError - if (responseSink) { - requestStream->subscribe(yarpl::make_ref>( - std::move(responseSink), *eb)); - } - }; - if (eb->isInEventBaseThread()) { - lambda(); - } else { - eb->runInEventBaseThread(std::move(lambda)); - } - }); + return requestChannel({}, false, std::move(requestStream)); +} + +std::shared_ptr> +RSocketRequester::requestChannel( + Payload request, + std::shared_ptr> + requestStream) { + return requestChannel(std::move(request), true, std::move(requestStream)); +} + +std::shared_ptr> +RSocketRequester::requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> + requestStreamFlowable) { + CHECK(stateMachine_); + + return yarpl::flowable::internal::flowableFromSubscriber( + [eb = eventBase_, + req = std::move(request), + hasInitialRequest, + requestStream = std::move(requestStreamFlowable), + srs = stateMachine_]( + std::shared_ptr> subscriber) { + auto lambda = [eb, + r = req.clone(), + hasInitialRequest, + requestStream, + srs, + subs = std::move(subscriber)]() mutable { + auto scheduled = + std::make_shared>( + std::move(subs), *eb); + auto responseSink = srs->requestChannel( + std::move(r), hasInitialRequest, std::move(scheduled)); + // responseSink is wrapped with thread scheduling + // so all emissions happen on the right thread. + + // If we don't get a responseSink back, that means that + // the requesting peer wasn't connected (or similar error) + // and the Flowable it gets back will immediately call onError. + if (responseSink) { + auto scheduledResponse = + std::make_shared>( + std::move(responseSink), *eb); + requestStream->subscribe(std::move(scheduledResponse)); + } + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } -yarpl::Reference> +std::shared_ptr> RSocketRequester::requestStream(Payload request) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::flowable::Flowables::fromPublisher([ - eb = eventBase_, - request = std::move(request), - srs = stateMachine_ - ](yarpl::Reference> subscriber) mutable { - auto lambda = [ - request = std::move(request), - subscriber = std::move(subscriber), - srs = std::move(srs), - eb - ]() mutable { - srs->streamsFactory().createStreamRequester( - std::move(request), - yarpl::make_ref>( - std::move(subscriber), *eb)); - }; - if (eb->isInEventBaseThread()) { - lambda(); - } else { - eb->runInEventBaseThread(std::move(lambda)); - } - }); + CHECK(stateMachine_); + + return yarpl::flowable::internal::flowableFromSubscriber( + [eb = eventBase_, req = std::move(request), srs = stateMachine_]( + std::shared_ptr> subscriber) { + auto lambda = + [eb, r = req.clone(), srs, subs = std::move(subscriber)]() mutable { + auto scheduled = + std::make_shared>( + std::move(subs), *eb); + srs->requestStream(std::move(r), std::move(scheduled)); + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } -yarpl::Reference> +std::shared_ptr> RSocketRequester::requestResponse(Payload request) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::single::Single::create([ - eb = eventBase_, - request = std::move(request), - srs = stateMachine_ - ](yarpl::Reference> observer) mutable { - auto lambda = [ - request = std::move(request), - observer = std::move(observer), - eb, - srs = std::move(srs) - ]() mutable { - srs->streamsFactory().createRequestResponseRequester( - std::move(request), - yarpl::make_ref>( - std::move(observer), *eb)); - }; - if (eb->isInEventBaseThread()) { - lambda(); - } else { - eb->runInEventBaseThread(std::move(lambda)); - } - }); + CHECK(stateMachine_); + + return yarpl::single::Single::create( + [eb = eventBase_, req = std::move(request), srs = stateMachine_]( + std::shared_ptr> observer) { + auto lambda = [eb, + r = req.clone(), + srs, + obs = std::move(observer)]() mutable { + auto scheduled = + std::make_shared>( + std::move(obs), *eb); + srs->requestResponse(std::move(r), std::move(scheduled)); + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } -yarpl::Reference> RSocketRequester::fireAndForget( +std::shared_ptr> RSocketRequester::fireAndForget( rsocket::Payload request) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::single::Single::create([ - eb = eventBase_, - request = std::move(request), - srs = stateMachine_ - ](yarpl::Reference> subscriber) mutable { - auto lambda = [ - request = std::move(request), - subscriber = std::move(subscriber), - srs = std::move(srs) - ]() mutable { - // TODO pass in SingleSubscriber for underlying layers to - // call onSuccess/onError once put on network - srs->fireAndForget(std::move(request)); - // right now just immediately call onSuccess - subscriber->onSubscribe(yarpl::single::SingleSubscriptions::empty()); - subscriber->onSuccess(); - }; - if (eb->isInEventBaseThread()) { - lambda(); - } else { - eb->runInEventBaseThread(std::move(lambda)); - } - }); + CHECK(stateMachine_); + + return yarpl::single::Single::create( + [eb = eventBase_, req = std::move(request), srs = stateMachine_]( + std::shared_ptr> subscriber) { + auto lambda = + [r = req.clone(), srs, subs = std::move(subscriber)]() mutable { + // TODO: Pass in SingleSubscriber for underlying layers to call + // onSuccess/onError once put on network. + srs->fireAndForget(std::move(r)); + subs->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + subs->onSuccess(); + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } void RSocketRequester::metadataPush(std::unique_ptr metadata) { - CHECK(stateMachine_); // verify the socket was not closed + CHECK(stateMachine_); - eventBase_->runInEventBaseThread( - [ srs = stateMachine_, metadata = std::move(metadata) ]() mutable { - srs->metadataPush(std::move(metadata)); + runOnCorrectThread( + *eventBase_, [srs = stateMachine_, meta = std::move(metadata)]() mutable { + srs->metadataPush(std::move(meta)); }); } -DuplexConnection* RSocketRequester::getConnection() { - return stateMachine_? stateMachine_->getConnection() : nullptr; -} } // namespace rsocket diff --git a/rsocket/RSocketRequester.h b/rsocket/RSocketRequester.h index a7058268c..a87d15955 100644 --- a/rsocket/RSocketRequester.h +++ b/rsocket/RSocketRequester.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -51,7 +63,7 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#request-stream */ - virtual yarpl::Reference> + virtual std::shared_ptr> requestStream(rsocket::Payload request); /** @@ -60,9 +72,20 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#request-channel */ - virtual yarpl::Reference> + virtual std::shared_ptr> requestChannel( - yarpl::Reference> requests); + std::shared_ptr> requests); + + /** + * As requestStream function accepts an initial request, this version of + * requestChannel also accepts an initial request. + * @see requestChannel + * @see requestStream + */ + virtual std::shared_ptr> + requestChannel( + Payload request, + std::shared_ptr> requests); /** * Send a single request and get a single response. @@ -70,7 +93,7 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#stream-sequences-request-response */ - virtual yarpl::Reference> + virtual std::shared_ptr> requestResponse(rsocket::Payload request); /** @@ -85,7 +108,7 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#request-fire-n-forget */ - virtual yarpl::Reference> fireAndForget( + virtual std::shared_ptr> fireAndForget( rsocket::Payload request); /** @@ -93,15 +116,16 @@ class RSocketRequester { */ virtual void metadataPush(std::unique_ptr metadata); - /** - * To be used only temporarily to check the transport's status. - */ - virtual DuplexConnection* getConnection(); - virtual void closeSocket(); protected: + virtual std::shared_ptr> + requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> requests); + std::shared_ptr stateMachine_; folly::EventBase* eventBase_; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketResponder.cpp b/rsocket/RSocketResponder.cpp index 5b5b00274..892d2e12e 100644 --- a/rsocket/RSocketResponder.cpp +++ b/rsocket/RSocketResponder.cpp @@ -1,35 +1,86 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketResponder.h" #include +#include namespace rsocket { -yarpl::Reference> -RSocketResponder::handleRequestResponse(rsocket::Payload, rsocket::StreamId) { - return yarpl::single::Singles::error( +using namespace yarpl::flowable; +using namespace yarpl::single; + +void RSocketResponderCore::handleRequestStream( + Payload, + StreamId, + std::shared_ptr> response) noexcept { + response->onSubscribe(Subscription::create()); + response->onError(std::logic_error("handleRequestStream not implemented")); +} + +void RSocketResponderCore::handleRequestResponse( + Payload, + StreamId, + std::shared_ptr> responseObserver) noexcept { + responseObserver->onSubscribe(SingleSubscriptions::empty()); + responseObserver->onError( + std::logic_error("handleRequestResponse not implemented")); +} + +void RSocketResponderCore::handleFireAndForget(Payload, StreamId) { + // No default implementation, no error response to provide. +} + +void RSocketResponderCore::handleMetadataPush(std::unique_ptr) { + // No default implementation, no error response to provide. +} + +std::shared_ptr> RSocketResponderCore::handleRequestChannel( + Payload, + StreamId, + std::shared_ptr> response) noexcept { + response->onSubscribe(Subscription::create()); + response->onError(std::logic_error("handleRequestStream not implemented")); + + // cancel immediately + return std::make_shared>(); +} + +std::shared_ptr> RSocketResponder::handleRequestResponse( + Payload, + StreamId) { + return Singles::error( std::logic_error("handleRequestResponse not implemented")); } -yarpl::Reference> -RSocketResponder::handleRequestStream(rsocket::Payload, rsocket::StreamId) { - return yarpl::flowable::Flowables::error( +std::shared_ptr> RSocketResponder::handleRequestStream( + Payload, + StreamId) { + return Flowable::error( std::logic_error("handleRequestStream not implemented")); } -yarpl::Reference> -RSocketResponder::handleRequestChannel( - rsocket::Payload, - yarpl::Reference>, - rsocket::StreamId) { - return yarpl::flowable::Flowables::error( +std::shared_ptr> RSocketResponder::handleRequestChannel( + Payload, + std::shared_ptr>, + StreamId) { + return Flowable::error( std::logic_error("handleRequestChannel not implemented")); } -void RSocketResponder::handleFireAndForget( - rsocket::Payload, - rsocket::StreamId) { +void RSocketResponder::handleFireAndForget(Payload, StreamId) { // No default implementation, no error response to provide. } @@ -38,17 +89,15 @@ void RSocketResponder::handleMetadataPush(std::unique_ptr) { } /// Handles a new Channel requested by the other end. -yarpl::Reference> -RSocketResponder::handleRequestChannelCore( +std::shared_ptr> +RSocketResponderAdapter::handleRequestChannel( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept { - class EagerSubscriberBridge - : public yarpl::flowable::Subscriber { + std::shared_ptr> response) noexcept { + class EagerSubscriberBridge : public Subscriber { public: - void onSubscribe(yarpl::Reference - subscription) noexcept override { + void onSubscribe( + std::shared_ptr subscription) noexcept override { CHECK(!subscription_); subscription_ = std::move(subscription); if (inner_) { @@ -56,13 +105,12 @@ RSocketResponder::handleRequestChannelCore( } } - void onNext(rsocket::Payload element) noexcept override { + void onNext(Payload element) noexcept override { DCHECK(inner_); inner_->onNext(std::move(element)); } void onComplete() noexcept override { - DCHECK(inner_); if (auto inner = std::move(inner_)) { inner->onComplete(); subscription_.reset(); @@ -81,8 +129,7 @@ RSocketResponder::handleRequestChannelCore( } } - void subscribe( - yarpl::Reference> inner) { + void subscribe(std::shared_ptr> inner) { CHECK(!inner_); // only one call to subscribe is supported CHECK(inner); @@ -100,19 +147,19 @@ RSocketResponder::handleRequestChannelCore( } private: - yarpl::Reference> inner_; - yarpl::Reference subscription_; + std::shared_ptr> inner_; + std::shared_ptr subscription_; folly::exception_wrapper error_; bool completed_{false}; }; - auto eagerSubscriber = yarpl::make_ref(); - auto flowable = handleRequestChannel( + auto eagerSubscriber = std::make_shared(); + auto flowable = inner_->handleRequestChannel( std::move(request), - yarpl::flowable::Flowables::fromPublisher( - [eagerSubscriber]( - yarpl::Reference> - subscriber) { eagerSubscriber->subscribe(subscriber); }), + internal::flowableFromSubscriber( + [eagerSubscriber](std::shared_ptr> subscriber) { + eagerSubscriber->subscribe(subscriber); + }), std::move(streamId)); // bridge from the existing eager RequestHandler and old Subscriber type // to the lazy Flowable and new Subscriber type @@ -121,22 +168,33 @@ RSocketResponder::handleRequestChannelCore( } /// Handles a new Stream requested by the other end. -void RSocketResponder::handleRequestStreamCore( +void RSocketResponderAdapter::handleRequestStream( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept { - auto flowable = handleRequestStream(std::move(request), std::move(streamId)); + std::shared_ptr> response) noexcept { + auto flowable = + inner_->handleRequestStream(std::move(request), std::move(streamId)); flowable->subscribe(std::move(response)); } /// Handles a new inbound RequestResponse requested by the other end. -void RSocketResponder::handleRequestResponseCore( +void RSocketResponderAdapter::handleRequestResponse( Payload request, StreamId streamId, - const yarpl::Reference>& - responseObserver) noexcept { - auto single = handleRequestResponse(std::move(request), streamId); + std::shared_ptr> responseObserver) noexcept { + auto single = inner_->handleRequestResponse(std::move(request), streamId); single->subscribe(std::move(responseObserver)); } + +void RSocketResponderAdapter::handleFireAndForget( + Payload request, + StreamId streamId) { + inner_->handleFireAndForget(std::move(request), streamId); +} + +void RSocketResponderAdapter::handleMetadataPush( + std::unique_ptr buf) { + inner_->handleMetadataPush(std::move(buf)); +} + } // namespace rsocket diff --git a/rsocket/RSocketResponder.h b/rsocket/RSocketResponder.h index f73e244d9..eedcc2ef8 100644 --- a/rsocket/RSocketResponder.h +++ b/rsocket/RSocketResponder.h @@ -1,14 +1,52 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include "rsocket/Payload.h" -#include "rsocket/internal/Common.h" +#include "rsocket/framing/FrameHeader.h" #include "yarpl/Flowable.h" #include "yarpl/Single.h" namespace rsocket { +class RSocketResponderCore { + public: + virtual ~RSocketResponderCore() = default; + + virtual void handleFireAndForget(Payload request, StreamId streamId); + + virtual void handleMetadataPush(std::unique_ptr metadata); + + virtual std::shared_ptr> + handleRequestChannel( + Payload request, + StreamId streamId, + std::shared_ptr> response) noexcept; + + virtual void handleRequestStream( + Payload request, + StreamId streamId, + std::shared_ptr> response) noexcept; + + virtual void handleRequestResponse( + Payload request, + StreamId streamId, + std::shared_ptr> + response) noexcept; +}; + /** * Responder APIs to handle requests on an RSocket connection. * @@ -37,28 +75,28 @@ class RSocketResponder { * * Returns a Single with the response. */ - virtual yarpl::Reference> - handleRequestResponse(rsocket::Payload request, rsocket::StreamId streamId); + virtual std::shared_ptr> handleRequestResponse( + Payload request, + StreamId streamId); /** * Called when a new `requestStream` occurs from an RSocketRequester. * * Returns a Flowable with the response stream. */ - virtual yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId); + virtual std::shared_ptr> + handleRequestStream(Payload request, StreamId streamId); /** * Called when a new `requestChannel` occurs from an RSocketRequester. * * Returns a Flowable with the response stream. */ - virtual yarpl::Reference> + virtual std::shared_ptr> handleRequestChannel( - rsocket::Payload request, - yarpl::Reference> - requestStream, - rsocket::StreamId streamId); + Payload request, + std::shared_ptr> requestStream, + StreamId streamId); /** * Called when a new `fireAndForget` occurs from an RSocketRequester. @@ -75,30 +113,42 @@ class RSocketResponder { * No response. */ virtual void handleMetadataPush(std::unique_ptr metadata); +}; + +class RSocketResponderAdapter : public RSocketResponderCore { + public: + explicit RSocketResponderAdapter(std::shared_ptr inner) + : inner_(std::move(inner)) {} + virtual ~RSocketResponderAdapter() = default; /// Internal method for handling channel requests, not intended to be used by /// application code. - yarpl::Reference> - handleRequestChannelCore( + std::shared_ptr> handleRequestChannel( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept; + std::shared_ptr> response) noexcept + override; /// Internal method for handling stream requests, not intended to be used /// by application code. - void handleRequestStreamCore( + void handleRequestStream( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept; + std::shared_ptr> response) noexcept + override; /// Internal method for handling request-response requests, not intended to be /// used by application code. - void handleRequestResponseCore( + void handleRequestResponse( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept; + std::shared_ptr> response) noexcept + override; + + void handleFireAndForget(Payload request, StreamId streamId) override; + void handleMetadataPush(std::unique_ptr buf) override; + + private: + std::shared_ptr inner_; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketServer.cpp b/rsocket/RSocketServer.cpp index 6bb078a8e..1e202810d 100644 --- a/rsocket/RSocketServer.cpp +++ b/rsocket/RSocketServer.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketServer.h" #include @@ -9,6 +21,7 @@ #include "rsocket/framing/FramedDuplexConnection.h" #include "rsocket/framing/ScheduledFrameTransport.h" #include "rsocket/internal/ConnectionSet.h" +#include "rsocket/internal/WarmResumeManager.h" namespace rsocket { @@ -20,7 +33,7 @@ RSocketServer::RSocketServer( return new rsocket::SetupResumeAcceptor{ folly::EventBaseManager::get()->getExistingEventBase()}; }), - connectionSet_(std::make_shared()), + connectionSet_(std::make_unique()), stats_(std::move(stats)) {} RSocketServer::~RSocketServer() { @@ -51,7 +64,7 @@ void RSocketServer::shutdownAndWait() { folly::collectAll(closingFutures).get(); // Close off all outstanding connections. - connectionSet_.reset(); + connectionSet_->shutdownAndWait(); } void RSocketServer::start( @@ -108,12 +121,20 @@ void RSocketServer::acceptConnection( acceptor->accept( std::move(framedConnection), - std::bind( - &RSocketServer::onRSocketSetup, - this, - serviceHandler, - std::placeholders::_1, - std::placeholders::_2), + [serviceHandler, + weakConSet = std::weak_ptr(connectionSet_), + scheduledResponder = useScheduledResponder_]( + std::unique_ptr conn, + SetupParameters params) mutable { + if (auto connectionSet = weakConSet.lock()) { + RSocketServer::onRSocketSetup( + serviceHandler, + std::move(connectionSet), + scheduledResponder, + std::move(conn), + std::move(params)); + } + }, std::bind( &RSocketServer::onRSocketResume, this, @@ -124,80 +145,104 @@ void RSocketServer::acceptConnection( void RSocketServer::onRSocketSetup( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::shared_ptr connectionSet, + bool scheduledResponder, + std::unique_ptr connection, SetupParameters setupParams) { - auto eventBase = folly::EventBaseManager::get()->getExistingEventBase(); + const auto eventBase = folly::EventBaseManager::get()->getExistingEventBase(); VLOG(2) << "Received new setup payload on " << eventBase->getName(); CHECK(eventBase); auto result = serviceHandler->onNewSetup(setupParams); if (result.hasError()) { - VLOG(3) << "Terminating SETUP attempt from client. No Responder"; - throw result.error(); + VLOG(3) << "Terminating SETUP attempt from client. " + << result.error().what(); + connection->send( + FrameSerializer::createFrameSerializer(setupParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup(result.error().what()))); + return; } auto connectionParams = std::move(result.value()); if (!connectionParams.responder) { LOG(ERROR) << "Received invalid Responder. Dropping connection"; - throw RSocketException("Received invalid Responder from server"); + connection->send( + FrameSerializer::createFrameSerializer(setupParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup( + "Received invalid Responder from server"))); + return; } - auto rs = std::make_shared( - useScheduledResponder_ + const auto rs = std::make_shared( + scheduledResponder ? std::make_shared( - std::move(connectionParams.responder), *eventBase) + std::move(connectionParams.responder), *eventBase) : std::move(connectionParams.responder), nullptr, RSocketMode::SERVER, - std::move(connectionParams.stats), + connectionParams.stats, std::move(connectionParams.connectionEvents), - nullptr, /* resumeManager */ + setupParams.resumable + ? std::make_shared(connectionParams.stats) + : ResumeManager::makeEmpty(), nullptr /* coldResumeHandler */); - connectionSet_->insert(rs, eventBase); - rs->registerSet(connectionSet_); + if (!connectionSet->insert(rs, eventBase)) { + VLOG(1) << "Server is closed, so ignore the connection"; + connection->send( + FrameSerializer::createFrameSerializer(setupParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup( + "Server ignores the connection attempt"))); + return; + } + rs->registerCloseCallback(connectionSet.get()); auto requester = std::make_shared(rs, *eventBase); auto serverState = std::shared_ptr( - new RSocketServerState(*eventBase, rs, requester)); + new RSocketServerState(*eventBase, rs, std::move(requester))); serviceHandler->onNewRSocketState(std::move(serverState), setupParams.token); - rs->connectServer(std::move(frameTransport), std::move(setupParams)); + rs->connectServer( + std::make_shared(std::move(connection)), + std::move(setupParams)); } void RSocketServer::onRSocketResume( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::unique_ptr connection, ResumeParameters resumeParams) { auto result = serviceHandler->onResume(resumeParams.token); if (result.hasError()) { stats_->resumeFailedNoState(); VLOG(3) << "Terminating RESUME attempt from client. No ServerState found"; - throw result.error(); + connection->send( + FrameSerializer::createFrameSerializer(resumeParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup(result.error().what()))); + return; } - auto serverState = std::move(result.value()); + const auto serverState = std::move(result.value()); CHECK(serverState); - auto* eventBase = folly::EventBaseManager::get()->getExistingEventBase(); + const auto eventBase = folly::EventBaseManager::get()->getExistingEventBase(); VLOG(2) << "Resuming client on " << eventBase->getName(); if (!serverState->eventBase_.isInEventBaseThread()) { // If the resumed connection is on a different EventBase, then use // ScheduledFrameTransport and ScheduledFrameProcessor to ensure the // RSocketStateMachine continues to live on the same EventBase and the // IO happens in the new EventBase - auto scheduledFT = yarpl::make_ref( - std::move(frameTransport), + auto scheduledFT = std::make_shared( + std::make_shared(std::move(connection)), eventBase, /* Transport EventBase */ &serverState->eventBase_); /* StateMachine EventBase */ - serverState->eventBase_.runInEventBaseThread([ - serverState, - scheduledFT = std::move(scheduledFT), - resumeParams = std::move(resumeParams) - ]() { - serverState->rSocketStateMachine_->resumeServer( - std::move(scheduledFT), resumeParams); - }); + serverState->eventBase_.runInEventBaseThread( + [serverState, + scheduledFT = std::move(scheduledFT), + resumeParams = std::move(resumeParams)]() mutable { + serverState->rSocketStateMachine_->resumeServer( + std::move(scheduledFT), resumeParams); + }); } else { // If the resumed connection is on the same EventBase, then the // RSocketStateMachine and Transport can continue living in the same // EventBase without any thread hopping between them. serverState->rSocketStateMachine_->resumeServer( - std::move(frameTransport), resumeParams); + std::make_shared(std::move(connection)), + resumeParams); } } @@ -216,4 +261,8 @@ folly::Optional RSocketServer::listeningPort() const { : folly::none; } +size_t RSocketServer::getNumConnections() { + return connectionSet_ ? connectionSet_->size() : 0; +} + } // namespace rsocket diff --git a/rsocket/RSocketServer.h b/rsocket/RSocketServer.h index 9c7c6243d..39dae66a3 100644 --- a/rsocket/RSocketServer.h +++ b/rsocket/RSocketServer.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,6 +24,7 @@ #include "rsocket/RSocketParameters.h" #include "rsocket/RSocketResponder.h" #include "rsocket/RSocketServiceHandler.h" +#include "rsocket/internal/ConnectionSet.h" #include "rsocket/internal/SetupResumeAcceptor.h" namespace rsocket { @@ -90,17 +103,24 @@ class RSocketServer { */ void setSingleThreadedResponder(); + /** + * Number of active connections to this server. + */ + size_t getNumConnections(); + private: - void onRSocketSetup( + static void onRSocketSetup( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::shared_ptr connectionSet, + bool scheduledResponder, + std::unique_ptr connection, rsocket::SetupParameters setupPayload); void onRSocketResume( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::unique_ptr connection, rsocket::ResumeParameters setupPayload); - std::unique_ptr duplexConnectionAcceptor_; + const std::unique_ptr duplexConnectionAcceptor_; bool started{false}; class SetupResumeAcceptorTag {}; diff --git a/rsocket/RSocketServerState.h b/rsocket/RSocketServerState.h index c5fcd137f..c5d010dbb 100644 --- a/rsocket/RSocketServerState.h +++ b/rsocket/RSocketServerState.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -22,6 +34,10 @@ class RSocketServerState { return rSocketRequester_; } + folly::EventBase* eventBase() { + return &eventBase_; + } + friend class RSocketServer; private: @@ -34,7 +50,7 @@ class RSocketServerState { rSocketRequester_(rSocketRequester) {} folly::EventBase& eventBase_; - std::shared_ptr rSocketStateMachine_; - std::shared_ptr rSocketRequester_; + const std::shared_ptr rSocketStateMachine_; + const std::shared_ptr rSocketRequester_; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketServiceHandler.cpp b/rsocket/RSocketServiceHandler.cpp index f118a9715..8e3f8d341 100644 --- a/rsocket/RSocketServiceHandler.cpp +++ b/rsocket/RSocketServiceHandler.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketServiceHandler.h" @@ -16,7 +28,7 @@ RSocketServiceHandler::onResume(ResumeIdentificationToken) { bool RSocketServiceHandler::canResume( const std::vector& /* cleanStreamIds */, const std::vector& /* dirtyStreamIds */, - ResumeIdentificationToken) { + ResumeIdentificationToken) const { return true; } @@ -29,7 +41,7 @@ std::shared_ptr RSocketServiceHandler::create( const SetupParameters& setupParameters) override { try { return RSocketConnectionParams(onNewSetupFn_(setupParameters)); - } catch(const std::exception& e) { + } catch (const std::exception& e) { return folly::Unexpected( ConnectionException(e.what())); } @@ -40,4 +52,4 @@ std::shared_ptr RSocketServiceHandler::create( }; return std::make_shared(std::move(onNewSetupFn)); } -} +} // namespace rsocket diff --git a/rsocket/RSocketServiceHandler.h b/rsocket/RSocketServiceHandler.h index e0bd6ff70..b67caa358 100644 --- a/rsocket/RSocketServiceHandler.h +++ b/rsocket/RSocketServiceHandler.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -34,7 +46,6 @@ struct RSocketConnectionParams { std::shared_ptr connectionEvents; }; - // This class has to be implemented by the application. The methods can be // called from different threads and it is the application's responsibility to // ensure thread-safety. @@ -90,10 +101,10 @@ class RSocketServiceHandler { virtual bool canResume( const std::vector& /* cleanStreamIds */, const std::vector& /* dirtyStreamIds */, - ResumeIdentificationToken); + ResumeIdentificationToken) const; // Convenience constructor to create a simple RSocketServiceHandler. static std::shared_ptr create( OnNewSetupFn onNewSetupFn); }; -} +} // namespace rsocket diff --git a/rsocket/RSocketStats.cpp b/rsocket/RSocketStats.cpp index d81df87af..ee7bc6f70 100644 --- a/rsocket/RSocketStats.cpp +++ b/rsocket/RSocketStats.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketStats.h" @@ -7,7 +19,11 @@ namespace rsocket { class NoopStats : public RSocketStats { public: NoopStats() = default; - ~NoopStats() = default; + NoopStats(const NoopStats&) = delete; // non construction-copyable + NoopStats& operator=(const NoopStats&) = delete; // non copyable + NoopStats& operator=(const NoopStats&&) = delete; // non movable + NoopStats(NoopStats&&) = delete; // non construction-movable + ~NoopStats() override = default; void socketCreated() override {} void socketConnected() override {} @@ -26,11 +42,8 @@ class NoopStats : public RSocketStats { void bytesRead(size_t) override {} void frameWritten(FrameType) override {} void frameRead(FrameType) override {} - virtual void serverResume( - folly::Optional, - int64_t, - int64_t, - ResumeOutcome) override {} + void serverResume(folly::Optional, int64_t, int64_t, ResumeOutcome) + override {} void resumeBufferChanged(int, int) override {} void streamBufferChanged(int64_t, int64_t) override {} @@ -40,15 +53,9 @@ class NoopStats : public RSocketStats { void keepaliveReceived() override {} static std::shared_ptr instance() { - static auto singleton = std::make_shared(); + static const auto singleton = std::make_shared(); return singleton; } - - private: - NoopStats(const NoopStats&) = delete; // non construction-copyable - NoopStats& operator=(const NoopStats&) = delete; // non copyable - NoopStats& operator=(const NoopStats&&) = delete; // non movable - NoopStats(NoopStats&&) = delete; // non construction-movable }; std::shared_ptr RSocketStats::noop() { diff --git a/rsocket/RSocketStats.h b/rsocket/RSocketStats.h index 5aad35913..8e7480b91 100644 --- a/rsocket/RSocketStats.h +++ b/rsocket/RSocketStats.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -53,5 +65,7 @@ class RSocketStats { virtual void resumeFailedNoState() {} virtual void keepaliveSent() {} virtual void keepaliveReceived() {} + virtual void unknownFrameReceived() { + } // TODO(lehecka): add to all implementations }; } // namespace rsocket diff --git a/rsocket/ResumeManager.h b/rsocket/ResumeManager.h index 59239f3d0..198539916 100644 --- a/rsocket/ResumeManager.h +++ b/rsocket/ResumeManager.h @@ -1,11 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include - #include - +#include #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameTransportImpl.h" @@ -55,6 +65,8 @@ using StreamResumeInfos = std::unordered_map; // - lastSentPosition() would return 350 class ResumeManager { public: + static std::shared_ptr makeEmpty(); + virtual ~ResumeManager() {} // The following methods will be called for each frame which is being @@ -116,13 +128,13 @@ class ResumeManager { virtual void onStreamClosed(StreamId streamId) = 0; // Returns the cached stream information. - virtual const StreamResumeInfos& getStreamResumeInfos() = 0; + virtual const StreamResumeInfos& getStreamResumeInfos() const = 0; // Returns the largest used StreamId so far. - virtual StreamId getLargestUsedStreamId() = 0; + virtual StreamId getLargestUsedStreamId() const = 0; // Utility method to check frames which should be tracked for resumption. - inline bool shouldTrackFrame(const FrameType frameType) { + virtual bool shouldTrackFrame(const FrameType frameType) const { switch (frameType) { case FrameType::REQUEST_CHANNEL: case FrameType::REQUEST_STREAM: @@ -146,4 +158,4 @@ class ResumeManager { } } }; -} +} // namespace rsocket diff --git a/benchmarks/BaselinesAsyncSocket.cpp b/rsocket/benchmarks/BaselinesAsyncSocket.cpp similarity index 90% rename from benchmarks/BaselinesAsyncSocket.cpp rename to rsocket/benchmarks/BaselinesAsyncSocket.cpp index 44d1b929f..467a36f68 100644 --- a/benchmarks/BaselinesAsyncSocket.cpp +++ b/rsocket/benchmarks/BaselinesAsyncSocket.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -8,8 +20,6 @@ #define PORT (35437) -using namespace folly; - // namespace { // // class TcpReader : public ::folly::AsyncTransportWrapper::ReadCallback { @@ -109,8 +119,8 @@ using namespace folly; // std::move(socket), eventBase_, loadSize_, recvBufferLength_); // } // -// void acceptError(const std::exception& ex) noexcept override { -// LOG(FATAL) << "acceptError" << ex.what() << std::endl; +// void acceptError(folly::exception_wrapper ex) noexcept override { +// LOG(FATAL) << "acceptError" << ex << std::endl; // eventBase_.terminateLoopSoon(); // } // @@ -199,9 +209,9 @@ using namespace folly; //} static void BM_Baseline_AsyncSocket_SendReceive( - size_t loadSize, - size_t msgLength, - size_t recvLength) { + size_t /*loadSize*/, + size_t /*msgLength*/, + size_t /*recvLength*/) { LOG_EVERY_N(INFO, 10000) << "TODO(lehecka): benchmark needs updating, " << "it has memory corruption bugs"; // EventBase serverEventBase; @@ -227,24 +237,28 @@ static void BM_Baseline_AsyncSocket_SendReceive( } BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s40B_r1024B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 40; constexpr size_t receiveSizeB = 1024; BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); } BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s40B_r4096B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 40; constexpr size_t receiveSizeB = 4096; BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); } BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s80B_r4096B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 80; constexpr size_t receiveSizeB = 4096; BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); } BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s4096B_r4096B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 4096; constexpr size_t receiveSizeB = 4096; @@ -252,16 +266,19 @@ BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s4096B_r4096B, n) { } BENCHMARK(BM_Baseline_AsyncSocket_Latency_1M_msgs_32B, n) { + (void)n; constexpr size_t messageSizeB = 32; constexpr size_t loadSizeB = 1000000 * messageSizeB; BM_Baseline_AsyncSocket_SendReceive(loadSizeB, messageSizeB, messageSizeB); } BENCHMARK(BM_Baseline_AsyncSocket_Latency_1M_msgs_128B, n) { + (void)n; constexpr size_t messageSizeB = 128; constexpr size_t loadSizeB = 1000000 * messageSizeB; BM_Baseline_AsyncSocket_SendReceive(loadSizeB, messageSizeB, messageSizeB); } BENCHMARK(BM_Baseline_AsyncSocket_Latency_1M_msgs_4kB, n) { + (void)n; constexpr size_t messageSizeB = 4096; constexpr size_t loadSizeB = 1000000 * messageSizeB; BM_Baseline_AsyncSocket_SendReceive(loadSizeB, messageSizeB, messageSizeB); diff --git a/benchmarks/BaselinesTcp.cpp b/rsocket/benchmarks/BaselinesTcp.cpp similarity index 75% rename from benchmarks/BaselinesTcp.cpp rename to rsocket/benchmarks/BaselinesTcp.cpp index f50e2173a..d9e22892d 100644 --- a/benchmarks/BaselinesTcp.cpp +++ b/rsocket/benchmarks/BaselinesTcp.cpp @@ -1,10 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include #include #include +#include #include #include #include @@ -23,12 +36,9 @@ static void BM_Baseline_TCP_SendReceive( std::thread t([&]() { int serverSock = socket(AF_INET, SOCK_STREAM, 0); int sock = -1; - struct sockaddr_in addr; + struct sockaddr_in addr = {}; socklen_t addrlen = sizeof(addr); - char message[MAX_MESSAGE_LENGTH]; - - std::memset(message, 0, sizeof(message)); - std::memset(&addr, 0, sizeof(addr)); + std::array message = {}; if (serverSock < 0) { perror("acceptor socket"); @@ -37,7 +47,7 @@ static void BM_Baseline_TCP_SendReceive( int enable = 1; if (setsockopt( - serverSock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < + serverSock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < 0) { perror("setsocketopt SO_REUSEADDR"); return; @@ -60,7 +70,7 @@ static void BM_Baseline_TCP_SendReceive( accepting.store(true); if ((sock = accept( - serverSock, reinterpret_cast(&addr), &addrlen)) < + serverSock, reinterpret_cast(&addr), &addrlen)) < 0) { perror("accept"); return; @@ -70,7 +80,7 @@ static void BM_Baseline_TCP_SendReceive( size_t sentBytes = 0; while (sentBytes < loadSize) { - if (send(sock, message, msgLength, 0) != + if (send(sock, message.data(), msgLength, 0) != static_cast(msgLength)) { perror("send"); return; @@ -86,13 +96,10 @@ static void BM_Baseline_TCP_SendReceive( std::this_thread::yield(); } - int sock = socket(AF_INET, SOCK_STREAM, 0); - struct sockaddr_in addr; - socklen_t addrlen = sizeof(addr); - char message[MAX_MESSAGE_LENGTH]; - - std::memset(message, 0, sizeof(message)); - std::memset(&addr, 0, sizeof(addr)); + const int sock = socket(AF_INET, SOCK_STREAM, 0); + struct sockaddr_in addr = {}; + const socklen_t addrlen = sizeof(addr); + std::array message = {}; if (sock < 0) { perror("connector socket"); @@ -113,7 +120,7 @@ static void BM_Baseline_TCP_SendReceive( size_t receivedBytes = 0; while (receivedBytes < loadSize) { - ssize_t recved = recv(sock, message, recvLength, 0); + const ssize_t recved = recv(sock, message.data(), recvLength, 0); if (recved < 0) { perror("recv"); @@ -128,24 +135,28 @@ static void BM_Baseline_TCP_SendReceive( } BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s40B_r1024B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 40; constexpr size_t receiveSizeB = 1024; BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); } BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s40B_r4096B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 40; constexpr size_t receiveSizeB = 4096; BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); } BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s80B_r4096B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 80; constexpr size_t receiveSizeB = 4096; BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); } BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s4096B_r4096B, n) { + (void)n; constexpr size_t loadSizeB = 100 * 1024 * 1024; constexpr size_t sendSizeB = 4096; constexpr size_t receiveSizeB = 4096; @@ -153,16 +164,19 @@ BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s4096B_r4096B, n) { } BENCHMARK(BM_Baseline_TCP_Latency_1M_msgs_32B, n) { + (void)n; constexpr size_t messageSizeB = 32; constexpr size_t loadSizeB = 1000000 * messageSizeB; BM_Baseline_TCP_SendReceive(loadSizeB, messageSizeB, messageSizeB); } BENCHMARK(BM_Baseline_TCP_Latency_1M_msgs_128B, n) { + (void)n; constexpr size_t messageSizeB = 128; constexpr size_t loadSizeB = 1000000 * messageSizeB; BM_Baseline_TCP_SendReceive(loadSizeB, messageSizeB, messageSizeB); } BENCHMARK(BM_Baseline_TCP_Latency_1M_msgs_4kB, n) { + (void)n; constexpr size_t messageSizeB = 4096; constexpr size_t loadSizeB = 1000000 * messageSizeB; BM_Baseline_TCP_SendReceive(loadSizeB, messageSizeB, messageSizeB); diff --git a/rsocket/benchmarks/Benchmarks.cpp b/rsocket/benchmarks/Benchmarks.cpp new file mode 100644 index 000000000..69a2abc91 --- /dev/null +++ b/rsocket/benchmarks/Benchmarks.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main(int argc, char** argv) { + folly::init(&argc, &argv); + + FLAGS_logtostderr = true; + + LOG(INFO) << "Running benchmarks... (takes minutes)"; + folly::runBenchmarks(); + + return 0; +} diff --git a/benchmarks/CMakeLists.txt b/rsocket/benchmarks/CMakeLists.txt similarity index 89% rename from benchmarks/CMakeLists.txt rename to rsocket/benchmarks/CMakeLists.txt index 172309927..4d0c4f51c 100644 --- a/benchmarks/CMakeLists.txt +++ b/rsocket/benchmarks/CMakeLists.txt @@ -1,5 +1,5 @@ add_library(fixture Fixture.cpp Fixture.h) -target_link_libraries(fixture ReactiveSocket folly) +target_link_libraries(fixture ReactiveSocket Folly::folly) function(benchmark NAME FILE) add_executable(${NAME} ${FILE} Benchmarks.cpp) @@ -7,9 +7,9 @@ function(benchmark NAME FILE) ${NAME} fixture ReactiveSocket - folly-benchmark - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + Folly::follybenchmark + glog::glog + gflags) endfunction() benchmark(baselines_tcp BaselinesTcp.cpp) diff --git a/benchmarks/FireForgetThroughputTcp.cpp b/rsocket/benchmarks/FireForgetThroughputTcp.cpp similarity index 68% rename from benchmarks/FireForgetThroughputTcp.cpp rename to rsocket/benchmarks/FireForgetThroughputTcp.cpp index c8c9ffe62..03f5a29f6 100644 --- a/benchmarks/FireForgetThroughputTcp.cpp +++ b/rsocket/benchmarks/FireForgetThroughputTcp.cpp @@ -1,7 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "benchmarks/Fixture.h" -#include "benchmarks/Latch.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Fixture.h" +#include "rsocket/benchmarks/Latch.h" #include #include @@ -31,7 +43,7 @@ class Responder : public RSocketResponder { private: Latch& latch_; }; -} +} // namespace BENCHMARK(FireForgetThroughput, n) { (void)n; @@ -63,7 +75,8 @@ BENCHMARK(FireForgetThroughput, n) { for (auto& client : fixture->clients) { client->getRequester() ->fireAndForget(Payload("TcpFireAndForget")) - ->subscribe(yarpl::make_ref>()); + ->subscribe( + std::make_shared>()); } } diff --git a/benchmarks/Fixture.cpp b/rsocket/benchmarks/Fixture.cpp similarity index 70% rename from benchmarks/Fixture.cpp rename to rsocket/benchmarks/Fixture.cpp index 90a98bdb3..2a42fd222 100644 --- a/benchmarks/Fixture.cpp +++ b/rsocket/benchmarks/Fixture.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/benchmarks/Fixture.h" @@ -17,7 +29,7 @@ std::shared_ptr makeClient( std::make_unique(*eventBase, std::move(address)); return RSocket::createConnectedClient(std::move(factory)).get(); } -} +} // namespace Fixture::Fixture( Fixture::Options fixtureOpts, @@ -47,4 +59,4 @@ Fixture::Fixture( workers.push_back(std::move(worker)); } } -} +} // namespace rsocket diff --git a/benchmarks/Fixture.h b/rsocket/benchmarks/Fixture.h similarity index 60% rename from benchmarks/Fixture.h rename to rsocket/benchmarks/Fixture.h index 083dbb35a..a1b290f7d 100644 --- a/benchmarks/Fixture.h +++ b/rsocket/benchmarks/Fixture.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -39,4 +51,4 @@ struct Fixture { std::vector> clients; const Options options; }; -} +} // namespace rsocket diff --git a/rsocket/benchmarks/Latch.h b/rsocket/benchmarks/Latch.h new file mode 100644 index 000000000..fc5422169 --- /dev/null +++ b/rsocket/benchmarks/Latch.h @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +/// Simple implementation of a latch synchronization primitive, for testing. +class Latch { + public: + explicit Latch(size_t limit) : limit_{limit} {} + + void wait() { + baton_.wait(); + } + + bool timed_wait(std::chrono::milliseconds timeout) { + return baton_.timed_wait(timeout); + } + + void post() { + auto const old = count_.fetch_add(1); + if (old == limit_ - 1) { + baton_.post(); + } + } + + private: + folly::Baton<> baton_; + std::atomic count_{0}; + const size_t limit_{0}; +}; diff --git a/benchmarks/README.md b/rsocket/benchmarks/README.md similarity index 100% rename from benchmarks/README.md rename to rsocket/benchmarks/README.md diff --git a/benchmarks/RequestResponseThroughputTcp.cpp b/rsocket/benchmarks/RequestResponseThroughputTcp.cpp similarity index 72% rename from benchmarks/RequestResponseThroughputTcp.cpp rename to rsocket/benchmarks/RequestResponseThroughputTcp.cpp index 605c193d4..aace80fd2 100644 --- a/benchmarks/RequestResponseThroughputTcp.cpp +++ b/rsocket/benchmarks/RequestResponseThroughputTcp.cpp @@ -1,8 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "benchmarks/Fixture.h" -#include "benchmarks/Latch.h" -#include "benchmarks/Throughput.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Fixture.h" +#include "rsocket/benchmarks/Latch.h" +#include "rsocket/benchmarks/Throughput.h" #include #include @@ -32,7 +44,7 @@ class Observer : public yarpl::single::SingleObserverBase { public: explicit Observer(Latch& latch) : latch_{latch} {} - void onSubscribe(yarpl::Reference + void onSubscribe(std::shared_ptr subscription) override { yarpl::single::SingleObserverBase::onSubscribe( std::move(subscription)); @@ -51,7 +63,7 @@ class Observer : public yarpl::single::SingleObserverBase { private: Latch& latch_; }; -} +} // namespace BENCHMARK(RequestResponseThroughput, n) { (void)n; @@ -84,7 +96,7 @@ BENCHMARK(RequestResponseThroughput, n) { auto& client = fixture->clients[i % opts.clients]; client->getRequester() ->requestResponse(Payload("RequestResponseTcp")) - ->subscribe(yarpl::make_ref(latch)); + ->subscribe(std::make_shared(latch)); } constexpr std::chrono::minutes timeout{5}; diff --git a/benchmarks/StreamThroughputMemory.cpp b/rsocket/benchmarks/StreamThroughputMemory.cpp similarity index 75% rename from benchmarks/StreamThroughputMemory.cpp rename to rsocket/benchmarks/StreamThroughputMemory.cpp index 1de925ac8..c5128152e 100644 --- a/benchmarks/StreamThroughputMemory.cpp +++ b/rsocket/benchmarks/StreamThroughputMemory.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "benchmarks/Throughput.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Throughput.h" #include #include @@ -32,7 +44,7 @@ class DirectDuplexConnection : public DuplexConnection { DirectDuplexConnection(std::shared_ptr state, folly::EventBase& evb) : state_{std::move(state)}, evb_{evb} {} - ~DirectDuplexConnection() { + ~DirectDuplexConnection() override { *state_->destroyed.wlock() = true; } @@ -42,7 +54,7 @@ class DirectDuplexConnection : public DuplexConnection { other_->other_ = this; } - void setInput(yarpl::Reference input) override { + void setInput(std::shared_ptr input) override { input_ = std::move(input); } @@ -53,7 +65,7 @@ class DirectDuplexConnection : public DuplexConnection { } other_->evb_.runInEventBaseThread( - [ state = state_, other = other_, b = std::move(buf) ]() mutable { + [state = state_, other = other_, b = std::move(buf)]() mutable { auto destroyed = state->destroyed.rlock(); if (*destroyed) { return; @@ -69,7 +81,7 @@ class DirectDuplexConnection : public DuplexConnection { DirectDuplexConnection* other_{nullptr}; - yarpl::Reference input_; + std::shared_ptr input_; }; class Acceptor : public ConnectionAcceptor { @@ -82,7 +94,7 @@ class Acceptor : public ConnectionAcceptor { void start(OnDuplexConnectionAccept onAccept) override { worker_.getEventBase()->runInEventBaseThread( - [ this, onAccept = std::move(onAccept) ]() mutable { + [this, onAccept = std::move(onAccept)]() mutable { auto server = std::make_unique( std::move(state_), *worker_.getEventBase()); server->tie(client_); @@ -124,10 +136,12 @@ class Factory : public ConnectionFactory { server_->start([responder](const SetupParameters&) { return responder; }); } - folly::Future connect() override { + folly::Future connect( + ProtocolVersion, + ResumeStatus /* unused */) override { return folly::via(worker_.getEventBase(), [this] { - return ConnectedDuplexConnection{std::move(connection_), - *worker_.getEventBase()}; + return ConnectedDuplexConnection{ + std::move(connection_), *worker_.getEventBase()}; }); } @@ -150,7 +164,7 @@ BENCHMARK(StreamThroughput, n) { (void)n; std::shared_ptr client; - yarpl::Reference subscriber; + std::shared_ptr subscriber; folly::ScopedEventBaseThread worker; @@ -164,7 +178,7 @@ BENCHMARK(StreamThroughput, n) { client->getRequester() ->requestStream(Payload("InMemoryStream")) - ->subscribe(yarpl::make_ref(latch, FLAGS_items)); + ->subscribe(std::make_shared(latch, FLAGS_items)); constexpr std::chrono::minutes timeout{5}; if (!latch.timed_wait(timeout)) { diff --git a/benchmarks/StreamThroughputTcp.cpp b/rsocket/benchmarks/StreamThroughputTcp.cpp similarity index 69% rename from benchmarks/StreamThroughputTcp.cpp rename to rsocket/benchmarks/StreamThroughputTcp.cpp index f833d93af..4f9c5e343 100644 --- a/benchmarks/StreamThroughputTcp.cpp +++ b/rsocket/benchmarks/StreamThroughputTcp.cpp @@ -1,7 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "benchmarks/Fixture.h" -#include "benchmarks/Throughput.h" +#include "rsocket/benchmarks/Fixture.h" +#include "rsocket/benchmarks/Throughput.h" #include #include @@ -54,7 +66,7 @@ BENCHMARK(StreamThroughput, n) { for (auto& client : fixture->clients) { client->getRequester() ->requestStream(Payload("TcpStream")) - ->subscribe(yarpl::make_ref(latch, FLAGS_items)); + ->subscribe(std::make_shared(latch, FLAGS_items)); } } diff --git a/benchmarks/Throughput.h b/rsocket/benchmarks/Throughput.h similarity index 65% rename from benchmarks/Throughput.h rename to rsocket/benchmarks/Throughput.h index 6f6fa4351..c5c215e99 100644 --- a/benchmarks/Throughput.h +++ b/rsocket/benchmarks/Throughput.h @@ -1,9 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "benchmarks/Latch.h" #include "rsocket/RSocketResponder.h" +#include "rsocket/benchmarks/Latch.h" namespace rsocket { @@ -14,14 +26,14 @@ class FixedResponder : public RSocketResponder { : message_{folly::IOBuf::copyBuffer(message)} {} /// Infinitely streams back the message. - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( Payload, StreamId) override { - return yarpl::flowable::Flowables::fromGenerator( + return yarpl::flowable::Flowable::fromGenerator( [msg = message_->clone()] { return Payload(msg->clone()); }); } - yarpl::Reference> handleRequestResponse( + std::shared_ptr> handleRequestResponse( Payload, StreamId) override { return yarpl::single::Singles::fromGenerator( @@ -72,4 +84,4 @@ class BoundedSubscriber : public yarpl::flowable::BaseSubscriber { size_t requested_{0}; std::atomic received_{0}; }; -} +} // namespace rsocket diff --git a/examples/README.md b/rsocket/examples/README.md similarity index 100% rename from examples/README.md rename to rsocket/examples/README.md diff --git a/examples/channel-hello-world/ChannelHelloWorld_Client.cpp b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp similarity index 50% rename from examples/channel-hello-world/ChannelHelloWorld_Client.cpp rename to rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp index 75b0568ce..55c12bee6 100644 --- a/examples/channel-hello-world/ChannelHelloWorld_Client.cpp +++ b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,13 +18,12 @@ #include #include -#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; using namespace yarpl::flowable; @@ -30,16 +41,17 @@ int main(int argc, char* argv[]) { address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(*worker.getEventBase(), - std::move(address))) - .get(); + std::make_unique( + *worker.getEventBase(), std::move(address))) + .get(); client->getRequester() - ->requestChannel(Flowables::justN({"initialPayload", "Bob", "Jane"}) - ->map([](std::string v) { - std::cout << "Sending: " << v << std::endl; - return Payload(v); - })) + ->requestChannel( + Payload("initialPayload"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { + std::cout << "Sending: " << v << std::endl; + return Payload(v); + })) ->subscribe([](Payload p) { std::cout << "Received: " << p.moveDataToString() << std::endl; }); diff --git a/examples/channel-hello-world/ChannelHelloWorld_Server.cpp b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp similarity index 69% rename from examples/channel-hello-world/ChannelHelloWorld_Server.cpp rename to rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp index 40cf04092..9be2a69d9 100644 --- a/examples/channel-hello-world/ChannelHelloWorld_Server.cpp +++ b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp @@ -1,6 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include +#include #include #include @@ -18,9 +31,9 @@ DEFINE_int32(port, 9898, "port to connect to"); class HelloChannelRequestResponder : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestChannel( + std::shared_ptr> handleRequestChannel( rsocket::Payload initialPayload, - yarpl::Reference> request, + std::shared_ptr> request, rsocket::StreamId) override { std::cout << "Initial request " << initialPayload.cloneDataToString() << std::endl; diff --git a/examples/channel-hello-world/README.md b/rsocket/examples/channel-hello-world/README.md similarity index 100% rename from examples/channel-hello-world/README.md rename to rsocket/examples/channel-hello-world/README.md diff --git a/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp similarity index 78% rename from examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp rename to rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp index 15703f3da..033649c17 100644 --- a/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp +++ b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp @@ -1,19 +1,29 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include #include -#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" using namespace ::folly; -using namespace ::rsocket_example; using namespace ::rsocket; -using namespace yarpl::flowable; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -41,7 +51,7 @@ class ChannelConnectionEvents : public RSocketConnectionEvents { private: std::atomic closed_{false}; }; -} +} // namespace void sendRequest(std::string mimeType) { folly::ScopedEventBaseThread worker; diff --git a/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp similarity index 66% rename from examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp rename to rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp index bd8a5bf46..97e640f98 100644 --- a/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp +++ b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -9,7 +21,6 @@ #include "rsocket/RSocket.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -using namespace ::folly; using namespace ::rsocket; DEFINE_int32(port, 9898, "port to connect to"); diff --git a/rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp b/rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp new file mode 100644 index 000000000..563ea7d10 --- /dev/null +++ b/rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "JsonRequestHandler.h" +#include +#include +#include "yarpl/Flowable.h" + +using namespace rsocket; +using namespace yarpl::flowable; + +/// Handles a new inbound Stream requested by the other end. +std::shared_ptr> +JsonRequestResponder::handleRequestStream(Payload request, StreamId) { + LOG(INFO) << "JsonRequestResponder.handleRequestStream " << request; + + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 100)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello (should be JSON) " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); +} diff --git a/rsocket/examples/conditional-request-handling/JsonRequestHandler.h b/rsocket/examples/conditional-request-handling/JsonRequestHandler.h new file mode 100644 index 000000000..2bc0f45ad --- /dev/null +++ b/rsocket/examples/conditional-request-handling/JsonRequestHandler.h @@ -0,0 +1,26 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" +#include "rsocket/RSocket.h" + +class JsonRequestResponder : public rsocket::RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) + override; +}; diff --git a/examples/conditional-request-handling/README.md b/rsocket/examples/conditional-request-handling/README.md similarity index 100% rename from examples/conditional-request-handling/README.md rename to rsocket/examples/conditional-request-handling/README.md diff --git a/rsocket/examples/conditional-request-handling/TextRequestHandler.cpp b/rsocket/examples/conditional-request-handling/TextRequestHandler.cpp new file mode 100644 index 000000000..708313186 --- /dev/null +++ b/rsocket/examples/conditional-request-handling/TextRequestHandler.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "TextRequestHandler.h" +#include +#include +#include "yarpl/Flowable.h" + +using namespace rsocket; +using namespace yarpl::flowable; + +/// Handles a new inbound Stream requested by the other end. +std::shared_ptr> +TextRequestResponder::handleRequestStream(Payload request, StreamId) { + LOG(INFO) << "TextRequestResponder.handleRequestStream " << request; + + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 100)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); +} diff --git a/rsocket/examples/conditional-request-handling/TextRequestHandler.h b/rsocket/examples/conditional-request-handling/TextRequestHandler.h new file mode 100644 index 000000000..7098b516e --- /dev/null +++ b/rsocket/examples/conditional-request-handling/TextRequestHandler.h @@ -0,0 +1,26 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" +#include "rsocket/RSocket.h" + +class TextRequestResponder : public rsocket::RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) + override; +}; diff --git a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp similarity index 55% rename from examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp rename to rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp index b3b338f6f..c20dd0971 100644 --- a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp +++ b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,16 +18,13 @@ #include #include -#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Single.h" -using namespace rsocket_example; using namespace rsocket; -using namespace yarpl; -using namespace yarpl::single; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -31,9 +40,9 @@ int main(int argc, char* argv[]) { address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(*worker.getEventBase(), - std::move(address))) - .get(); + std::make_unique( + *worker.getEventBase(), std::move(address))) + .get(); client->getRequester()->fireAndForget(Payload("Hello World!"))->subscribe([] { std::cout << "wrote to network" << std::endl; diff --git a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp similarity index 67% rename from examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp rename to rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp index c17150c4f..7c1dfbce0 100644 --- a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp +++ b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -10,8 +22,6 @@ #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" using namespace rsocket; -using namespace yarpl; -using namespace yarpl::single; DEFINE_int32(port, 9898, "port to connect to"); diff --git a/examples/fire-and-forget-hello-world/README.md b/rsocket/examples/fire-and-forget-hello-world/README.md similarity index 100% rename from examples/fire-and-forget-hello-world/README.md rename to rsocket/examples/fire-and-forget-hello-world/README.md diff --git a/examples/request-response-hello-world/README.md b/rsocket/examples/request-response-hello-world/README.md similarity index 100% rename from examples/request-response-hello-world/README.md rename to rsocket/examples/request-response-hello-world/README.md diff --git a/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp similarity index 56% rename from examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp rename to rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp index 75da5a397..f9f935a05 100644 --- a/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp +++ b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,16 +18,13 @@ #include #include -#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Single.h" -using namespace rsocket_example; using namespace rsocket; -using namespace yarpl; -using namespace yarpl::single; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -30,9 +39,9 @@ int main(int argc, char* argv[]) { folly::ScopedEventBaseThread worker; auto client = RSocket::createConnectedClient( - std::make_unique( - *worker.getEventBase(), - std::move(address))).get(); + std::make_unique( + *worker.getEventBase(), std::move(address))) + .get(); client->getRequester() ->requestResponse(Payload("Jane")) diff --git a/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp similarity index 55% rename from examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp rename to rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp index ad8c824c2..7bba6813a 100644 --- a/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp +++ b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp @@ -1,6 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include +#include #include #include @@ -11,7 +24,6 @@ #include "yarpl/Single.h" using namespace rsocket; -using namespace yarpl; using namespace yarpl::single; DEFINE_int32(port, 9898, "port to connect to"); @@ -19,26 +31,26 @@ DEFINE_int32(port, 9898, "port to connect to"); namespace { class HelloRequestResponseResponder : public rsocket::RSocketResponder { public: - Reference> handleRequestResponse(Payload request, StreamId) - override { + std::shared_ptr> handleRequestResponse( + Payload request, + StreamId) override { std::cout << "HelloRequestResponseRequestResponder.handleRequestResponse " << request << std::endl; // string from payload data auto requestString = request.moveDataToString(); - return Single::create([name = std::move(requestString)]( - auto subscriber) { - - std::stringstream ss; - ss << "Hello " << name << "!"; - std::string s = ss.str(); - subscriber->onSubscribe(SingleSubscriptions::empty()); - subscriber->onSuccess(Payload(s, "metadata")); - }); + return Single::create( + [name = std::move(requestString)](auto subscriber) { + std::stringstream ss; + ss << "Hello " << name << "!"; + std::string s = ss.str(); + subscriber->onSubscribe(SingleSubscriptions::empty()); + subscriber->onSuccess(Payload(s, "metadata")); + }); } }; -} +} // namespace int main(int argc, char* argv[]) { FLAGS_logtostderr = true; diff --git a/examples/resumption/ColdResumption_Client.cpp b/rsocket/examples/resumption/ColdResumption_Client.cpp similarity index 82% rename from examples/resumption/ColdResumption_Client.cpp rename to rsocket/examples/resumption/ColdResumption_Client.cpp index 91ba53f8b..8d443bf0f 100644 --- a/examples/resumption/ColdResumption_Client.cpp +++ b/rsocket/examples/resumption/ColdResumption_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -8,22 +20,21 @@ #include "rsocket/RSocket.h" +#include "rsocket/test/test_utils/ColdResumeManager.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "test/test_utils/ColdResumeManager.h" using namespace rsocket; -using namespace yarpl; using namespace yarpl::flowable; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); -typedef std::map>> HelloSubscribers; +typedef std::map>> + HelloSubscribers; namespace { -class HelloSubscriber : public virtual Refcounted, - public Subscriber { +class HelloSubscriber : public Subscriber { public: void request(int n) { while (!subscription_) { @@ -37,7 +48,7 @@ class HelloSubscriber : public virtual Refcounted, }; protected: - void onSubscribe(Reference subscription) override { + void onSubscribe(std::shared_ptr subscription) override { subscription_ = subscription; } @@ -49,7 +60,7 @@ class HelloSubscriber : public virtual Refcounted, void onError(folly::exception_wrapper) override {} private: - Reference subscription_; + std::shared_ptr subscription_; std::atomic count_{0}; }; @@ -59,14 +70,14 @@ class HelloResumeHandler : public ColdResumeHandler { : subscribers_(std::move(subscribers)) {} std::string generateStreamToken(const Payload& payload, StreamId, StreamType) - override { + const override { auto streamToken = payload.data->cloneAsValue().moveToFbString().toStdString(); VLOG(3) << "Generated token: " << streamToken; return streamToken; } - Reference> handleRequesterResumeStream( + std::shared_ptr> handleRequesterResumeStream( std::string streamToken, size_t consumerAllowance) override { CHECK(subscribers_.find(streamToken) != subscribers_.end()); @@ -92,7 +103,7 @@ std::unique_ptr getConnFactory( address.setFromHostPort(FLAGS_host, FLAGS_port); return std::make_unique(*eventBase, address); } -} +} // namespace // There are three sessions and three streams. // There is cold-resumption between the three sessions. @@ -117,7 +128,7 @@ int main(int argc, char* argv[]) { auto resumeManager = std::make_shared( RSocketStats::noop(), "" /* inputFile */); { - auto firstSub = yarpl::make_ref(); + auto firstSub = std::make_shared(); auto coldResumeHandler = std::make_shared( HelloSubscribers({{firstPayload, firstSub}})); auto firstClient = RSocket::createConnectedClient( @@ -156,7 +167,7 @@ int main(int argc, char* argv[]) { auto resumeManager = std::make_shared( RSocketStats::noop(), "/tmp/firstResumption.json" /* inputFile */); { - auto firstSub = yarpl::make_ref(); + auto firstSub = std::make_shared(); auto coldResumeHandler = std::make_shared( HelloSubscribers({{firstPayload, firstSub}})); auto secondClient = RSocket::createResumedClient( @@ -170,7 +181,7 @@ int main(int argc, char* argv[]) { // Create another stream to verify StreamIds are set properly after // resumption - auto secondSub = yarpl::make_ref(); + auto secondSub = std::make_shared(); secondClient->getRequester() ->requestStream(Payload(secondPayload)) ->subscribe(secondSub); @@ -192,8 +203,8 @@ int main(int argc, char* argv[]) { { auto resumeManager = std::make_shared( RSocketStats::noop(), "/tmp/secondResumption.json" /* inputFile */); - auto firstSub = yarpl::make_ref(); - auto secondSub = yarpl::make_ref(); + auto firstSub = std::make_shared(); + auto secondSub = std::make_shared(); auto coldResumeHandler = std::make_shared(HelloSubscribers( {{firstPayload, firstSub}, {secondPayload, secondSub}})); @@ -209,7 +220,7 @@ int main(int argc, char* argv[]) { // Create another stream to verify StreamIds are set properly after // resumption - auto thirdSub = yarpl::make_ref(); + auto thirdSub = std::make_shared(); thirdClient->getRequester() ->requestStream(Payload(thirdPayload)) ->subscribe(thirdSub); diff --git a/examples/resumption/Resumption_Server.cpp b/rsocket/examples/resumption/Resumption_Server.cpp similarity index 69% rename from examples/resumption/Resumption_Server.cpp rename to rsocket/examples/resumption/Resumption_Server.cpp index 226e138f5..36ccb6bb7 100644 --- a/examples/resumption/Resumption_Server.cpp +++ b/rsocket/examples/resumption/Resumption_Server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -17,18 +29,17 @@ DEFINE_int32(port, 9898, "Port to accept connections on"); class HelloStreamRequestResponder : public RSocketResponder { public: - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( rsocket::Payload request, rsocket::StreamId) override { auto requestString = request.moveDataToString(); - return Flowables::range(1, 1000)->map([name = std::move(requestString)]( - int64_t v) { - return Payload(folly::to(v), "metadata"); - }); + return Flowable<>::range(1, 1000)->map( + [name = std::move(requestString)](int64_t v) { + return Payload(folly::to(v), "metadata"); + }); } }; - class HelloServiceHandler : public RSocketServiceHandler { public: folly::Expected onNewSetup( diff --git a/examples/resumption/WarmResumption_Client.cpp b/rsocket/examples/resumption/WarmResumption_Client.cpp similarity index 74% rename from examples/resumption/WarmResumption_Client.cpp rename to rsocket/examples/resumption/WarmResumption_Client.cpp index 3f72c809d..f83a24e85 100644 --- a/examples/resumption/WarmResumption_Client.cpp +++ b/rsocket/examples/resumption/WarmResumption_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,14 +18,13 @@ #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/internal/ClientResumeStatusCallback.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; DEFINE_string(host, "localhost", "host to connect to"); @@ -21,8 +32,7 @@ DEFINE_int32(port, 9898, "host:port to connect to"); namespace { -class HelloSubscriber : public virtual yarpl::Refcounted, - public yarpl::flowable::Subscriber { +class HelloSubscriber : public yarpl::flowable::Subscriber { public: void request(int n) { LOG(INFO) << "... requesting " << n; @@ -43,7 +53,7 @@ class HelloSubscriber : public virtual yarpl::Refcounted, }; protected: - void onSubscribe(yarpl::Reference + void onSubscribe(std::shared_ptr subscription) noexcept override { subscription_ = subscription; } @@ -62,14 +72,14 @@ class HelloSubscriber : public virtual yarpl::Refcounted, } private: - yarpl::Reference subscription_{nullptr}; + std::shared_ptr subscription_{nullptr}; std::atomic count_{0}; }; -} +} // namespace std::unique_ptr getClientAndRequestStream( folly::EventBase* eventBase, - yarpl::Reference subscriber) { + std::shared_ptr subscriber) { folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); SetupParameters setupParameters; @@ -90,7 +100,7 @@ int main(int argc, char* argv[]) { folly::ScopedEventBaseThread worker1; - auto subscriber1 = yarpl::make_ref(); + auto subscriber1 = std::make_shared(); auto client = getClientAndRequestStream(worker1.getEventBase(), subscriber1); subscriber1->request(7); @@ -104,7 +114,7 @@ int main(int argc, char* argv[]) { client->resume() .via(worker2.getEventBase()) - .then([subscriber1] { + .thenValue([subscriber1](folly::Unit) { // continue with the old client. subscriber1->request(3); while (subscriber1->rcvdCount() < 10) { @@ -112,7 +122,7 @@ int main(int argc, char* argv[]) { } subscriber1->cancel(); }) - .onError([&](folly::exception_wrapper ex) { + .thenError([&](folly::exception_wrapper ex) { LOG(INFO) << "Resumption Failed: " << ex.what(); try { ex.throw_exception(); @@ -124,7 +134,7 @@ int main(int argc, char* argv[]) { LOG(INFO) << "UnknownException " << typeid(e).name(); } // Create a new client - auto subscriber2 = yarpl::make_ref(); + auto subscriber2 = std::make_shared(); auto client = getClientAndRequestStream(worker1.getEventBase(), subscriber2); subscriber2->request(7); diff --git a/examples/stream-hello-world/README.md b/rsocket/examples/stream-hello-world/README.md similarity index 100% rename from examples/stream-hello-world/README.md rename to rsocket/examples/stream-hello-world/README.md diff --git a/examples/stream-hello-world/StreamHelloWorld_Client.cpp b/rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp similarity index 56% rename from examples/stream-hello-world/StreamHelloWorld_Client.cpp rename to rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp index 2be2bd1b5..07fc75d9f 100644 --- a/examples/stream-hello-world/StreamHelloWorld_Client.cpp +++ b/rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,13 +18,12 @@ #include #include -#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; DEFINE_string(host, "localhost", "host to connect to"); @@ -29,9 +40,9 @@ int main(int argc, char* argv[]) { address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(*worker.getEventBase(), - std::move(address))) - .get(); + std::make_unique( + *worker.getEventBase(), std::move(address))) + .get(); client->getRequester() ->requestStream(Payload("Jane")) diff --git a/examples/stream-hello-world/StreamHelloWorld_Server.cpp b/rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp similarity index 60% rename from examples/stream-hello-world/StreamHelloWorld_Server.cpp rename to rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp index 1d57fb467..c3af1a157 100644 --- a/examples/stream-hello-world/StreamHelloWorld_Server.cpp +++ b/rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp @@ -1,6 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include +#include #include #include @@ -18,7 +31,7 @@ DEFINE_int32(port, 9898, "port to connect to"); class HelloStreamRequestResponder : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( rsocket::Payload request, rsocket::StreamId) override { std::cout << "HelloStreamRequestResponder.handleRequestStream " << request @@ -27,13 +40,13 @@ class HelloStreamRequestResponder : public rsocket::RSocketResponder { // string from payload data auto requestString = request.moveDataToString(); - return Flowables::range(1, 10)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); + return Flowable<>::range(1, 10)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); } }; diff --git a/examples/stream-observable-to-flowable/README.md b/rsocket/examples/stream-observable-to-flowable/README.md similarity index 100% rename from examples/stream-observable-to-flowable/README.md rename to rsocket/examples/stream-observable-to-flowable/README.md diff --git a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp similarity index 60% rename from examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp rename to rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp index 6b201fb94..18cc8be0c 100644 --- a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp +++ b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,15 +18,13 @@ #include #include -#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; -using yarpl::flowable::Subscribers; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -30,9 +40,9 @@ int main(int argc, char* argv[]) { address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(*worker.getEventBase(), - std::move(address))) - .get(); + std::make_unique( + *worker.getEventBase(), std::move(address))) + .get(); client->getRequester() ->requestStream(Payload("TopicX")) diff --git a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp similarity index 59% rename from examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp rename to rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp index 1a75b9d8a..30f2b3d5e 100644 --- a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp +++ b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include +#include #include #include @@ -21,7 +34,7 @@ DEFINE_int32(port, 9898, "port to connect to"); class PushStreamRequestResponder : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( Payload request, rsocket::StreamId) override { std::cout << "PushStreamRequestResponder.handleRequestStream " << request @@ -44,24 +57,24 @@ class PushStreamRequestResponder : public rsocket::RSocketResponder { // This examples uses BackpressureStrategy::DROP which simply // drops any events emitted from the Observable if the Flowable // does not have any credits from the Subscriber. - return Observable::create([name = std::move(requestString)]( - Reference> s) { - // Must make this async since it's an infinite stream - // and will block the IO thread. - // Using a raw thread right now since the 'subscribeOn' - // operator is not ready yet. This can eventually - // be replaced with use of 'subscribeOn'. - std::thread([s, name]() { - int64_t v = 0; - while (!s->isUnsubscribed()) { - std::stringstream ss; - ss << "Event[" << name << "]-" << ++v << "!"; - std::string payloadData = ss.str(); - s->onNext(Payload(payloadData, "metadata")); - } - }).detach(); - - }) + return Observable::create( + [name = std::move(requestString)]( + std::shared_ptr> s) { + // Must make this async since it's an infinite stream + // and will block the IO thread. + // Using a raw thread right now since the 'subscribeOn' + // operator is not ready yet. This can eventually + // be replaced with use of 'subscribeOn'. + std::thread([s, name]() { + int64_t v = 0; + while (!s->isUnsubscribed()) { + std::stringstream ss; + ss << "Event[" << name << "]-" << ++v << "!"; + std::string payloadData = ss.str(); + s->onNext(Payload(payloadData, "metadata")); + } + }).detach(); + }) ->toFlowable(BackpressureStrategy::DROP); } }; diff --git a/examples/util/ExampleSubscriber.cpp b/rsocket/examples/util/ExampleSubscriber.cpp similarity index 76% rename from examples/util/ExampleSubscriber.cpp rename to rsocket/examples/util/ExampleSubscriber.cpp index 2b782be64..6ad535d36 100644 --- a/examples/util/ExampleSubscriber.cpp +++ b/rsocket/examples/util/ExampleSubscriber.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/examples/util/ExampleSubscriber.h" #include @@ -23,7 +35,7 @@ ExampleSubscriber::ExampleSubscriber(int initialRequest, int numToTake) } void ExampleSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { LOG(INFO) << "ExampleSubscriber " << this << " onSubscribe, requesting " << initialRequest_; subscription_ = std::move(subscription); @@ -69,4 +81,4 @@ void ExampleSubscriber::awaitTerminalEvent() { terminalEventCV_.wait(lk, [this] { return terminated_; }); LOG(INFO) << "ExampleSubscriber " << this << " unblocked"; } -} +} // namespace rsocket_example diff --git a/examples/util/ExampleSubscriber.h b/rsocket/examples/util/ExampleSubscriber.h similarity index 50% rename from examples/util/ExampleSubscriber.h rename to rsocket/examples/util/ExampleSubscriber.h index 00a4db3ad..24a1caa23 100644 --- a/examples/util/ExampleSubscriber.h +++ b/rsocket/examples/util/ExampleSubscriber.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -15,13 +27,12 @@ * Request 5 items to begin with, then 3 more after each receipt of 3. */ namespace rsocket_example { -class ExampleSubscriber - : public yarpl::flowable::Subscriber { +class ExampleSubscriber : public yarpl::flowable::Subscriber { public: ~ExampleSubscriber(); ExampleSubscriber(int initialRequest, int numToTake); - void onSubscribe(yarpl::Reference + void onSubscribe(std::shared_ptr subscription) noexcept override; void onNext(rsocket::Payload) noexcept override; void onComplete() noexcept override; @@ -35,9 +46,9 @@ class ExampleSubscriber int numToTake_; int requested_; int received_; - yarpl::Reference subscription_; + std::shared_ptr subscription_; bool terminated_{false}; std::mutex m_; std::condition_variable terminalEventCV_; }; -} +} // namespace rsocket_example diff --git a/examples/util/README.md b/rsocket/examples/util/README.md similarity index 100% rename from examples/util/README.md rename to rsocket/examples/util/README.md diff --git a/rsocket/framing/ErrorCode.cpp b/rsocket/framing/ErrorCode.cpp index 4b9b4767d..6ee11c348 100644 --- a/rsocket/framing/ErrorCode.cpp +++ b/rsocket/framing/ErrorCode.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/ErrorCode.h" @@ -31,4 +43,4 @@ std::ostream& operator<<(std::ostream& os, ErrorCode errorCode) { } return os << "ErrorCode[" << static_cast(errorCode) << "]"; } -} +} // namespace rsocket diff --git a/rsocket/framing/ErrorCode.h b/rsocket/framing/ErrorCode.h index f6f0ce31d..93f741aaa 100644 --- a/rsocket/framing/ErrorCode.h +++ b/rsocket/framing/ErrorCode.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -40,4 +52,4 @@ enum class ErrorCode : uint32_t { }; std::ostream& operator<<(std::ostream&, ErrorCode); -} +} // namespace rsocket diff --git a/rsocket/framing/Frame.cpp b/rsocket/framing/Frame.cpp index 069d26920..9b3d8cc53 100644 --- a/rsocket/framing/Frame.cpp +++ b/rsocket/framing/Frame.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/Frame.h" @@ -11,10 +23,25 @@ namespace rsocket { -const uint32_t Frame_LEASE::kMaxTtl; -const uint32_t Frame_LEASE::kMaxNumRequests; -const uint32_t Frame_SETUP::kMaxKeepaliveTime; -const uint32_t Frame_SETUP::kMaxLifetime; +namespace detail { + +FrameFlags getFlags(const Payload& p) { + return p.metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_; +} + +void checkFlags(const Payload& p, FrameFlags flags) { + if (bool(p.metadata) != bool(flags & FrameFlags::METADATA)) { + throw std::invalid_argument{ + "Value of METADATA flag doesn't match payload metadata"}; + } +} + +} // namespace detail + +constexpr uint32_t Frame_LEASE::kMaxTtl; +constexpr uint32_t Frame_LEASE::kMaxNumRequests; +constexpr uint32_t Frame_SETUP::kMaxKeepaliveTime; +constexpr uint32_t Frame_SETUP::kMaxLifetime; std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_Base& frame) { return os << frame.header_ << "(" << frame.requestN_ << ", " @@ -53,54 +80,65 @@ std::ostream& operator<<(std::ostream& os, const Frame_PAYLOAD& frame) { return os << frame.header_ << ", " << frame.payload_; } -Frame_ERROR Frame_ERROR::invalidSetup(std::string message) { - return connectionErr(ErrorCode::INVALID_SETUP, std::move(message)); +Frame_ERROR Frame_ERROR::invalidSetup(folly::StringPiece message) { + return connectionErr(ErrorCode::INVALID_SETUP, message); } -Frame_ERROR Frame_ERROR::unsupportedSetup(std::string message) { - return connectionErr(ErrorCode::UNSUPPORTED_SETUP, std::move(message)); +Frame_ERROR Frame_ERROR::unsupportedSetup(folly::StringPiece message) { + return connectionErr(ErrorCode::UNSUPPORTED_SETUP, message); } -Frame_ERROR Frame_ERROR::rejectedSetup(std::string message) { - return connectionErr(ErrorCode::REJECTED_SETUP, std::move(message)); +Frame_ERROR Frame_ERROR::rejectedSetup(folly::StringPiece message) { + return connectionErr(ErrorCode::REJECTED_SETUP, message); } -Frame_ERROR Frame_ERROR::rejectedResume(std::string message) { - return connectionErr(ErrorCode::REJECTED_RESUME, std::move(message)); +Frame_ERROR Frame_ERROR::rejectedResume(folly::StringPiece message) { + return connectionErr(ErrorCode::REJECTED_RESUME, message); } -Frame_ERROR Frame_ERROR::connectionError(std::string message) { - return connectionErr(ErrorCode::CONNECTION_ERROR, std::move(message)); +Frame_ERROR Frame_ERROR::connectionError(folly::StringPiece message) { + return connectionErr(ErrorCode::CONNECTION_ERROR, message); } Frame_ERROR Frame_ERROR::applicationError( StreamId stream, - std::string message) { - return streamErr(ErrorCode::APPLICATION_ERROR, std::move(message), stream); + folly::StringPiece message) { + return streamErr(ErrorCode::APPLICATION_ERROR, message, stream); +} + +Frame_ERROR Frame_ERROR::applicationError(StreamId stream, Payload&& payload) { + if (stream == 0) { + throw std::invalid_argument{"Can't make stream error for stream zero"}; + } + return Frame_ERROR(stream, ErrorCode::APPLICATION_ERROR, std::move(payload)); } -Frame_ERROR Frame_ERROR::rejected(StreamId stream, std::string message) { - return streamErr(ErrorCode::REJECTED, std::move(message), stream); +Frame_ERROR Frame_ERROR::rejected(StreamId stream, folly::StringPiece message) { + return streamErr(ErrorCode::REJECTED, message, stream); } -Frame_ERROR Frame_ERROR::canceled(StreamId stream, std::string message) { - return streamErr(ErrorCode::CANCELED, std::move(message), stream); +Frame_ERROR Frame_ERROR::canceled(StreamId stream, folly::StringPiece message) { + return streamErr(ErrorCode::CANCELED, message, stream); } -Frame_ERROR Frame_ERROR::invalid(StreamId stream, std::string message) { - return streamErr(ErrorCode::INVALID, std::move(message), stream); +Frame_ERROR Frame_ERROR::invalid(StreamId stream, folly::StringPiece message) { + return streamErr(ErrorCode::INVALID, message, stream); } -Frame_ERROR Frame_ERROR::connectionErr(ErrorCode err, std::string message) { - return Frame_ERROR{0, err, Payload{std::move(message)}}; +Frame_ERROR Frame_ERROR::connectionErr( + ErrorCode err, + folly::StringPiece message) { + return Frame_ERROR{0, err, Payload{message}}; } -Frame_ERROR -Frame_ERROR::streamErr(ErrorCode err, std::string message, StreamId stream) { +Frame_ERROR Frame_ERROR::streamErr( + ErrorCode err, + folly::StringPiece message, + StreamId stream) { if (stream == 0) { throw std::invalid_argument{"Can't make stream error for stream zero"}; } - return Frame_ERROR{stream, err, Payload{std::move(message)}}; + return Frame_ERROR{stream, err, Payload{message}}; } std::ostream& operator<<(std::ostream& os, const Frame_ERROR& frame) { @@ -147,7 +185,8 @@ std::ostream& operator<<(std::ostream& os, const Frame_RESUME_OK& frame) { } std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_CHANNEL& frame) { - return os << frame.header_ << ", " << frame.payload_; + return os << frame.header_ << ", initialRequestN=" << frame.requestN_ << ", " + << frame.payload_; } std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_STREAM& frame) { @@ -155,25 +194,4 @@ std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_STREAM& frame) { << frame.payload_; } -StreamType getStreamType(FrameType frameType) { - if (frameType == FrameType::REQUEST_STREAM) { - return StreamType::STREAM; - } else if (frameType == FrameType::REQUEST_CHANNEL) { - return StreamType::CHANNEL; - } else if (frameType == FrameType::REQUEST_RESPONSE) { - return StreamType::REQUEST_RESPONSE; - } else if (frameType == FrameType::REQUEST_FNF) { - return StreamType::FNF; - } else { - LOG(FATAL) << "Unknown open stream frame : " << frameType; - } -} - -bool isNewStreamFrame(FrameType frameType) { - return frameType == FrameType::REQUEST_CHANNEL || - frameType == FrameType::REQUEST_STREAM || - frameType == FrameType::REQUEST_RESPONSE || - frameType == FrameType::REQUEST_FNF; -} - } // namespace rsocket diff --git a/rsocket/framing/Frame.h b/rsocket/framing/Frame.h index 4a66aa57b..8de331f1a 100644 --- a/rsocket/framing/Frame.h +++ b/rsocket/framing/Frame.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -14,7 +26,8 @@ #include "rsocket/framing/FrameFlags.h" #include "rsocket/framing/FrameHeader.h" #include "rsocket/framing/FrameType.h" -#include "rsocket/internal/Common.h" +#include "rsocket/framing/ProtocolVersion.h" +#include "rsocket/framing/ResumeIdentificationToken.h" namespace folly { template @@ -22,11 +35,22 @@ class Optional; namespace io { class Cursor; class QueueAppender; -} -} +} // namespace io +} // namespace folly namespace rsocket { +namespace detail { + +FrameFlags getFlags(const Payload&); + +void checkFlags(const Payload&, FrameFlags); + +} // namespace detail + +using ResumePosition = int64_t; +constexpr ResumePosition kUnspecifiedResumePosition = -1; + /// Frames do not form hierarchy, as we never perform type erasure on a frame. /// We use inheritance only to save code duplication. /// @@ -46,7 +70,7 @@ class Frame_REQUEST_N { Frame_REQUEST_N() = default; Frame_REQUEST_N(StreamId streamId, uint32_t requestN) - : header_(FrameType::REQUEST_N, FrameFlags::EMPTY, streamId), + : header_(FrameType::REQUEST_N, FrameFlags::EMPTY_, streamId), requestN_(requestN) { DCHECK(requestN_ > 0); DCHECK(requestN_ <= kMaxRequestN); @@ -66,12 +90,10 @@ class Frame_REQUEST_Base { FrameFlags flags, uint32_t requestN, Payload payload) - : header_(frameType, flags | payload.getFlags(), streamId), + : header_(frameType, flags | detail::getFlags(payload), streamId), requestN_(requestN), payload_(std::move(payload)) { - // to verify the client didn't set - // METADATA and provided none - payload_.checkFlags(header_.flags); + detail::checkFlags(payload_, header_.flags); // TODO: DCHECK(requestN_ > 0); DCHECK(requestN_ <= Frame_REQUEST_N::kMaxRequestN); } @@ -155,11 +177,10 @@ class Frame_REQUEST_RESPONSE { Frame_REQUEST_RESPONSE(StreamId streamId, FrameFlags flags, Payload payload) : header_( FrameType::REQUEST_RESPONSE, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), streamId), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); } FrameHeader header_; @@ -176,11 +197,10 @@ class Frame_REQUEST_FNF { Frame_REQUEST_FNF(StreamId streamId, FrameFlags flags, Payload payload) : header_( FrameType::REQUEST_FNF, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), streamId), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); } FrameHeader header_; @@ -206,7 +226,7 @@ class Frame_CANCEL { public: Frame_CANCEL() = default; explicit Frame_CANCEL(StreamId streamId) - : header_(FrameType::CANCEL, FrameFlags::EMPTY, streamId) {} + : header_(FrameType::CANCEL, FrameFlags::EMPTY_, streamId) {} FrameHeader header_; }; @@ -221,11 +241,10 @@ class Frame_PAYLOAD { Frame_PAYLOAD(StreamId streamId, FrameFlags flags, Payload payload) : header_( FrameType::PAYLOAD, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), streamId), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); } static Frame_PAYLOAD complete(StreamId streamId); @@ -241,26 +260,27 @@ class Frame_ERROR { Frame_ERROR() = default; Frame_ERROR(StreamId streamId, ErrorCode errorCode, Payload payload) - : header_(FrameType::ERROR, payload.getFlags(), streamId), + : header_(FrameType::ERROR, detail::getFlags(payload), streamId), errorCode_(errorCode), payload_(std::move(payload)) {} // Connection errors. - static Frame_ERROR invalidSetup(std::string); - static Frame_ERROR unsupportedSetup(std::string); - static Frame_ERROR rejectedSetup(std::string); - static Frame_ERROR rejectedResume(std::string); - static Frame_ERROR connectionError(std::string); + static Frame_ERROR invalidSetup(folly::StringPiece); + static Frame_ERROR unsupportedSetup(folly::StringPiece); + static Frame_ERROR rejectedSetup(folly::StringPiece); + static Frame_ERROR rejectedResume(folly::StringPiece); + static Frame_ERROR connectionError(folly::StringPiece); // Stream errors. - static Frame_ERROR applicationError(StreamId, std::string); - static Frame_ERROR rejected(StreamId, std::string); - static Frame_ERROR canceled(StreamId, std::string); - static Frame_ERROR invalid(StreamId, std::string); + static Frame_ERROR applicationError(StreamId, folly::StringPiece); + static Frame_ERROR applicationError(StreamId, Payload&&); + static Frame_ERROR rejected(StreamId, folly::StringPiece); + static Frame_ERROR canceled(StreamId, folly::StringPiece); + static Frame_ERROR invalid(StreamId, folly::StringPiece); private: - static Frame_ERROR connectionErr(ErrorCode, std::string); - static Frame_ERROR streamErr(ErrorCode, std::string, StreamId); + static Frame_ERROR connectionErr(ErrorCode, folly::StringPiece); + static Frame_ERROR streamErr(ErrorCode, folly::StringPiece, StreamId); public: FrameHeader header_; @@ -314,7 +334,7 @@ class Frame_SETUP { Payload payload) : header_( FrameType::SETUP, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), 0), versionMajor_(versionMajor), versionMinor_(versionMinor), @@ -324,8 +344,7 @@ class Frame_SETUP { metadataMimeType_(metadataMimeType), dataMimeType_(dataMimeType), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); DCHECK(keepaliveTime_ > 0); DCHECK(maxLifetime_ > 0); DCHECK(keepaliveTime_ <= kMaxKeepaliveTime); @@ -361,7 +380,7 @@ class Frame_LEASE { std::unique_ptr metadata = std::unique_ptr()) : header_( FrameType::LEASE, - metadata ? FrameFlags::METADATA : FrameFlags::EMPTY, + metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_, 0), ttl_(ttl), numberOfRequests_(numberOfRequests), @@ -388,7 +407,7 @@ class Frame_RESUME { ResumePosition lastReceivedServerPosition, ResumePosition clientPosition, ProtocolVersion protocolVersion) - : header_(FrameType::RESUME, FrameFlags::EMPTY, 0), + : header_(FrameType::RESUME, FrameFlags::EMPTY_, 0), versionMajor_(protocolVersion.major), versionMinor_(protocolVersion.minor), token_(token), @@ -409,15 +428,12 @@ class Frame_RESUME_OK { public: Frame_RESUME_OK() = default; explicit Frame_RESUME_OK(ResumePosition position) - : header_(FrameType::RESUME_OK, FrameFlags::EMPTY, 0), + : header_(FrameType::RESUME_OK, FrameFlags::EMPTY_, 0), position_(position) {} FrameHeader header_; ResumePosition position_{}; }; std::ostream& operator<<(std::ostream&, const Frame_RESUME_OK&); -/// @} -StreamType getStreamType(FrameType frameType); -bool isNewStreamFrame(FrameType frameType); -} +} // namespace rsocket diff --git a/rsocket/framing/FrameFlags.cpp b/rsocket/framing/FrameFlags.cpp index df37ad40e..d95399aa1 100644 --- a/rsocket/framing/FrameFlags.cpp +++ b/rsocket/framing/FrameFlags.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameFlags.h" @@ -10,4 +22,4 @@ namespace rsocket { std::ostream& operator<<(std::ostream& os, FrameFlags flags) { return os << std::bitset<16>{raw(flags)}; } -} +} // namespace rsocket diff --git a/rsocket/framing/FrameFlags.h b/rsocket/framing/FrameFlags.h index bedea38c1..7ab7eacf7 100644 --- a/rsocket/framing/FrameFlags.h +++ b/rsocket/framing/FrameFlags.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -6,10 +18,11 @@ #include namespace rsocket { - enum class FrameFlags : uint16_t { - EMPTY = 0x000, - IGNORE = 0x200, + // Note that win32 defines EMPTY and IGNORE so we use a trailing + // underscore to avoid a collision + EMPTY_ = 0x000, + IGNORE_ = 0x200, METADATA = 0x100, // SETUP. @@ -57,6 +70,6 @@ constexpr FrameFlags operator~(FrameFlags a) { return static_cast(~raw(a)); } -std::ostream& operator<<(std::ostream&, FrameFlags); +std::ostream& operator<<(std::ostream& ostr, FrameFlags a); -} +} // namespace rsocket diff --git a/rsocket/framing/FrameHeader.cpp b/rsocket/framing/FrameHeader.cpp index 4ab055b18..3ee16dfca 100644 --- a/rsocket/framing/FrameHeader.cpp +++ b/rsocket/framing/FrameHeader.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameHeader.h" @@ -10,55 +22,64 @@ namespace rsocket { namespace { -constexpr auto kEmpty = "0x00"; -constexpr auto kMetadata = "METADATA"; -constexpr auto kResumeEnable = "RESUME_ENABLE"; -constexpr auto kLease = "LEASE"; -constexpr auto kKeepAliveRespond = "KEEPALIVE_RESPOND"; -constexpr auto kFollows = "FOLLOWS"; -constexpr auto kComplete = "COMPLETE"; -constexpr auto kNext = "NEXT"; +using FlagString = std::pair; -std::map>> - flagToNameMap{ - {FrameType::REQUEST_N, {}}, - {FrameType::REQUEST_RESPONSE, - {{FrameFlags::METADATA, kMetadata}, {FrameFlags::FOLLOWS, kFollows}}}, - {FrameType::REQUEST_FNF, - {{FrameFlags::METADATA, kMetadata}, {FrameFlags::FOLLOWS, kFollows}}}, - {FrameType::METADATA_PUSH, {}}, - {FrameType::CANCEL, {}}, - {FrameType::PAYLOAD, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}, - {FrameFlags::COMPLETE, kComplete}, - {FrameFlags::NEXT, kNext}}}, - {FrameType::ERROR, {{FrameFlags::METADATA, kMetadata}}}, - {FrameType::KEEPALIVE, - {{FrameFlags::KEEPALIVE_RESPOND, kKeepAliveRespond}}}, - {FrameType::SETUP, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::RESUME_ENABLE, kResumeEnable}, - {FrameFlags::LEASE, kLease}}}, - {FrameType::LEASE, {{FrameFlags::METADATA, kMetadata}}}, - {FrameType::RESUME, {}}, - {FrameType::REQUEST_CHANNEL, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}, - {FrameFlags::COMPLETE, kComplete}}}, - {FrameType::REQUEST_STREAM, - {{FrameFlags::METADATA, kMetadata}, {FrameFlags::FOLLOWS, kFollows}}}}; +constexpr std::array kMetadata = { + {std::make_pair(FrameFlags::METADATA, "METADATA")}}; +constexpr std::array kKeepaliveRespond = { + {std::make_pair(FrameFlags::KEEPALIVE_RESPOND, "KEEPALIVE_RESPOND")}}; +constexpr std::array kMetadataFollows = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::FOLLOWS, "FOLLOWS")}}; +constexpr std::array kMetadataResumeEnableLease = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::RESUME_ENABLE, "RESUME_ENABLE"), + std::make_pair(FrameFlags::LEASE, "LEASE")}}; +constexpr std::array kMetadataFollowsComplete = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::FOLLOWS, "FOLLOWS"), + std::make_pair(FrameFlags::COMPLETE, "COMPLETE")}}; +constexpr std::array kMetadataFollowsCompleteNext = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::FOLLOWS, "FOLLOWS"), + std::make_pair(FrameFlags::COMPLETE, "COMPLETE"), + std::make_pair(FrameFlags::NEXT, "NEXT")}}; + +template +constexpr auto toRange(const std::array& arr) { + return folly::Range{arr.data(), arr.size()}; +} + +// constexpr -- Old versions of C++ compiler doesn't support +// compound-statements in constexpr function (no switch statement) +folly::Range allowedFlags(FrameType type) { + switch (type) { + case FrameType::SETUP: + return toRange(kMetadataResumeEnableLease); + case FrameType::LEASE: + case FrameType::ERROR: + return toRange(kMetadata); + case FrameType::KEEPALIVE: + return toRange(kKeepaliveRespond); + case FrameType::REQUEST_RESPONSE: + case FrameType::REQUEST_FNF: + case FrameType::REQUEST_STREAM: + return toRange(kMetadataFollows); + case FrameType::REQUEST_CHANNEL: + return toRange(kMetadataFollowsComplete); + case FrameType::PAYLOAD: + return toRange(kMetadataFollowsCompleteNext); + default: + return {}; + } +} std::ostream& writeFlags(std::ostream& os, FrameFlags frameFlags, FrameType frameType) { - FrameFlags foundFlags = FrameFlags::EMPTY; + FrameFlags foundFlags = FrameFlags::EMPTY_; - // Search the corresponding string value for each flag, insert the missing - // ones as empty - auto const& allowedFlags = flagToNameMap[frameType]; - - std::string delimiter = ""; - for (const auto& pair : allowedFlags) { + std::string delimiter; + for (const auto& pair : allowedFlags(frameType)) { if (!!(frameFlags & pair.first)) { os << delimiter << pair.second; delimiter = "|"; @@ -69,15 +90,17 @@ writeFlags(std::ostream& os, FrameFlags frameFlags, FrameType frameType) { if (foundFlags != frameFlags) { os << frameFlags; } else if (delimiter.empty()) { - os << kEmpty; + os << "0x00"; } return os; } -} + +} // namespace std::ostream& operator<<(std::ostream& os, const FrameHeader& header) { os << header.type << "["; return writeFlags(os, header.flags, header.type) << ", " << header.streamId << "]"; } -} + +} // namespace rsocket diff --git a/rsocket/framing/FrameHeader.h b/rsocket/framing/FrameHeader.h index 0b4eb2900..cb67c895b 100644 --- a/rsocket/framing/FrameHeader.h +++ b/rsocket/framing/FrameHeader.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -10,6 +22,7 @@ namespace rsocket { +/// Header that begins every RSocket frame. class FrameHeader { public: FrameHeader() {} @@ -25,10 +38,15 @@ class FrameHeader { return !!(flags & FrameFlags::NEXT); } + bool flagsFollows() const { + return !!(flags & FrameFlags::FOLLOWS); + } + FrameType type{FrameType::RESERVED}; - FrameFlags flags{FrameFlags::EMPTY}; + FrameFlags flags{FrameFlags::EMPTY_}; StreamId streamId{0}; }; std::ostream& operator<<(std::ostream&, const FrameHeader&); -} + +} // namespace rsocket diff --git a/rsocket/framing/FrameProcessor.h b/rsocket/framing/FrameProcessor.h index e66de9cb0..70c5eae3e 100644 --- a/rsocket/framing/FrameProcessor.h +++ b/rsocket/framing/FrameProcessor.h @@ -1,13 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "rsocket/internal/Common.h" - -namespace folly { -class IOBuf; -class exception_wrapper; -} +#include +#include namespace rsocket { @@ -19,4 +27,4 @@ class FrameProcessor { virtual void onTerminal(folly::exception_wrapper) = 0; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer.cpp b/rsocket/framing/FrameSerializer.cpp index 72a1958b7..92904944b 100644 --- a/rsocket/framing/FrameSerializer.cpp +++ b/rsocket/framing/FrameSerializer.cpp @@ -1,52 +1,25 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameSerializer.h" - -#include -#include - -#include "rsocket/framing/FrameSerializer_v0.h" -#include "rsocket/framing/FrameSerializer_v0_1.h" #include "rsocket/framing/FrameSerializer_v1_0.h" -DEFINE_string( - rs_use_protocol_version, - "", - "override for the ReactiveSocket protocol version to be used" - " [MAJOR.MINOR]."); - namespace rsocket { -constexpr const ProtocolVersion ProtocolVersion::Latest = - FrameSerializerV1_0::Version; - -ProtocolVersion ProtocolVersion::Current() { - if (FLAGS_rs_use_protocol_version.empty()) { - return ProtocolVersion::Latest; - } - - if (FLAGS_rs_use_protocol_version == "*") { - return ProtocolVersion::Unknown; - } - - if (FLAGS_rs_use_protocol_version.size() != 3) { - LOG(ERROR) << "unknown protocol version " << FLAGS_rs_use_protocol_version - << " defaulting to v" << ProtocolVersion::Latest; - return ProtocolVersion::Latest; - } - - return ProtocolVersion( - folly::to(FLAGS_rs_use_protocol_version[0] - '0'), - folly::to(FLAGS_rs_use_protocol_version[2] - '0')); -} - std::unique_ptr FrameSerializer::createFrameSerializer( const ProtocolVersion& protocolVersion) { - if (protocolVersion == FrameSerializerV0::Version) { - return std::make_unique(); - } else if (protocolVersion == FrameSerializerV0_1::Version) { - return std::make_unique(); - } else if (protocolVersion == FrameSerializerV1_0::Version) { + if (protocolVersion == FrameSerializerV1_0::Version) { return std::make_unique(); } @@ -59,13 +32,34 @@ std::unique_ptr FrameSerializer::createFrameSerializer( std::unique_ptr FrameSerializer::createAutodetectedSerializer( const folly::IOBuf& firstFrame) { auto detectedVersion = FrameSerializerV1_0::detectProtocolVersion(firstFrame); - if (detectedVersion == ProtocolVersion::Unknown) { - detectedVersion = FrameSerializerV0_1::detectProtocolVersion(firstFrame); - } return createFrameSerializer(detectedVersion); } -std::ostream& operator<<(std::ostream& os, const ProtocolVersion& version) { - return os << version.major << "." << version.minor; +bool& FrameSerializer::preallocateFrameSizeField() { + return preallocateFrameSizeField_; } + +folly::IOBufQueue FrameSerializer::createBufferQueue(size_t bufferSize) const { + const auto prependSize = + preallocateFrameSizeField_ ? frameLengthFieldSize() : 0; + auto buf = folly::IOBuf::createCombined(bufferSize + prependSize); + buf->advance(prependSize); + folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); + queue.append(std::move(buf)); + return queue; } + +folly::Optional FrameSerializer::peekStreamId( + const ProtocolVersion& protocolVersion, + const folly::IOBuf& frame, + bool skipFrameLengthBytes) { + if (protocolVersion == FrameSerializerV1_0::Version) { + return FrameSerializerV1_0().peekStreamId(frame, skipFrameLengthBytes); + } + + auto* msg = "unknown protocol version"; + DCHECK(false) << msg; + return folly::none; +} + +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer.h b/rsocket/framing/FrameSerializer.h index 787225f56..7ee0bafae 100644 --- a/rsocket/framing/FrameSerializer.h +++ b/rsocket/framing/FrameSerializer.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -15,7 +27,7 @@ class FrameSerializer { public: virtual ~FrameSerializer() = default; - virtual ProtocolVersion protocolVersion() = 0; + virtual ProtocolVersion protocolVersion() const = 0; static std::unique_ptr createFrameSerializer( const ProtocolVersion& protocolVersion); @@ -23,65 +35,81 @@ class FrameSerializer { static std::unique_ptr createAutodetectedSerializer( const folly::IOBuf& firstFrame); - virtual FrameType peekFrameType(const folly::IOBuf& in) = 0; - virtual folly::Optional peekStreamId(const folly::IOBuf& in) = 0; + static folly::Optional peekStreamId( + const ProtocolVersion& protocolVersion, + const folly::IOBuf& frame, + bool skipFrameLengthBytes); + virtual FrameType peekFrameType(const folly::IOBuf& in) const = 0; + virtual folly::Optional peekStreamId( + const folly::IOBuf& in, + bool skipFrameLengthBytes) const = 0; + + virtual std::unique_ptr serializeOut( + Frame_REQUEST_STREAM&&) const = 0; + virtual std::unique_ptr serializeOut( + Frame_REQUEST_CHANNEL&&) const = 0; + virtual std::unique_ptr serializeOut( + Frame_REQUEST_RESPONSE&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_REQUEST_STREAM&&) = 0; + Frame_REQUEST_FNF&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_REQUEST_CHANNEL&&) = 0; + Frame_REQUEST_N&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_REQUEST_RESPONSE&&) = 0; - virtual std::unique_ptr serializeOut(Frame_REQUEST_FNF&&) = 0; - virtual std::unique_ptr serializeOut(Frame_REQUEST_N&&) = 0; - virtual std::unique_ptr serializeOut(Frame_METADATA_PUSH&&) = 0; - virtual std::unique_ptr serializeOut(Frame_CANCEL&&) = 0; - virtual std::unique_ptr serializeOut(Frame_PAYLOAD&&) = 0; - virtual std::unique_ptr serializeOut(Frame_ERROR&&) = 0; + Frame_METADATA_PUSH&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_CANCEL&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_PAYLOAD&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_ERROR&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_KEEPALIVE&&, - bool) = 0; - virtual std::unique_ptr serializeOut(Frame_SETUP&&) = 0; - virtual std::unique_ptr serializeOut(Frame_LEASE&&) = 0; - virtual std::unique_ptr serializeOut(Frame_RESUME&&) = 0; - virtual std::unique_ptr serializeOut(Frame_RESUME_OK&&) = 0; + Frame_KEEPALIVE&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_SETUP&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_LEASE&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_RESUME&&) const = 0; + virtual std::unique_ptr serializeOut( + Frame_RESUME_OK&&) const = 0; virtual bool deserializeFrom( Frame_REQUEST_STREAM&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; virtual bool deserializeFrom( Frame_REQUEST_CHANNEL&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; virtual bool deserializeFrom( Frame_REQUEST_RESPONSE&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; virtual bool deserializeFrom( Frame_REQUEST_FNF&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_REQUEST_N&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; + virtual bool deserializeFrom(Frame_REQUEST_N&, std::unique_ptr) + const = 0; virtual bool deserializeFrom( Frame_METADATA_PUSH&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_CANCEL&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_PAYLOAD&, - std::unique_ptr) = 0; - virtual bool deserializeFrom(Frame_ERROR&, std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_KEEPALIVE&, - std::unique_ptr, - bool supportsResumability) = 0; - virtual bool deserializeFrom(Frame_SETUP&, std::unique_ptr) = 0; - virtual bool deserializeFrom(Frame_LEASE&, std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_RESUME&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_RESUME_OK&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; + virtual bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_ERROR&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_SETUP&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_LEASE&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_RESUME&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_RESUME_OK&, std::unique_ptr) + const = 0; + + virtual size_t frameLengthFieldSize() const = 0; + bool& preallocateFrameSizeField(); + + protected: + folly::IOBufQueue createBufferQueue(size_t bufferSize) const; + + private: + bool preallocateFrameSizeField_{false}; }; -} + +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer_v0.cpp b/rsocket/framing/FrameSerializer_v0.cpp deleted file mode 100644 index c7de24e08..000000000 --- a/rsocket/framing/FrameSerializer_v0.cpp +++ /dev/null @@ -1,785 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/framing/FrameSerializer_v0.h" - -#include - -namespace rsocket { - -constexpr const ProtocolVersion FrameSerializerV0::Version; -constexpr const size_t FrameSerializerV0::kFrameHeaderSize; // bytes - -namespace { -constexpr static const auto kMaxMetadataLength = 0xFFFFFF; // 24bit max value - -enum class FrameType_V0 : uint16_t { - RESERVED = 0x0000, - SETUP = 0x0001, - LEASE = 0x0002, - KEEPALIVE = 0x0003, - REQUEST_RESPONSE = 0x0004, - REQUEST_FNF = 0x0005, - REQUEST_STREAM = 0x0006, - REQUEST_SUB = 0x0007, - REQUEST_CHANNEL = 0x0008, - REQUEST_N = 0x0009, - CANCEL = 0x000A, - RESPONSE = 0x000B, - ERROR = 0x000C, - METADATA_PUSH = 0x000D, - RESUME = 0x000E, - RESUME_OK = 0x000F, - EXT = 0xFFFF, -}; - -enum class FrameFlags_V0 : uint16_t { - EMPTY = 0x0000, - IGNORE = 0x8000, - METADATA = 0x4000, - - FOLLOWS = 0x2000, - KEEPALIVE_RESPOND = 0x2000, - LEASE = 0x2000, - COMPLETE = 0x1000, - RESUME_ENABLE = 0x0800, -}; - -constexpr inline FrameFlags_V0 operator&(FrameFlags_V0 a, FrameFlags_V0 b) { - return static_cast( - static_cast(a) & static_cast(b)); -} - -inline uint16_t& operator|=(uint16_t& a, FrameFlags_V0 b) { - return (a |= static_cast(b)); -} - -constexpr inline bool operator!(FrameFlags_V0 a) { - return !static_cast(a); -} -} // namespace - -static folly::IOBufQueue createBufferQueue(size_t bufferSize) { - auto buf = folly::IOBuf::createCombined(bufferSize); - folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); - queue.append(std::move(buf)); - return queue; -} - -ProtocolVersion FrameSerializerV0::protocolVersion() { - return Version; -} - -static uint16_t serializeFrameType(FrameType frameType) { - switch (frameType) { - case FrameType::RESERVED: - case FrameType::SETUP: - case FrameType::LEASE: - case FrameType::KEEPALIVE: - case FrameType::REQUEST_RESPONSE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_STREAM: - return static_cast(frameType); - - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_N: - case FrameType::CANCEL: - case FrameType::PAYLOAD: - case FrameType::ERROR: - case FrameType::METADATA_PUSH: - case FrameType::RESUME: - case FrameType::RESUME_OK: - return static_cast(frameType) + 1; - - case FrameType::EXT: - return static_cast(FrameType_V0::EXT); - - default: - CHECK(false); - return 0; - } -} - -static FrameType deserializeFrameType(uint16_t frameType) { - if (frameType > static_cast(FrameType_V0::RESUME_OK) && - frameType != static_cast(FrameType_V0::EXT)) { - return FrameType::RESERVED; - } - - switch (static_cast(frameType)) { - case FrameType_V0::RESERVED: - case FrameType_V0::SETUP: - case FrameType_V0::LEASE: - case FrameType_V0::KEEPALIVE: - case FrameType_V0::REQUEST_RESPONSE: - case FrameType_V0::REQUEST_FNF: - case FrameType_V0::REQUEST_STREAM: - return static_cast(frameType); - - case FrameType_V0::REQUEST_SUB: - return FrameType::REQUEST_STREAM; - - case FrameType_V0::REQUEST_CHANNEL: - case FrameType_V0::REQUEST_N: - case FrameType_V0::CANCEL: - case FrameType_V0::RESPONSE: - case FrameType_V0::ERROR: - case FrameType_V0::METADATA_PUSH: - case FrameType_V0::RESUME: - case FrameType_V0::RESUME_OK: - return static_cast(frameType - 1); - - case FrameType_V0::EXT: - return FrameType::EXT; - - default: - CHECK(false); - return FrameType::RESERVED; - } -} - -static uint16_t serializeFrameFlags(FrameFlags frameType) { - uint16_t result = 0; - if (!!(frameType & FrameFlags::IGNORE)) { - result |= FrameFlags_V0::IGNORE; - } - if (!!(frameType & FrameFlags::METADATA)) { - result |= FrameFlags_V0::METADATA; - } - return result; -} - -static FrameFlags deserializeFrameFlags(FrameFlags_V0 flags) { - FrameFlags result = FrameFlags::EMPTY; - - if (!!(flags & FrameFlags_V0::IGNORE)) { - result |= FrameFlags::IGNORE; - } - if (!!(flags & FrameFlags_V0::METADATA)) { - result |= FrameFlags::METADATA; - } - return result; -} - -static void serializeHeaderInto( - folly::io::QueueAppender& appender, - const FrameHeader& header, - uint16_t extraFlags) { - appender.writeBE(serializeFrameType(header.type)); - appender.writeBE(serializeFrameFlags(header.flags) | extraFlags); - appender.writeBE(header.streamId); -} - -static void deserializeHeaderFrom( - folly::io::Cursor& cur, - FrameHeader& header, - FrameFlags_V0& flags) { - header.type = deserializeFrameType(cur.readBE()); - - flags = static_cast(cur.readBE()); - header.flags = deserializeFrameFlags(flags); - - header.streamId = cur.readBE(); -} - -static void serializeMetadataInto( - folly::io::QueueAppender& appender, - std::unique_ptr metadata) { - if (metadata == nullptr) { - return; - } - - // Use signed int because the first bit in metadata length is reserved. - if (metadata->length() >= kMaxMetadataLength - sizeof(uint32_t)) { - CHECK(false) << "Metadata is too big to serialize"; - } - - appender.writeBE( - static_cast(metadata->length() + sizeof(uint32_t))); - appender.insert(std::move(metadata)); -} - -std::unique_ptr FrameSerializerV0::deserializeMetadataFrom( - folly::io::Cursor& cur, - FrameFlags flags) { - if (!(flags & FrameFlags::METADATA)) { - return nullptr; - } - - const auto length = cur.readBE(); - - if (length >= kMaxMetadataLength) { - throw std::runtime_error("Metadata is too big to deserialize"); - } - - if (length <= sizeof(uint32_t)) { - throw std::runtime_error("Metadata is too small to encode its size"); - } - - const auto metadataPayloadLength = - length - static_cast(sizeof(uint32_t)); - - // TODO: Check if metadataPayloadLength exceeds frame length minus frame - // header size. - - std::unique_ptr metadata; - cur.clone(metadata, metadataPayloadLength); - return metadata; -} - -static std::unique_ptr deserializeDataFrom( - folly::io::Cursor& cur) { - std::unique_ptr data; - auto totalLength = cur.totalLength(); - - if (totalLength > 0) { - cur.clone(data, totalLength); - } - return data; -} - -static Payload deserializePayloadFrom( - folly::io::Cursor& cur, - FrameFlags flags) { - auto metadata = FrameSerializerV0::deserializeMetadataFrom(cur, flags); - auto data = deserializeDataFrom(cur); - return Payload(std::move(data), std::move(metadata)); -} - -static void serializePayloadInto( - folly::io::QueueAppender& appender, - Payload&& payload) { - serializeMetadataInto(appender, std::move(payload.metadata)); - if (payload.data) { - appender.insert(std::move(payload.data)); - } -} - -static uint32_t payloadFramingSize(const Payload& payload) { - return (payload.metadata != nullptr ? sizeof(uint32_t) : 0); -} - -static std::unique_ptr serializeOutInternal( - Frame_REQUEST_Base&& frame) { - auto queue = createBufferQueue( - FrameSerializerV0::kFrameHeaderSize + sizeof(uint32_t) + - payloadFramingSize(frame.payload_)); - uint16_t extraFlags = 0; - if (!!(frame.header_.flags & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - if (!!(frame.header_.flags & FrameFlags::COMPLETE)) { - extraFlags |= FrameFlags_V0::COMPLETE; - } - - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - - appender.writeBE(frame.requestN_); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -static bool deserializeFromInternal( - Frame_REQUEST_Base& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags |= FrameFlags::FOLLOWS; - } - if (!!(flags & FrameFlags_V0::COMPLETE)) { - frame.header_.flags |= FrameFlags::COMPLETE; - } - - frame.requestN_ = cur.readBE(); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -FrameType FrameSerializerV0::peekFrameType(const folly::IOBuf& in) { - folly::io::Cursor cur(&in); - try { - return deserializeFrameType(cur.readBE()); - } catch (...) { - return FrameType::RESERVED; - } -} - -folly::Optional FrameSerializerV0::peekStreamId( - const folly::IOBuf& in) { - folly::io::Cursor cur(&in); - try { - cur.skip(sizeof(uint16_t)); // type - cur.skip(sizeof(uint16_t)); // flags - return folly::make_optional(cur.readBE()); - } catch (...) { - return folly::none; - } -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_STREAM&& frame) { - return serializeOutInternal(std::move(frame)); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_CHANNEL&& frame) { - return serializeOutInternal(std::move(frame)); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_RESPONSE&& frame) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - - auto queue = - createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_FNF&& frame) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - - auto queue = - createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_N&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(uint32_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(frame.requestN_); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_METADATA_PUSH&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(uint32_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - serializeMetadataInto(appender, std::move(frame.metadata_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_CANCEL&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_PAYLOAD&& frame) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - if (!!(frame.header_.flags & FrameFlags::COMPLETE)) { - extraFlags |= FrameFlags_V0::COMPLETE; - } - - auto queue = - createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_ERROR&& frame) { - auto queue = createBufferQueue( - kFrameHeaderSize + sizeof(uint32_t) + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(static_cast(frame.errorCode_)); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_KEEPALIVE&& frame, - bool resumeable) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags & FrameFlags::KEEPALIVE_RESPOND)) { - extraFlags |= FrameFlags_V0::KEEPALIVE_RESPOND; - } - - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (resumeable) { - appender.writeBE(frame.position_); - } - if (frame.data_) { - appender.insert(std::move(frame.data_)); - } - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_SETUP&& frame) { - auto queue = createBufferQueue( - kFrameHeaderSize + 3 * sizeof(uint32_t) + frame.token_.data().size() + 2 + - frame.metadataMimeType_.length() + frame.dataMimeType_.length() + - payloadFramingSize(frame.payload_)); - uint16_t extraFlags = 0; - if (!!(frame.header_.flags & FrameFlags::RESUME_ENABLE)) { - extraFlags |= FrameFlags_V0::RESUME_ENABLE; - } - if (!!(frame.header_.flags & FrameFlags::LEASE)) { - extraFlags |= FrameFlags_V0::LEASE; - } - - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - - serializeHeaderInto(appender, frame.header_, extraFlags); - CHECK( - frame.versionMajor_ != ProtocolVersion::Unknown.major || - frame.versionMinor_ != ProtocolVersion::Unknown.minor); - appender.writeBE(static_cast(frame.versionMajor_)); - appender.writeBE(static_cast(frame.versionMinor_)); - appender.writeBE(static_cast(frame.keepaliveTime_)); - appender.writeBE(static_cast(frame.maxLifetime_)); - - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (!!(frame.header_.flags & FrameFlags::RESUME_ENABLE)) { - appender.push(frame.token_.data().data(), frame.token_.data().size()); - } - - CHECK( - frame.metadataMimeType_.length() <= std::numeric_limits::max()); - appender.writeBE(static_cast(frame.metadataMimeType_.length())); - appender.push( - reinterpret_cast(frame.metadataMimeType_.data()), - frame.metadataMimeType_.length()); - - CHECK(frame.dataMimeType_.length() <= std::numeric_limits::max()); - appender.writeBE(static_cast(frame.dataMimeType_.length())); - appender.push( - reinterpret_cast(frame.dataMimeType_.data()), - frame.dataMimeType_.length()); - - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_LEASE&& frame) { - auto queue = createBufferQueue( - kFrameHeaderSize + 3 * 2 * sizeof(uint32_t) + - (frame.metadata_ ? sizeof(uint32_t) : 0)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(static_cast(frame.ttl_)); - appender.writeBE(static_cast(frame.numberOfRequests_)); - serializeMetadataInto(appender, std::move(frame.metadata_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_RESUME&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + 16 + sizeof(int64_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - CHECK(frame.token_.data().size() <= 16); - appender.push(frame.token_.data().data(), frame.token_.data().size()); - appender.writeBE(frame.lastReceivedServerPosition_); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_RESUME_OK&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(frame.position_); - return queue.move(); -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_STREAM& frame, - std::unique_ptr in) { - return deserializeFromInternal(frame, std::move(in)); -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_CHANNEL& frame, - std::unique_ptr in) { - return deserializeFromInternal(frame, std::move(in)); -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_RESPONSE& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags |= FrameFlags::FOLLOWS; - } - - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_FNF& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags |= FrameFlags::FOLLOWS; - } - - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_N& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.requestN_ = cur.readBE(); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_METADATA_PUSH& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.metadata_ = deserializeMetadataFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return frame.metadata_ != nullptr; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_CANCEL& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_PAYLOAD& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags |= FrameFlags::FOLLOWS; - } - if (!!(flags & FrameFlags_V0::COMPLETE)) { - frame.header_.flags |= FrameFlags::COMPLETE; - } - - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_ERROR& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.errorCode_ = static_cast(cur.readBE()); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_KEEPALIVE& frame, - std::unique_ptr in, - bool resumable) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::KEEPALIVE_RESPOND)) { - frame.header_.flags |= FrameFlags::KEEPALIVE_RESPOND; - } - - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (resumable) { - frame.position_ = cur.readBE(); - } else { - frame.position_ = 0; - } - frame.data_ = deserializeDataFrom(cur); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_SETUP& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::RESUME_ENABLE)) { - frame.header_.flags |= FrameFlags::RESUME_ENABLE; - } - if (!!(flags & FrameFlags_V0::LEASE)) { - frame.header_.flags |= FrameFlags::LEASE; - } - - frame.versionMajor_ = cur.readBE(); - frame.versionMinor_ = cur.readBE(); - - auto keepaliveTime = cur.readBE(); - if (keepaliveTime <= 0) { - return false; - } - frame.keepaliveTime_ = - std::min(keepaliveTime, Frame_SETUP::kMaxKeepaliveTime); - - auto maxLifetime = cur.readBE(); - if (maxLifetime <= 0) { - return false; - } - frame.maxLifetime_ = - std::min(maxLifetime, Frame_SETUP::kMaxLifetime); - - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (!!(frame.header_.flags & FrameFlags::RESUME_ENABLE)) { - std::vector data(16); - cur.pull(data.data(), data.size()); - frame.token_.set(std::move(data)); - } else { - frame.token_ = ResumeIdentificationToken(); - } - - auto mdmtLen = cur.readBE(); - frame.metadataMimeType_ = cur.readFixedString(mdmtLen); - - auto dmtLen = cur.readBE(); - frame.dataMimeType_ = cur.readFixedString(dmtLen); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_LEASE& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.ttl_ = std::min(cur.readBE(), Frame_LEASE::kMaxTtl); - frame.numberOfRequests_ = - std::min(cur.readBE(), Frame_LEASE::kMaxNumRequests); - frame.metadata_ = deserializeMetadataFrom(cur, frame.header_.flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_RESUME& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - std::vector data(16); - cur.pull(data.data(), data.size()); - auto protocolVer = protocolVersion(); - frame.versionMajor_ = protocolVer.major; - frame.versionMinor_ = protocolVer.minor; - frame.token_.set(std::move(data)); - frame.lastReceivedServerPosition_ = cur.readBE(); - frame.clientPosition_ = kUnspecifiedResumePosition; - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_RESUME_OK& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.position_ = cur.readBE(); - } catch (...) { - return false; - } - return true; -} - -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v0.h b/rsocket/framing/FrameSerializer_v0.h deleted file mode 100644 index 74e356351..000000000 --- a/rsocket/framing/FrameSerializer_v0.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/framing/FrameSerializer.h" - -namespace rsocket { - -class FrameSerializerV0 : public FrameSerializer { - public: - constexpr static const ProtocolVersion Version = ProtocolVersion(0, 0); - constexpr static const size_t kFrameHeaderSize = 8; // bytes - - ProtocolVersion protocolVersion() override; - - FrameType peekFrameType(const folly::IOBuf& in) override; - folly::Optional peekStreamId(const folly::IOBuf& in) override; - - std::unique_ptr serializeOut(Frame_REQUEST_STREAM&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_CHANNEL&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_RESPONSE&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_FNF&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_N&&) override; - std::unique_ptr serializeOut(Frame_METADATA_PUSH&&) override; - std::unique_ptr serializeOut(Frame_CANCEL&&) override; - std::unique_ptr serializeOut(Frame_PAYLOAD&&) override; - std::unique_ptr serializeOut(Frame_ERROR&&) override; - std::unique_ptr serializeOut(Frame_KEEPALIVE&&, bool) override; - std::unique_ptr serializeOut(Frame_SETUP&&) override; - std::unique_ptr serializeOut(Frame_LEASE&&) override; - std::unique_ptr serializeOut(Frame_RESUME&&) override; - std::unique_ptr serializeOut(Frame_RESUME_OK&&) override; - - bool deserializeFrom(Frame_REQUEST_STREAM&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_CHANNEL&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_RESPONSE&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_FNF&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_N&, std::unique_ptr) - override; - bool deserializeFrom(Frame_METADATA_PUSH&, std::unique_ptr) - override; - bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) override; - bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) override; - bool deserializeFrom(Frame_ERROR&, std::unique_ptr) override; - bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr, bool) - override; - bool deserializeFrom(Frame_SETUP&, std::unique_ptr) override; - bool deserializeFrom(Frame_LEASE&, std::unique_ptr) override; - bool deserializeFrom(Frame_RESUME&, std::unique_ptr) override; - bool deserializeFrom(Frame_RESUME_OK&, std::unique_ptr) - override; - - static std::unique_ptr deserializeMetadataFrom( - folly::io::Cursor& cur, - FrameFlags flags); -}; -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v0_1.cpp b/rsocket/framing/FrameSerializer_v0_1.cpp deleted file mode 100644 index b42322437..000000000 --- a/rsocket/framing/FrameSerializer_v0_1.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/framing/FrameSerializer_v0_1.h" - -#include - -namespace rsocket { - -constexpr const ProtocolVersion FrameSerializerV0_1::Version; -constexpr const size_t FrameSerializerV0_1::kMinBytesNeededForAutodetection; - -ProtocolVersion FrameSerializerV0_1::protocolVersion() { - return Version; -} - -ProtocolVersion FrameSerializerV0_1::detectProtocolVersion( - const folly::IOBuf& firstFrame, - size_t skipBytes) { - // SETUP frame - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Frame Type = SETUP |0|M|L|S| Flags | - // +-------------------------------+-+-+-+-+-----------------------+ - // | Stream ID = 0 | - // +-------------------------------+-------------------------------+ - // | Major Version | Minor Version | - // +-------------------------------+-------------------------------+ - // ... - // +-------------------------------+-------------------------------+ - - // RESUME frame - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Frame Type = RESUME | Flags | - // +-------------------------------+-------------------------------+ - // | Stream ID = 0 | - // +-------------------------------+-------------------------------+ - // | | - // | Resume Identification Token | - // | | - // | | - // +-------------------------------+-------------------------------+ - // | Resume Position | - // | | - // +-------------------------------+-------------------------------+ - - folly::io::Cursor cur(&firstFrame); - try { - cur.skip(skipBytes); - - auto frameType = cur.readBE(); - cur.skip(sizeof(uint16_t)); // flags - auto streamId = cur.readBE(); - - constexpr static const auto kSETUP = 0x0001; - constexpr static const auto kRESUME = 0x000E; - - VLOG(4) << "frameType=" << frameType << "streamId=" << streamId; - - if (frameType == kSETUP && streamId == 0) { - auto majorVersion = cur.readBE(); - auto minorVersion = cur.readBE(); - - VLOG(4) << "majorVersion=" << majorVersion - << " minorVersion=" << minorVersion; - - if (majorVersion == 0 && (minorVersion == 0 || minorVersion == 1)) { - return ProtocolVersion(majorVersion, minorVersion); - } - } else if (frameType == kRESUME && streamId == 0) { - return FrameSerializerV0_1::Version; - } - } catch (...) { - } - return ProtocolVersion::Unknown; -} - -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v0_1.h b/rsocket/framing/FrameSerializer_v0_1.h deleted file mode 100644 index c53774f8b..000000000 --- a/rsocket/framing/FrameSerializer_v0_1.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/framing/FrameSerializer_v0.h" - -namespace rsocket { - -class FrameSerializerV0_1 : public FrameSerializerV0 { - public: - constexpr static const ProtocolVersion Version = ProtocolVersion(0, 1); - constexpr static const size_t kMinBytesNeededForAutodetection = 12; // bytes - - static ProtocolVersion detectProtocolVersion( - const folly::IOBuf& firstFrame, - size_t skipBytes = 0); - - ProtocolVersion protocolVersion() override; -}; -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v1_0.cpp b/rsocket/framing/FrameSerializer_v1_0.cpp index 122e3515f..446246d11 100644 --- a/rsocket/framing/FrameSerializer_v1_0.cpp +++ b/rsocket/framing/FrameSerializer_v1_0.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameSerializer_v1_0.h" @@ -11,21 +23,14 @@ constexpr const size_t FrameSerializerV1_0::kFrameHeaderSize; constexpr const size_t FrameSerializerV1_0::kMinBytesNeededForAutodetection; namespace { -constexpr const auto kMedatadaLengthSize = 3; // bytes -constexpr const auto kMaxMetadataLength = 0xFFFFFF; // 24bit max value +constexpr const uint32_t kMedatadaLengthSize = 3u; // bytes +constexpr const uint32_t kMaxMetadataLength = 0xFFFFFFu; // 24bit max value } // namespace -ProtocolVersion FrameSerializerV1_0::protocolVersion() { +ProtocolVersion FrameSerializerV1_0::protocolVersion() const { return Version; } -static folly::IOBufQueue createBufferQueue(size_t bufferSize) { - auto buf = folly::IOBuf::createCombined(bufferSize); - folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); - queue.append(std::move(buf)); - return queue; -} - static FrameType deserializeFrameType(uint16_t frameType) { if (frameType > static_cast(FrameType::RESUME_OK) && frameType != static_cast(FrameType::EXT)) { @@ -64,13 +69,12 @@ static void serializeMetadataInto( return; } - // Use signed int because the first bit in metadata length is reserved. - if (metadata->length() > kMaxMetadataLength) { - CHECK(false) << "Metadata is too big to serialize"; - } - // metadata length field not included in the medatadata length - uint32_t metadataLength = static_cast(metadata->length()); + uint32_t metadataLength = + static_cast(metadata->computeChainDataLength()); + CHECK_LT(metadataLength, kMaxMetadataLength) + << "Metadata is too big to serialize"; + appender.write(static_cast(metadataLength >> 16)); // first byte appender.write( static_cast((metadataLength >> 8) & 0xFF)); // second byte @@ -91,9 +95,8 @@ std::unique_ptr FrameSerializerV1_0::deserializeMetadataFrom( metadataLength |= static_cast(cur.read() << 8); metadataLength |= cur.read(); - if (metadataLength > kMaxMetadataLength) { - throw std::runtime_error("Metadata is too big to deserialize"); - } + CHECK_LE(metadataLength, kMaxMetadataLength) + << "Read out the 24-bit integer incorrectly somehow"; std::unique_ptr metadata; cur.clone(metadata, metadataLength); @@ -132,8 +135,8 @@ static uint32_t payloadFramingSize(const Payload& payload) { return (payload.metadata != nullptr ? kMedatadaLengthSize : 0); } -static std::unique_ptr serializeOutInternal( - Frame_REQUEST_Base&& frame) { +std::unique_ptr FrameSerializerV1_0::serializeOutInternal( + Frame_REQUEST_Base&& frame) const { auto queue = createBufferQueue( FrameSerializerV1_0::kFrameHeaderSize + sizeof(uint32_t) + payloadFramingSize(frame.payload_)); @@ -174,7 +177,7 @@ static size_t getResumeIdTokenFramingLength( : 0; } -FrameType FrameSerializerV1_0::peekFrameType(const folly::IOBuf& in) { +FrameType FrameSerializerV1_0::peekFrameType(const folly::IOBuf& in) const { folly::io::Cursor cur(&in); try { cur.skip(sizeof(int32_t)); // streamId @@ -186,9 +189,13 @@ FrameType FrameSerializerV1_0::peekFrameType(const folly::IOBuf& in) { } folly::Optional FrameSerializerV1_0::peekStreamId( - const folly::IOBuf& in) { + const folly::IOBuf& in, + bool skipFrameLengthBytes) const { folly::io::Cursor cur(&in); try { + if (skipFrameLengthBytes) { + cur.skip(3); // skip 3 bytes for frame length + } auto streamId = cur.readBE(); if (streamId < 0) { return folly::none; @@ -200,17 +207,17 @@ folly::Optional FrameSerializerV1_0::peekStreamId( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_STREAM&& frame) { + Frame_REQUEST_STREAM&& frame) const { return serializeOutInternal(std::move(frame)); } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_CHANNEL&& frame) { + Frame_REQUEST_CHANNEL&& frame) const { return serializeOutInternal(std::move(frame)); } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_RESPONSE&& frame) { + Frame_REQUEST_RESPONSE&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -220,7 +227,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_FNF&& frame) { + Frame_REQUEST_FNF&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -230,7 +237,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_N&& frame) { + Frame_REQUEST_N&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(uint32_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -239,7 +246,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_METADATA_PUSH&& frame) { + Frame_METADATA_PUSH&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -250,7 +257,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_CANCEL&& frame) { + Frame_CANCEL&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -258,7 +265,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_PAYLOAD&& frame) { + Frame_PAYLOAD&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -268,7 +275,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_ERROR&& frame) { + Frame_ERROR&& frame) const { auto queue = createBufferQueue( kFrameHeaderSize + sizeof(uint32_t) + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -279,8 +286,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_KEEPALIVE&& frame, - bool /*resumeable*/) { + Frame_KEEPALIVE&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -292,7 +298,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_SETUP&& frame) { + Frame_SETUP&& frame) const { auto queue = createBufferQueue( kFrameHeaderSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(int32_t) + sizeof(int32_t) + @@ -334,7 +340,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_LEASE&& frame) { + Frame_LEASE&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int32_t) + sizeof(int32_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -348,7 +354,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_RESUME&& frame) { + Frame_RESUME&& frame) const { auto queue = createBufferQueue( kFrameHeaderSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t) + frame.token_.data().size() + sizeof(int32_t) + @@ -371,7 +377,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_RESUME_OK&& frame) { + Frame_RESUME_OK&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -381,19 +387,19 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_STREAM& frame, - std::unique_ptr in) { + std::unique_ptr in) const { return deserializeFromInternal(frame, std::move(in)); } bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_CHANNEL& frame, - std::unique_ptr in) { + std::unique_ptr in) const { return deserializeFromInternal(frame, std::move(in)); } bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_RESPONSE& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -406,7 +412,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_FNF& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -419,7 +425,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_N& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -436,7 +442,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_METADATA_PUSH& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -451,7 +457,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_CANCEL& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -463,7 +469,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_PAYLOAD& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -476,7 +482,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_ERROR& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -490,8 +496,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_KEEPALIVE& frame, - std::unique_ptr in, - bool /*resumable*/) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -509,7 +514,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_SETUP& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -552,7 +557,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_LEASE& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -577,7 +582,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_RESUME& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -609,7 +614,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_RESUME_OK& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -678,4 +683,8 @@ ProtocolVersion FrameSerializerV1_0::detectProtocolVersion( } return ProtocolVersion::Unknown; } + +size_t FrameSerializerV1_0::frameLengthFieldSize() const { + return 3; // bytes } +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer_v1_0.h b/rsocket/framing/FrameSerializer_v1_0.h index 4414339b2..f636584dd 100644 --- a/rsocket/framing/FrameSerializer_v1_0.h +++ b/rsocket/framing/FrameSerializer_v1_0.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,55 +24,74 @@ class FrameSerializerV1_0 : public FrameSerializer { constexpr static const size_t kFrameHeaderSize = 6; // bytes constexpr static const size_t kMinBytesNeededForAutodetection = 10; // bytes - ProtocolVersion protocolVersion() override; + ProtocolVersion protocolVersion() const override; static ProtocolVersion detectProtocolVersion( const folly::IOBuf& firstFrame, size_t skipBytes = 0); - FrameType peekFrameType(const folly::IOBuf& in) override; - folly::Optional peekStreamId(const folly::IOBuf& in) override; + FrameType peekFrameType(const folly::IOBuf& in) const override; + folly::Optional peekStreamId( + const folly::IOBuf& in, + bool skipFrameLengthBytes) const override; - std::unique_ptr serializeOut(Frame_REQUEST_STREAM&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_CHANNEL&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_RESPONSE&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_FNF&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_N&&) override; - std::unique_ptr serializeOut(Frame_METADATA_PUSH&&) override; - std::unique_ptr serializeOut(Frame_CANCEL&&) override; - std::unique_ptr serializeOut(Frame_PAYLOAD&&) override; - std::unique_ptr serializeOut(Frame_ERROR&&) override; - std::unique_ptr serializeOut(Frame_KEEPALIVE&&, bool) override; - std::unique_ptr serializeOut(Frame_SETUP&&) override; - std::unique_ptr serializeOut(Frame_LEASE&&) override; - std::unique_ptr serializeOut(Frame_RESUME&&) override; - std::unique_ptr serializeOut(Frame_RESUME_OK&&) override; + std::unique_ptr serializeOut( + Frame_REQUEST_STREAM&&) const override; + std::unique_ptr serializeOut( + Frame_REQUEST_CHANNEL&&) const override; + std::unique_ptr serializeOut( + Frame_REQUEST_RESPONSE&&) const override; + std::unique_ptr serializeOut( + Frame_REQUEST_FNF&&) const override; + std::unique_ptr serializeOut(Frame_REQUEST_N&&) const override; + std::unique_ptr serializeOut( + Frame_METADATA_PUSH&&) const override; + std::unique_ptr serializeOut(Frame_CANCEL&&) const override; + std::unique_ptr serializeOut(Frame_PAYLOAD&&) const override; + std::unique_ptr serializeOut(Frame_ERROR&&) const override; + std::unique_ptr serializeOut(Frame_KEEPALIVE&&) const override; + std::unique_ptr serializeOut(Frame_SETUP&&) const override; + std::unique_ptr serializeOut(Frame_LEASE&&) const override; + std::unique_ptr serializeOut(Frame_RESUME&&) const override; + std::unique_ptr serializeOut(Frame_RESUME_OK&&) const override; bool deserializeFrom(Frame_REQUEST_STREAM&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_CHANNEL&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_RESPONSE&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_FNF&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_N&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_METADATA_PUSH&, std::unique_ptr) - override; - bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) override; - bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) override; - bool deserializeFrom(Frame_ERROR&, std::unique_ptr) override; - bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr, bool) - override; - bool deserializeFrom(Frame_SETUP&, std::unique_ptr) override; - bool deserializeFrom(Frame_LEASE&, std::unique_ptr) override; - bool deserializeFrom(Frame_RESUME&, std::unique_ptr) override; + const override; + bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_ERROR&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_SETUP&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_LEASE&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_RESUME&, std::unique_ptr) + const override; bool deserializeFrom(Frame_RESUME_OK&, std::unique_ptr) - override; + const override; static std::unique_ptr deserializeMetadataFrom( folly::io::Cursor& cur, FrameFlags flags); + + private: + std::unique_ptr serializeOutInternal( + Frame_REQUEST_Base&& frame) const; + + size_t frameLengthFieldSize() const override; }; -} +} // namespace rsocket diff --git a/rsocket/framing/FrameTransport.h b/rsocket/framing/FrameTransport.h index 1bf0c7b46..6c5ed3ef1 100644 --- a/rsocket/framing/FrameTransport.h +++ b/rsocket/framing/FrameTransport.h @@ -1,18 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once -#include "yarpl/Refcounted.h" +#include +#include "rsocket/DuplexConnection.h" #include "rsocket/framing/FrameProcessor.h" namespace rsocket { // Refer to FrameTransportImpl for documentation on the implementation -class FrameTransport : public virtual yarpl::Refcounted { +class FrameTransport { public: + virtual ~FrameTransport() = default; virtual void setFrameProcessor(std::shared_ptr) = 0; virtual void outputFrameOrDrop(std::unique_ptr) = 0; virtual void close() = 0; + // Just for observation purposes! + // TODO(T25011919): remove virtual DuplexConnection* getConnection() = 0; + + virtual bool isConnectionFramed() const = 0; }; -} +} // namespace rsocket diff --git a/rsocket/framing/FrameTransportImpl.cpp b/rsocket/framing/FrameTransportImpl.cpp index 29faf5188..8e49b9bac 100644 --- a/rsocket/framing/FrameTransportImpl.cpp +++ b/rsocket/framing/FrameTransportImpl.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameTransportImpl.h" @@ -36,7 +48,7 @@ void FrameTransportImpl::connect() { // will create a hard reference for that case and keep the object alive // until setInput method returns auto connectionCopy = connection_; - connectionCopy->setInput(this->ref_from_this(this)); + connectionCopy->setInput(shared_from_this()); } } @@ -56,8 +68,7 @@ void FrameTransportImpl::close() { if (!connection_) { return; } - - auto oldConnection = std::move(connection_); + connection_.reset(); if (auto subscription = std::move(connectionInputSub_)) { subscription->cancel(); @@ -65,7 +76,7 @@ void FrameTransportImpl::close() { } void FrameTransportImpl::onSubscribe( - yarpl::Reference subscription) { + std::shared_ptr subscription) { if (!connection_) { return; } @@ -77,8 +88,10 @@ void FrameTransportImpl::onSubscribe( } void FrameTransportImpl::onNext(std::unique_ptr frame) { - CHECK(frameProcessor_); - frameProcessor_->processFrame(std::move(frame)); + // Copy in case frame processing calls through to close(). + if (auto const processor = frameProcessor_) { + processor->processFrame(std::move(frame)); + } } void FrameTransportImpl::terminateProcessor(folly::exception_wrapper ex) { @@ -115,4 +128,9 @@ void FrameTransportImpl::outputFrameOrDrop( } } +bool FrameTransportImpl::isConnectionFramed() const { + CHECK(connection_); + return connection_->isFramed(); +} + } // namespace rsocket diff --git a/rsocket/framing/FrameTransportImpl.h b/rsocket/framing/FrameTransportImpl.h index b7e9c0d1d..36ce9b526 100644 --- a/rsocket/framing/FrameTransportImpl.h +++ b/rsocket/framing/FrameTransportImpl.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -13,9 +25,11 @@ namespace rsocket { class FrameProcessor; -class FrameTransportImpl : public FrameTransport, - /// Registered as an input in the DuplexConnection. - public DuplexConnection::Subscriber { +class FrameTransportImpl + : public FrameTransport, + /// Registered as an input in the DuplexConnection. + public DuplexConnection::Subscriber, + public std::enable_shared_from_this { public: explicit FrameTransportImpl(std::unique_ptr connection); ~FrameTransportImpl(); @@ -37,16 +51,18 @@ class FrameTransportImpl : public FrameTransport, return connection_.get(); } - private: - void connect(); + bool isConnectionFramed() const override; // Subscriber. - void onSubscribe(yarpl::Reference) override; + void onSubscribe(std::shared_ptr) override; void onNext(std::unique_ptr) override; void onComplete() override; void onError(folly::exception_wrapper) override; + private: + void connect(); + /// Terminates the FrameProcessor. Will queue up the exception if no /// processor is set, overwriting any previously queued exception. void terminateProcessor(folly::exception_wrapper); @@ -54,7 +70,8 @@ class FrameTransportImpl : public FrameTransport, std::shared_ptr frameProcessor_; std::shared_ptr connection_; - yarpl::Reference connectionOutput_; - yarpl::Reference connectionInputSub_; + std::shared_ptr connectionOutput_; + std::shared_ptr connectionInputSub_; }; + } // namespace rsocket diff --git a/rsocket/framing/FrameType.cpp b/rsocket/framing/FrameType.cpp index d41bd44a3..8fb4fd140 100644 --- a/rsocket/framing/FrameType.cpp +++ b/rsocket/framing/FrameType.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameType.h" @@ -57,4 +69,4 @@ std::ostream& operator<<(std::ostream& os, FrameType type) { } return os << str; } -} +} // namespace rsocket diff --git a/rsocket/framing/FrameType.h b/rsocket/framing/FrameType.h index a4a020b48..726f9cd75 100644 --- a/rsocket/framing/FrameType.h +++ b/rsocket/framing/FrameType.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -32,4 +44,4 @@ folly::StringPiece toString(FrameType); std::ostream& operator<<(std::ostream&, FrameType); -} +} // namespace rsocket diff --git a/rsocket/framing/FramedDuplexConnection.cpp b/rsocket/framing/FramedDuplexConnection.cpp index c5fd451c8..9dec14a76 100644 --- a/rsocket/framing/FramedDuplexConnection.cpp +++ b/rsocket/framing/FramedDuplexConnection.cpp @@ -1,17 +1,24 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FramedDuplexConnection.h" - #include - -#include "rsocket/framing/FrameSerializer.h" #include "rsocket/framing/FrameSerializer_v1_0.h" #include "rsocket/framing/FramedReader.h" namespace rsocket { -using namespace yarpl::flowable; - namespace { constexpr auto kMaxFrameLength = 0xFFFFFF; // 24bit max value @@ -29,7 +36,7 @@ void writeFrameLength( auto shift = (frameSizeFieldLength - 1) * 8; while (frameSizeFieldLength--) { - auto byte = (frameLength >> shift) & 0xFF; + const auto byte = (frameLength >> shift) & 0xFF; cur.write(static_cast(byte)); shift -= 8; } @@ -44,27 +51,17 @@ size_t getFrameSizeFieldLength(ProtocolVersion version) { } } -size_t getPayloadLength(ProtocolVersion version, size_t payloadLength) { - DCHECK(version != ProtocolVersion::Unknown); - if (version < FrameSerializerV1_0::Version) { - return payloadLength + getFrameSizeFieldLength(version); - } else { - return payloadLength; - } -} - std::unique_ptr prependSize( ProtocolVersion version, std::unique_ptr payload) { CHECK(payload); const auto frameSizeFieldLength = getFrameSizeFieldLength(version); - // the frame size includes the payload size and the size value - auto payloadLength = - getPayloadLength(version, payload->computeChainDataLength()); - if (payloadLength > kMaxFrameLength) { - return nullptr; - } + const auto payloadLength = payload->computeChainDataLength(); + + CHECK_LE(payloadLength, kMaxFrameLength) + << "payloadLength: " << payloadLength + << " kMaxFrameLength: " << kMaxFrameLength; if (payload->headroom() >= frameSizeFieldLength) { // move the data pointer back and write value to the payload @@ -102,20 +99,13 @@ void FramedDuplexConnection::send(std::unique_ptr buf) { } auto sized = prependSize(*protocolVersion_, std::move(buf)); - if (!sized) { - protocolVersion_.reset(); - inputReader_.reset(); - inner_.reset(); - return; - } - inner_->send(std::move(sized)); } void FramedDuplexConnection::setInput( - yarpl::Reference framesSink) { + std::shared_ptr framesSink) { if (!inputReader_) { - inputReader_ = yarpl::make_ref(protocolVersion_); + inputReader_ = std::make_shared(protocolVersion_); inner_->setInput(inputReader_); } inputReader_->setInput(std::move(framesSink)); diff --git a/rsocket/framing/FramedDuplexConnection.h b/rsocket/framing/FramedDuplexConnection.h index 2e4ef8e81..2073266ea 100644 --- a/rsocket/framing/FramedDuplexConnection.h +++ b/rsocket/framing/FramedDuplexConnection.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -20,7 +32,7 @@ class FramedDuplexConnection : public virtual DuplexConnection { void send(std::unique_ptr) override; - void setInput(yarpl::Reference) override; + void setInput(std::shared_ptr) override; bool isFramed() const override { return true; @@ -31,8 +43,8 @@ class FramedDuplexConnection : public virtual DuplexConnection { } private: - std::unique_ptr inner_; - yarpl::Reference inputReader_; - std::shared_ptr protocolVersion_; + const std::unique_ptr inner_; + std::shared_ptr inputReader_; + const std::shared_ptr protocolVersion_; }; -} +} // namespace rsocket diff --git a/rsocket/framing/FramedReader.cpp b/rsocket/framing/FramedReader.cpp index 4daed0bcb..02edba694 100644 --- a/rsocket/framing/FramedReader.cpp +++ b/rsocket/framing/FramedReader.cpp @@ -1,11 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FramedReader.h" #include -#include "rsocket/framing/FrameSerializer_v0_1.h" #include "rsocket/framing/FrameSerializer_v1_0.h" +#include "rsocket/internal/Common.h" namespace rsocket { @@ -13,23 +25,19 @@ using namespace yarpl::flowable; namespace { -constexpr size_t kFrameLengthFieldLengthV0_1 = sizeof(int32_t); constexpr size_t kFrameLengthFieldLengthV1_0 = 3; /// Get the byte size of the frame length field in an RSocket frame. size_t frameSizeFieldLength(ProtocolVersion version) { DCHECK_NE(version, ProtocolVersion::Unknown); - return version < FrameSerializerV1_0::Version ? kFrameLengthFieldLengthV0_1 - : kFrameLengthFieldLengthV1_0; + return kFrameLengthFieldLengthV1_0; } /// Get the minimum size for a valid RSocket frame (including its frame length /// field). size_t minimalFrameLength(ProtocolVersion version) { DCHECK_NE(version, ProtocolVersion::Unknown); - return version < FrameSerializerV1_0::Version - ? FrameSerializerV0::kFrameHeaderSize + frameSizeFieldLength(version) - : FrameSerializerV1_0::kFrameHeaderSize; + return FrameSerializerV1_0::kFrameHeaderSize; } /// Compute the length of the entire frame (including its frame length field), @@ -48,10 +56,10 @@ size_t frameSizeWithoutLengthField(ProtocolVersion version, size_t frameSize) { ? frameSize - frameSizeFieldLength(version) : frameSize; } -} +} // namespace size_t FramedReader::readFrameLength() const { - auto fieldLength = frameSizeFieldLength(*version_); + const auto fieldLength = frameSizeFieldLength(*version_); DCHECK_GT(fieldLength, 0); folly::io::Cursor cur{payloadQueue_.front()}; @@ -66,9 +74,9 @@ size_t FramedReader::readFrameLength() const { return frameLength; } -void FramedReader::onSubscribe(yarpl::Reference subscription) { - DuplexConnection::DuplexSubscriber::onSubscribe(subscription); - subscription->request(std::numeric_limits::max()); +void FramedReader::onSubscribe(std::shared_ptr subscription) { + subscription_ = std::move(subscription); + subscription_->request(std::numeric_limits::max()); } void FramedReader::onNext(std::unique_ptr payload) { @@ -84,7 +92,7 @@ void FramedReader::parseFrames() { } // Delivering onNext can trigger termination and destroy this instance. - auto thisPtr = this->ref_from_this(this); + auto const self = shared_from_this(); dispatchingFrames_ = true; @@ -113,7 +121,8 @@ void FramedReader::parseFrames() { } payloadQueue_.trimStart(frameSizeFieldLen); - auto payloadSize = frameSizeWithoutLengthField(*version_, nextFrameSize); + const auto payloadSize = + frameSizeWithoutLengthField(*version_, nextFrameSize); DCHECK_GT(payloadSize, 0) << "folly::IOBufQueue::split(0) returns a nullptr, can't have that"; @@ -131,7 +140,7 @@ void FramedReader::parseFrames() { void FramedReader::onComplete() { payloadQueue_.move(); - DuplexConnection::DuplexSubscriber::onComplete(); + auto subscription = std::move(subscription_); if (auto subscriber = std::move(inner_)) { // After this call the instance can be destroyed! subscriber->onComplete(); @@ -140,7 +149,7 @@ void FramedReader::onComplete() { void FramedReader::onError(folly::exception_wrapper ex) { payloadQueue_.move(); - DuplexConnection::DuplexSubscriber::onError({}); + auto subscription = std::move(subscription_); if (auto subscriber = std::move(inner_)) { // After this call the instance can be destroyed! subscriber->onError(std::move(ex)); @@ -158,11 +167,11 @@ void FramedReader::cancel() { } void FramedReader::setInput( - yarpl::Reference inner) { + std::shared_ptr inner) { CHECK(!inner_) << "Must cancel original input to FramedReader before setting a new one"; inner_ = std::move(inner); - inner_->onSubscribe(this->ref_from_this(this)); + inner_->onSubscribe(shared_from_this()); } bool FramedReader::ensureOrAutodetectProtocolVersion() { @@ -170,33 +179,24 @@ bool FramedReader::ensureOrAutodetectProtocolVersion() { return true; } - auto minBytesNeeded = std::max( - FrameSerializerV0_1::kMinBytesNeededForAutodetection, - FrameSerializerV1_0::kMinBytesNeededForAutodetection); + const auto minBytesNeeded = + FrameSerializerV1_0::kMinBytesNeededForAutodetection; DCHECK_GT(minBytesNeeded, 0); if (payloadQueue_.chainLength() < minBytesNeeded) { return false; } - DCHECK_GT(minBytesNeeded, kFrameLengthFieldLengthV0_1); DCHECK_GT(minBytesNeeded, kFrameLengthFieldLengthV1_0); auto const& firstFrame = *payloadQueue_.front(); - auto detected = FrameSerializerV1_0::detectProtocolVersion( + const auto detectedV1 = FrameSerializerV1_0::detectProtocolVersion( firstFrame, kFrameLengthFieldLengthV1_0); - if (detected != ProtocolVersion::Unknown) { + if (detectedV1 != ProtocolVersion::Unknown) { *version_ = FrameSerializerV1_0::Version; return true; } - detected = FrameSerializerV0_1::detectProtocolVersion( - firstFrame, kFrameLengthFieldLengthV0_1); - if (detected != ProtocolVersion::Unknown) { - *version_ = FrameSerializerV0_1::Version; - return true; - } - error("Could not detect protocol version from framing"); return false; } @@ -205,12 +205,13 @@ void FramedReader::error(std::string errorMsg) { VLOG(1) << "error: " << errorMsg; payloadQueue_.move(); - if (DuplexConnection::DuplexSubscriber::subscription()) { - DuplexConnection::DuplexSubscriber::subscription()->cancel(); + if (auto subscription = std::move(subscription_)) { + subscription->cancel(); } if (auto subscriber = std::move(inner_)) { // After this call the instance can be destroyed! subscriber->onError(std::runtime_error{std::move(errorMsg)}); } } -} + +} // namespace rsocket diff --git a/rsocket/framing/FramedReader.h b/rsocket/framing/FramedReader.h index a4b361058..d0bc05a4f 100644 --- a/rsocket/framing/FramedReader.h +++ b/rsocket/framing/FramedReader.h @@ -1,31 +1,44 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include "rsocket/DuplexConnection.h" +#include "rsocket/framing/ProtocolVersion.h" #include "rsocket/internal/Allowance.h" -#include "rsocket/internal/Common.h" #include "yarpl/flowable/Subscription.h" namespace rsocket { -class FramedReader : public DuplexConnection::DuplexSubscriber, - public yarpl::flowable::Subscription { +class FramedReader : public DuplexConnection::Subscriber, + public yarpl::flowable::Subscription, + public std::enable_shared_from_this { public: explicit FramedReader(std::shared_ptr version) : version_{std::move(version)} {} /// Set the inner subscriber which will be getting full frame payloads. - void setInput(yarpl::Reference); + void setInput(std::shared_ptr); /// Cancel the subscription and error the inner subscriber. void error(std::string); // Subscriber. - void onSubscribe(yarpl::Reference) override; + void onSubscribe(std::shared_ptr) override; void onNext(std::unique_ptr) override; void onComplete() override; void onError(folly::exception_wrapper) override; @@ -41,12 +54,14 @@ class FramedReader : public DuplexConnection::DuplexSubscriber, size_t readFrameLength() const; - yarpl::Reference inner_; + std::shared_ptr subscription_; + std::shared_ptr inner_; Allowance allowance_; bool dispatchingFrames_{false}; folly::IOBufQueue payloadQueue_{folly::IOBufQueue::cacheChainLength()}; - std::shared_ptr version_; + const std::shared_ptr version_; }; -} + +} // namespace rsocket diff --git a/rsocket/framing/Framer.cpp b/rsocket/framing/Framer.cpp new file mode 100644 index 000000000..0fb97763b --- /dev/null +++ b/rsocket/framing/Framer.cpp @@ -0,0 +1,204 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/Framer.h" +#include +#include "rsocket/framing/FrameSerializer_v1_0.h" + +namespace rsocket { + +namespace { + +constexpr size_t kFrameLengthFieldLengthV1_0 = 3; +constexpr auto kMaxFrameLength = 0xFFFFFF; // 24bit max value + +template +void writeFrameLength( + TWriter& cur, + size_t frameLength, + size_t frameSizeFieldLength) { + DCHECK(frameSizeFieldLength > 0); + + // starting from the highest byte + // frameSizeFieldLength == 3 => shift = [16,8,0] + // frameSizeFieldLength == 4 => shift = [24,16,8,0] + auto shift = (frameSizeFieldLength - 1) * 8; + + while (frameSizeFieldLength--) { + const auto byte = (frameLength >> shift) & 0xFF; + cur.write(static_cast(byte)); + shift -= 8; + } +} +} // namespace + +/// Get the byte size of the frame length field in an RSocket frame. +size_t Framer::frameSizeFieldLength() const { + DCHECK_NE(protocolVersion_, ProtocolVersion::Unknown); + if (protocolVersion_ < FrameSerializerV1_0::Version) { + return sizeof(int32_t); + } else { + return 3; // bytes + } +} + +/// Get the minimum size for a valid RSocket frame (including its frame length +/// field). +size_t Framer::minimalFrameLength() const { + DCHECK_NE(protocolVersion_, ProtocolVersion::Unknown); + return FrameSerializerV1_0::kFrameHeaderSize; +} + +/// Compute the length of the entire frame (including its frame length field), +/// if given only its frame length field. +size_t Framer::frameSizeWithLengthField(size_t frameSize) const { + return protocolVersion_ < FrameSerializerV1_0::Version + ? frameSize + : frameSize + frameSizeFieldLength(); +} + +/// Compute the length of the frame (excluding its frame length field), if given +/// only its frame length field. +size_t Framer::frameSizeWithoutLengthField(size_t frameSize) const { + DCHECK_NE(protocolVersion_, ProtocolVersion::Unknown); + return protocolVersion_ < FrameSerializerV1_0::Version + ? frameSize - frameSizeFieldLength() + : frameSize; +} + +size_t Framer::readFrameLength() const { + const auto fieldLength = frameSizeFieldLength(); + DCHECK_GT(fieldLength, 0); + + folly::io::Cursor cur{payloadQueue_.front()}; + size_t frameLength = 0; + + // Reading of arbitrary-sized big-endian integer. + for (size_t i = 0; i < fieldLength; ++i) { + frameLength <<= 8; + frameLength |= cur.read(); + } + + return frameLength; +} + +void Framer::addFrameChunk(std::unique_ptr payload) { + payloadQueue_.append(std::move(payload)); + parseFrames(); +} + +void Framer::parseFrames() { + if (payloadQueue_.empty() || !ensureOrAutodetectProtocolVersion()) { + // At this point we dont have enough bytes on the wire or we errored out. + return; + } + + while (!payloadQueue_.empty()) { + auto const frameSizeFieldLen = frameSizeFieldLength(); + if (payloadQueue_.chainLength() < frameSizeFieldLen) { + // We don't even have the next frame size value. + break; + } + + auto const nextFrameSize = readFrameLength(); + if (nextFrameSize < minimalFrameLength()) { + error("Invalid frame - Frame size smaller than minimum"); + break; + } + + if (payloadQueue_.chainLength() < frameSizeWithLengthField(nextFrameSize)) { + // Need to accumulate more data. + break; + } + + auto payloadSize = frameSizeWithoutLengthField(nextFrameSize); + if (stripFrameLengthField_) { + payloadQueue_.trimStart(frameSizeFieldLen); + } else { + payloadSize += frameSizeFieldLen; + } + + DCHECK_GT(payloadSize, 0) + << "folly::IOBufQueue::split(0) returns a nullptr, can't have that"; + auto nextFrame = payloadQueue_.split(payloadSize); + onFrame(std::move(nextFrame)); + } +} + +bool Framer::ensureOrAutodetectProtocolVersion() { + if (protocolVersion_ != ProtocolVersion::Unknown) { + return true; + } + + const auto minBytesNeeded = + FrameSerializerV1_0::kMinBytesNeededForAutodetection; + DCHECK_GT(minBytesNeeded, 0); + if (payloadQueue_.chainLength() < minBytesNeeded) { + return false; + } + + DCHECK_GT(minBytesNeeded, kFrameLengthFieldLengthV1_0); + + auto const& firstFrame = *payloadQueue_.front(); + + const auto detectedV1 = FrameSerializerV1_0::detectProtocolVersion( + firstFrame, kFrameLengthFieldLengthV1_0); + if (detectedV1 != ProtocolVersion::Unknown) { + protocolVersion_ = FrameSerializerV1_0::Version; + return true; + } + + error("Could not detect protocol version from data"); + return false; +} + +std::unique_ptr Framer::prependSize( + std::unique_ptr payload) { + CHECK(payload); + + const auto frameSizeFieldLengthValue = frameSizeFieldLength(); + const auto payloadLength = payload->computeChainDataLength(); + + CHECK_LE(payloadLength, kMaxFrameLength) + << "payloadLength: " << payloadLength + << " kMaxFrameLength: " << kMaxFrameLength; + + if (payload->headroom() >= frameSizeFieldLengthValue) { + // move the data pointer back and write value to the payload + payload->prepend(frameSizeFieldLengthValue); + folly::io::RWPrivateCursor cur(payload.get()); + writeFrameLength(cur, payloadLength, frameSizeFieldLengthValue); + return payload; + } else { + auto newPayload = folly::IOBuf::createCombined(frameSizeFieldLengthValue); + folly::io::Appender appender(newPayload.get(), /* do not grow */ 0); + writeFrameLength(appender, payloadLength, frameSizeFieldLengthValue); + newPayload->appendChain(std::move(payload)); + return newPayload; + } +} + +StreamId Framer::peekStreamId( + const folly::IOBuf& frame, + bool skipFrameLengthBytes) const { + return FrameSerializer::peekStreamId( + protocolVersion_, frame, skipFrameLengthBytes) + .value(); +} + +std::unique_ptr Framer::drainPayloadQueue() { + return payloadQueue_.move(); +} + +} // namespace rsocket diff --git a/rsocket/framing/Framer.h b/rsocket/framing/Framer.h new file mode 100644 index 000000000..2ff740492 --- /dev/null +++ b/rsocket/framing/Framer.h @@ -0,0 +1,73 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "rsocket/framing/ProtocolVersion.h" +#include "rsocket/internal/Common.h" + +namespace rsocket { + +/// +/// Frames class is used to parse individual rsocket frames from the stream of +/// incoming payload chunks. Every time a frame is parsed the onFrame method is +/// invoked. +/// Each rsocket frame is prepended with the frame length by +/// prependSize method. +/// +class Framer { + public: + Framer(ProtocolVersion protocolVersion, bool stripFrameLengthField) + : protocolVersion_{protocolVersion}, + stripFrameLengthField_{stripFrameLengthField} {} + virtual ~Framer() {} + + /// For processing incoming frame chunks + void addFrameChunk(std::unique_ptr); + + /// Prepends payload size to the beginning of he IOBuf based on the + /// set protocol version + std::unique_ptr prependSize( + std::unique_ptr payload); + + /// derived class can override this method to react to termination + virtual void error(const char*) = 0; + virtual void onFrame(std::unique_ptr) = 0; + + ProtocolVersion protocolVersion() const { + return protocolVersion_; + } + + StreamId peekStreamId(const folly::IOBuf& frame, bool) const; + + std::unique_ptr drainPayloadQueue(); + + private: + // to explicitly trigger parsing frames + void parseFrames(); + bool ensureOrAutodetectProtocolVersion(); + + size_t readFrameLength() const; + size_t frameSizeFieldLength() const; + size_t minimalFrameLength() const; + size_t frameSizeWithLengthField(size_t frameSize) const; + size_t frameSizeWithoutLengthField(size_t frameSize) const; + + folly::IOBufQueue payloadQueue_{folly::IOBufQueue::cacheChainLength()}; + ProtocolVersion protocolVersion_; + const bool stripFrameLengthField_; +}; + +} // namespace rsocket diff --git a/rsocket/framing/ProtocolVersion.cpp b/rsocket/framing/ProtocolVersion.cpp new file mode 100644 index 000000000..ee8f54c5f --- /dev/null +++ b/rsocket/framing/ProtocolVersion.cpp @@ -0,0 +1,32 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/ProtocolVersion.h" + +#include +#include + +namespace rsocket { + +const ProtocolVersion ProtocolVersion::Unknown = ProtocolVersion( + std::numeric_limits::max(), + std::numeric_limits::max()); + +const ProtocolVersion ProtocolVersion::Latest = ProtocolVersion(1, 0); + +std::ostream& operator<<(std::ostream& os, const ProtocolVersion& version) { + return os << version.major << "." << version.minor; +} + +} // namespace rsocket diff --git a/rsocket/framing/ProtocolVersion.h b/rsocket/framing/ProtocolVersion.h new file mode 100644 index 000000000..3daf24dad --- /dev/null +++ b/rsocket/framing/ProtocolVersion.h @@ -0,0 +1,75 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace rsocket { + +// Bug in GCC: https://bugzilla.redhat.com/show_bug.cgi?id=130601 +#pragma push_macro("major") +#pragma push_macro("minor") +#undef major +#undef minor + +struct ProtocolVersion { + uint16_t major{}; + uint16_t minor{}; + + constexpr ProtocolVersion() = default; + constexpr ProtocolVersion(uint16_t _major, uint16_t _minor) + : major(_major), minor(_minor) {} + + static const ProtocolVersion Unknown; + static const ProtocolVersion Latest; +}; + +#pragma pop_macro("major") +#pragma pop_macro("minor") + +std::ostream& operator<<(std::ostream&, const ProtocolVersion&); + +constexpr bool operator==( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return left.major == right.major && left.minor == right.minor; +} + +constexpr bool operator!=( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return !(left == right); +} + +constexpr bool operator<( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return left != ProtocolVersion::Unknown && + right != ProtocolVersion::Unknown && + (left.major < right.major || + (left.major == right.major && left.minor < right.minor)); +} + +constexpr bool operator>( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return left != ProtocolVersion::Unknown && + right != ProtocolVersion::Unknown && + (left.major > right.major || + (left.major == right.major && left.minor > right.minor)); +} + +} // namespace rsocket diff --git a/rsocket/framing/ResumeIdentificationToken.cpp b/rsocket/framing/ResumeIdentificationToken.cpp new file mode 100644 index 000000000..31011f14d --- /dev/null +++ b/rsocket/framing/ResumeIdentificationToken.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/ResumeIdentificationToken.h" + +#include +#include +#include +#include + +#include +#include + +namespace rsocket { + +constexpr const char* kHexChars = "0123456789abcdef"; + +ResumeIdentificationToken::ResumeIdentificationToken() {} + +ResumeIdentificationToken::ResumeIdentificationToken(const std::string& token) { + const auto getNibble = [&token](size_t i) { + uint8_t nibble; + if (token[i] >= '0' && token[i] <= '9') { + nibble = token[i] - '0'; + } else if (token[i] >= 'a' && token[i] <= 'f') { + nibble = token[i] - 'a' + 10; + } else { + throw std::invalid_argument("ResumeToken not in right format: " + token); + } + return nibble; + }; + if (token.size() < 2 || token[0] != '0' || token[1] != 'x' || + (token.size() % 2) != 0) { + throw std::invalid_argument("ResumeToken not in right format: " + token); + } + for (size_t i = 2 /* skipping '0x' */; i < token.size(); i += 2) { + const uint8_t firstNibble = getNibble(i + 0); + const uint8_t secondNibble = getNibble(i + 1); + bits_.push_back((firstNibble << 4) | secondNibble); + } +} + +ResumeIdentificationToken ResumeIdentificationToken::generateNew() { + constexpr size_t kSize = 16; + std::vector data; + data.reserve(kSize); + for (size_t i = 0; i < kSize; i++) { + data.push_back(static_cast(folly::Random::rand32())); + } + return ResumeIdentificationToken(std::move(data)); +} + +void ResumeIdentificationToken::set(std::vector newBits) { + CHECK(newBits.size() <= std::numeric_limits::max()); + bits_ = std::move(newBits); +} + +std::string ResumeIdentificationToken::str() const { + std::stringstream out; + out << *this; + return out.str(); +} + +std::ostream& operator<<( + std::ostream& out, + const ResumeIdentificationToken& token) { + out << "0x"; + for (const auto b : token.data()) { + out << kHexChars[(b & 0xF0) >> 4]; + out << kHexChars[b & 0x0F]; + } + return out; +} + +} // namespace rsocket diff --git a/rsocket/framing/ResumeIdentificationToken.h b/rsocket/framing/ResumeIdentificationToken.h new file mode 100644 index 000000000..be276ec3e --- /dev/null +++ b/rsocket/framing/ResumeIdentificationToken.h @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace rsocket { + +class ResumeIdentificationToken { + public: + /// Creates an empty token. + ResumeIdentificationToken(); + + // The string token and ::str() function should complement each other. The + // string representation should be of the format + // 0x44ab7cf01fd290b63140d01ee789cfb6 + explicit ResumeIdentificationToken(const std::string&); + + static ResumeIdentificationToken generateNew(); + + const std::vector& data() const { + return bits_; + } + + void set(std::vector newBits); + + bool operator==(const ResumeIdentificationToken& right) const { + return data() == right.data(); + } + + bool operator!=(const ResumeIdentificationToken& right) const { + return data() != right.data(); + } + + bool operator<(const ResumeIdentificationToken& right) const { + return data() < right.data(); + } + + std::string str() const; + + private: + explicit ResumeIdentificationToken(std::vector bits) + : bits_(std::move(bits)) {} + + std::vector bits_; +}; + +std::ostream& operator<<(std::ostream&, const ResumeIdentificationToken&); + +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameProcessor.cpp b/rsocket/framing/ScheduledFrameProcessor.cpp index c40c69168..e1abeade9 100644 --- a/rsocket/framing/ScheduledFrameProcessor.cpp +++ b/rsocket/framing/ScheduledFrameProcessor.cpp @@ -1,24 +1,43 @@ -#include "rsocket/framing/ScheduledFrameProcessor.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include -#include +#include "rsocket/framing/ScheduledFrameProcessor.h" namespace rsocket { -ScheduledFrameProcessor::~ScheduledFrameProcessor() {} +ScheduledFrameProcessor::ScheduledFrameProcessor( + std::shared_ptr processor, + folly::EventBase* evb) + : evb_{evb}, processor_{std::move(processor)} {} + +ScheduledFrameProcessor::~ScheduledFrameProcessor() = default; void ScheduledFrameProcessor::processFrame( std::unique_ptr ioBuf) { + CHECK(processor_) << "Calling processFrame() after onTerminal()"; + evb_->runInEventBaseThread( - [ fp = frameProcessor_, ioBuf = std::move(ioBuf) ]() mutable { - fp->processFrame(std::move(ioBuf)); + [processor = processor_, buf = std::move(ioBuf)]() mutable { + processor->processFrame(std::move(buf)); }); } -void ScheduledFrameProcessor::onTerminal(folly::exception_wrapper ex) { +void ScheduledFrameProcessor::onTerminal(folly::exception_wrapper ew) { evb_->runInEventBaseThread( - [ ex = std::move(ex), fp = frameProcessor_ ]() mutable { - fp->onTerminal(std::move(ex)); + [e = std::move(ew), processor = std::move(processor_)]() mutable { + processor->onTerminal(std::move(e)); }); } -} + +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameProcessor.h b/rsocket/framing/ScheduledFrameProcessor.h index 173dfaa86..e4546af79 100644 --- a/rsocket/framing/ScheduledFrameProcessor.h +++ b/rsocket/framing/ScheduledFrameProcessor.h @@ -1,3 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include @@ -16,19 +30,15 @@ namespace rsocket { // (FrameProcessor) in the original EventBase. class ScheduledFrameProcessor : public FrameProcessor { public: - ScheduledFrameProcessor( - std::shared_ptr fp, - folly::EventBase* evb) - : frameProcessor_(std::move(fp)), evb_(evb) {} - + ScheduledFrameProcessor(std::shared_ptr, folly::EventBase*); ~ScheduledFrameProcessor(); - void processFrame(std::unique_ptr ioBuf) override; - void onTerminal(folly::exception_wrapper ex) override; + void processFrame(std::unique_ptr) override; + void onTerminal(folly::exception_wrapper) override; private: - std::shared_ptr frameProcessor_; - folly::EventBase* evb_; + folly::EventBase* const evb_; + std::shared_ptr processor_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameTransport.cpp b/rsocket/framing/ScheduledFrameTransport.cpp index 7a3656eb2..88f715f16 100644 --- a/rsocket/framing/ScheduledFrameTransport.cpp +++ b/rsocket/framing/ScheduledFrameTransport.cpp @@ -1,33 +1,58 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "rsocket/framing/ScheduledFrameTransport.h" -#include +#include "rsocket/framing/ScheduledFrameProcessor.h" namespace rsocket { -ScheduledFrameTransport::~ScheduledFrameTransport() {} +ScheduledFrameTransport::~ScheduledFrameTransport() = default; void ScheduledFrameTransport::setFrameProcessor( std::shared_ptr fp) { - transportEvb_->runInEventBaseThread( - [ this, self = this->ref_from_this(this), fp = std::move(fp) ]() mutable { - auto scheduledFP = std::make_shared( - std::move(fp), stateMachineEvb_); - frameTransport_->setFrameProcessor(std::move(scheduledFP)); - }); + CHECK(frameTransport_) << "Inner transport already closed"; + + transportEvb_->runInEventBaseThread([stateMachineEvb = stateMachineEvb_, + transport = frameTransport_, + fp = std::move(fp)]() mutable { + auto scheduledFP = std::make_shared( + std::move(fp), stateMachineEvb); + transport->setFrameProcessor(std::move(scheduledFP)); + }); } void ScheduledFrameTransport::outputFrameOrDrop( std::unique_ptr ioBuf) { + CHECK(frameTransport_) << "Inner transport already closed"; + transportEvb_->runInEventBaseThread( - [ ft = frameTransport_, ioBuf = std::move(ioBuf) ]() mutable { - ft->outputFrameOrDrop(std::move(ioBuf)); + [transport = frameTransport_, buf = std::move(ioBuf)]() mutable { + transport->outputFrameOrDrop(std::move(buf)); }); } void ScheduledFrameTransport::close() { - transportEvb_->runInEventBaseThread([ft = frameTransport_]() { - ft->close(); - }); + CHECK(frameTransport_) << "Inner transport already closed"; + + transportEvb_->runInEventBaseThread( + [transport = std::move(frameTransport_)]() { transport->close(); }); +} + +bool ScheduledFrameTransport::isConnectionFramed() const { + CHECK(frameTransport_) << "Inner transport already closed"; + return frameTransport_->isConnectionFramed(); } -} // rsocket +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameTransport.h b/rsocket/framing/ScheduledFrameTransport.h index 868f396b7..cc53f9444 100644 --- a/rsocket/framing/ScheduledFrameTransport.h +++ b/rsocket/framing/ScheduledFrameTransport.h @@ -1,9 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include -#include "rsocket/framing/FrameTransportImpl.h" -#include "rsocket/framing/ScheduledFrameProcessor.h" +#include "rsocket/framing/FrameTransport.h" namespace rsocket { @@ -15,11 +28,10 @@ namespace rsocket { // original RSocketStateMachine was constructed for the client. Here the // RSocketStateMachine uses this class to schedule events of the Transport in // the new EventBase. -class ScheduledFrameTransport : public FrameTransport, - public yarpl::enable_get_ref { +class ScheduledFrameTransport : public FrameTransport { public: ScheduledFrameTransport( - yarpl::Reference frameTransport, + std::shared_ptr frameTransport, folly::EventBase* transportEvb, folly::EventBase* stateMachineEvb) : transportEvb_(transportEvb), @@ -28,22 +40,24 @@ class ScheduledFrameTransport : public FrameTransport, ~ScheduledFrameTransport(); - void setFrameProcessor(std::shared_ptr fp) override; - void outputFrameOrDrop(std::unique_ptr ioBuf) override; + void setFrameProcessor(std::shared_ptr) override; + void outputFrameOrDrop(std::unique_ptr) override; void close() override; + bool isConnectionFramed() const override; private: DuplexConnection* getConnection() override { DLOG(FATAL) << "ScheduledFrameTransport doesn't support getConnection method, " - "because it can create safe usage issues when EventBase of the " - "transport and the RSocketClient is not the same."; + "because it can create safe usage issues when EventBase of the " + "transport and the RSocketClient is not the same."; return nullptr; } private: - folly::EventBase* transportEvb_; - folly::EventBase* stateMachineEvb_; - yarpl::Reference frameTransport_; + folly::EventBase* const transportEvb_; + folly::EventBase* const stateMachineEvb_; + std::shared_ptr frameTransport_; }; -} + +} // namespace rsocket diff --git a/rsocket/internal/Allowance.h b/rsocket/internal/Allowance.h index 25d29eb7e..059dd3c47 100644 --- a/rsocket/internal/Allowance.h +++ b/rsocket/internal/Allowance.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -70,4 +82,4 @@ class Allowance { "Allowance representation must be an integer type"); ValueType value_{0}; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/ClientResumeStatusCallback.h b/rsocket/internal/ClientResumeStatusCallback.h index 0ec5078df..abe20fc9d 100644 --- a/rsocket/internal/ClientResumeStatusCallback.h +++ b/rsocket/internal/ClientResumeStatusCallback.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -18,4 +30,4 @@ class ClientResumeStatusCallback { virtual void onResumeError(folly::exception_wrapper ex) noexcept = 0; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/Common.cpp b/rsocket/internal/Common.cpp index f111fb772..fbd33f592 100644 --- a/rsocket/internal/Common.cpp +++ b/rsocket/internal/Common.cpp @@ -1,22 +1,29 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/Common.h" +#include + #include #include #include +#include #include namespace rsocket { -namespace { -constexpr const char* HEX_CHARS = {"0123456789abcdef"}; -} - -constexpr const ProtocolVersion ProtocolVersion::Unknown = ProtocolVersion( - std::numeric_limits::max(), - std::numeric_limits::max()); - static const char* getTerminatingSignalErrorMessage(int terminatingSignal) { switch (static_cast(terminatingSignal)) { case StreamCompletionSignal::CONNECTION_END: @@ -44,12 +51,32 @@ static const char* getTerminatingSignalErrorMessage(int terminatingSignal) { } } +folly::StringPiece toString(StreamType t) { + switch (t) { + case StreamType::REQUEST_RESPONSE: + return "REQUEST_RESPONSE"; + case StreamType::STREAM: + return "STREAM"; + case StreamType::CHANNEL: + return "CHANNEL"; + case StreamType::FNF: + return "FNF"; + default: + DCHECK(false); + return "(invalid StreamType)"; + } +} + +std::ostream& operator<<(std::ostream& os, StreamType t) { + return os << toString(t); +} + std::ostream& operator<<(std::ostream& os, RSocketMode mode) { switch (mode) { - case RSocketMode::CLIENT: - return os << "CLIENT"; - case RSocketMode::SERVER: - return os << "SERVER"; + case RSocketMode::CLIENT: + return os << "CLIENT"; + case RSocketMode::SERVER: + return os << "SERVER"; } DLOG(FATAL) << "Invalid RSocketMode"; return os << "INVALID_RSOCKET_MODE"; @@ -80,6 +107,7 @@ std::string to_string(StreamCompletionSignal signal) { } // this should be never hit because the switch is over all cases LOG(FATAL) << "unknown StreamCompletionSignal=" << static_cast(signal); + return ""; } std::ostream& operator<<(std::ostream& os, StreamCompletionSignal signal) { @@ -90,65 +118,23 @@ StreamInterruptedException::StreamInterruptedException(int _terminatingSignal) : std::runtime_error(getTerminatingSignalErrorMessage(_terminatingSignal)), terminatingSignal(_terminatingSignal) {} -ResumeIdentificationToken::ResumeIdentificationToken() {} +std::string humanify(std::unique_ptr const& buf) { + std::string ret; + size_t cursor = 0; -ResumeIdentificationToken::ResumeIdentificationToken(const std::string& token) { - auto getNibble = [&token](size_t i) { - uint8_t nibble; - if (token[i] >= '0' && token[i] <= '9') { - nibble = token[i] - '0'; - } else if (token[i] >= 'a' && token[i] <= 'f') { - nibble = token[i] - 'a' + 10; - } else { - throw std::invalid_argument("ResumeToken not in right format: " + token); + for (const auto& range : *buf) { + for (const unsigned char chr : range) { + if (cursor >= 20) + goto outer; + ret += chr; + cursor++; } - return nibble; - }; - if (token.size() < 2 || token[0] != '0' || token[1] != 'x' || - (token.size() % 2) != 0) { - throw std::invalid_argument("ResumeToken not in right format: " + token); - } - size_t i = 2; - while (i < token.size()) { - uint8_t firstNibble = getNibble(i++); - uint8_t secondNibble = getNibble(i++); - bits_.push_back((firstNibble << 4) | secondNibble); } -} - -ResumeIdentificationToken ResumeIdentificationToken::generateNew() { - constexpr size_t kSize = 16; - std::vector data; - data.reserve(kSize); - for (size_t i = 0; i < kSize; i++) { - data.push_back(static_cast(folly::Random::rand32())); - } - return ResumeIdentificationToken(std::move(data)); -} - -void ResumeIdentificationToken::set(std::vector newBits) { - CHECK(newBits.size() <= std::numeric_limits::max()); - bits_ = std::move(newBits); -} +outer: -std::string ResumeIdentificationToken::str() const { - std::stringstream out; - out << *this; - return out.str(); + return folly::humanify(ret); } - -std::ostream& operator<<( - std::ostream& out, - const ResumeIdentificationToken& token) { - out << "0x"; - for (auto b : token.data()) { - out << HEX_CHARS[(b & 0xF0) >> 4]; - out << HEX_CHARS[b & 0x0F]; - } - return out; -} - std::string hexDump(folly::StringPiece s) { - return folly::hexDump(s.data(), s.size()); + return folly::hexDump(s.data(), std::min(0xFF, s.size())); } -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/Common.h b/rsocket/internal/Common.h index ec044de12..a096a5545 100644 --- a/rsocket/internal/Common.h +++ b/rsocket/internal/Common.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -21,20 +33,18 @@ class IOBuf; template class Range; typedef Range StringPiece; -} +} // namespace folly namespace rsocket { -constexpr std::chrono::seconds kDefaultKeepaliveInterval{5}; - -constexpr int64_t kMaxRequestN = std::numeric_limits::max(); - /// A unique identifier of a stream. using StreamId = uint32_t; -using ResumePosition = int64_t; -constexpr const ResumePosition kUnspecifiedResumePosition = -1; +constexpr std::chrono::seconds kDefaultKeepaliveInterval{5}; + +constexpr int64_t kMaxRequestN = std::numeric_limits::max(); +std::string humanify(std::unique_ptr const&); std::string hexDump(folly::StringPiece s); /// Indicates the reason why the stream stateMachine received a terminal signal @@ -63,6 +73,9 @@ enum class StreamType { FNF, }; +folly::StringPiece toString(StreamType); +std::ostream& operator<<(std::ostream&, StreamType); + enum class RequestOriginator { LOCAL, REMOTE, @@ -74,104 +87,9 @@ std::ostream& operator<<(std::ostream&, StreamCompletionSignal); class StreamInterruptedException : public std::runtime_error { public: explicit StreamInterruptedException(int _terminatingSignal); - int terminatingSignal; + const int terminatingSignal; }; -class ResumeIdentificationToken { - public: - /// Creates an empty token. - ResumeIdentificationToken(); - - // The stringToken and ::str() function should complement - // each other. The string representation should be of the - // format 0x44ab7cf01fd290b63140d01ee789cfb6 - explicit ResumeIdentificationToken(const std::string& stringToken); - - static ResumeIdentificationToken generateNew(); - - const std::vector& data() const { - return bits_; - } - - void set(std::vector newBits); - - bool operator==(const ResumeIdentificationToken& right) const { - return data() == right.data(); - } - - bool operator!=(const ResumeIdentificationToken& right) const { - return data() != right.data(); - } - - bool operator<(const ResumeIdentificationToken& right) const { - return data() < right.data(); - } - - std::string str() const; - - private: - explicit ResumeIdentificationToken(std::vector bits) - : bits_(std::move(bits)) {} - - std::vector bits_; -}; - -std::ostream& operator<<(std::ostream&, const ResumeIdentificationToken&); - -// bug in GCC: https://bugzilla.redhat.com/show_bug.cgi?id=130601 -#pragma push_macro("major") -#pragma push_macro("minor") -#undef major -#undef minor - -struct ProtocolVersion { - uint16_t major{}; - uint16_t minor{}; - - constexpr ProtocolVersion() = default; - constexpr ProtocolVersion(uint16_t _major, uint16_t _minor) - : major(_major), minor(_minor) {} - - static const ProtocolVersion Unknown; - static const ProtocolVersion Latest; - static ProtocolVersion Current(); -}; - -#pragma pop_macro("major") -#pragma pop_macro("minor") - -std::ostream& operator<<(std::ostream&, const ProtocolVersion&); - -constexpr inline bool operator==( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return left.major == right.major && left.minor == right.minor; -} - -constexpr inline bool operator!=( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return !(left == right); -} - -constexpr inline bool operator<( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return left != ProtocolVersion::Unknown && - right != ProtocolVersion::Unknown && - (left.major < right.major || - (left.major == right.major && left.minor < right.minor)); -} - -constexpr inline bool operator>( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return left != ProtocolVersion::Unknown && - right != ProtocolVersion::Unknown && - (left.major > right.major || - (left.major == right.major && left.minor > right.minor)); -} - class FrameSink; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/ConnectionSet.cpp b/rsocket/internal/ConnectionSet.cpp index 47d6253ba..0ed32db6a 100644 --- a/rsocket/internal/ConnectionSet.cpp +++ b/rsocket/internal/ConnectionSet.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ConnectionSet.h" @@ -11,20 +23,31 @@ namespace rsocket { ConnectionSet::ConnectionSet() {} ConnectionSet::~ConnectionSet() { - VLOG(1) << "Started ~ConnectionSet"; - SCOPE_EXIT { VLOG(1) << "Finished ~ConnectionSet"; }; + if (!shutDown_) { + shutdownAndWait(); + } +} + +void ConnectionSet::shutdownAndWait() { + VLOG(1) << "Started ConnectionSet::shutdownAndWait"; + shutDown_ = true; + + SCOPE_EXIT { + VLOG(1) << "Finished ConnectionSet::shutdownAndWait"; + }; StateMachineMap map; // Move all the connections out of the synchronized map so we don't block // while closing the state machines. { - auto locked = machines_.lock(); + const auto locked = machines_.lock(); if (locked->empty()) { VLOG(2) << "No connections to close, early exit"; return; } + targetRemoves_ = removes_ + locked->size(); map.swap(*locked); } @@ -34,7 +57,7 @@ ConnectionSet::~ConnectionSet() { auto rsocket = std::move(kv.first); auto evb = kv.second; - auto close = [rs = std::move(rsocket)] { + const auto close = [rs = std::move(rsocket)] { rs->close({}, StreamCompletionSignal::SOCKET_CLOSED); }; @@ -48,18 +71,38 @@ ConnectionSet::~ConnectionSet() { evb->runInEventBaseThread(close); } } + + VLOG(2) << "Waiting for connections to close"; + shutdownDone_.wait(); + VLOG(2) << "Connections have closed"; } -void ConnectionSet::insert( +bool ConnectionSet::insert( std::shared_ptr machine, folly::EventBase* evb) { + VLOG(4) << "insert(" << machine.get() << ", " << evb << ")"; + + if (shutDown_) { + return false; + } machines_.lock()->emplace(std::move(machine), evb); + return true; } -void ConnectionSet::remove( - const std::shared_ptr& machine) { - auto locked = machines_.lock(); - auto const result = locked->erase(machine); +void ConnectionSet::remove(RSocketStateMachine& machine) { + VLOG(4) << "remove(" << &machine << ")"; + + const auto locked = machines_.lock(); + auto const result = locked->erase(machine.shared_from_this()); DCHECK_LE(result, 1); + + if (++removes_ == targetRemoves_) { + shutdownDone_.post(); + } } + +size_t ConnectionSet::size() const { + return machines_.lock()->size(); } + +} // namespace rsocket diff --git a/rsocket/internal/ConnectionSet.h b/rsocket/internal/ConnectionSet.h index 0d32f0034..b679b96f2 100644 --- a/rsocket/internal/ConnectionSet.h +++ b/rsocket/internal/ConnectionSet.h @@ -1,39 +1,60 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include +#include #include #include #include +#include "rsocket/statemachine/RSocketStateMachine.h" + namespace folly { class EventBase; } namespace rsocket { -class RSocketStateMachine; - /// Set of RSocketStateMachine objects. Stores them until they call /// RSocketStateMachine::close(). /// /// Also tracks which EventBase is controlling each state machine so that they /// can be closed on the correct thread. -class ConnectionSet { +class ConnectionSet : public RSocketStateMachine::CloseCallback { public: ConnectionSet(); - ~ConnectionSet(); + virtual ~ConnectionSet(); + + bool insert(std::shared_ptr, folly::EventBase*); + void remove(RSocketStateMachine&) override; - void insert(std::shared_ptr, folly::EventBase*); + size_t size() const; - void remove(const std::shared_ptr&); + void shutdownAndWait(); private: using StateMachineMap = std:: unordered_map, folly::EventBase*>; folly::Synchronized machines_; + folly::Baton<> shutdownDone_; + size_t removes_{0}; + size_t targetRemoves_{0}; + std::atomic shutDown_{false}; }; -} + +} // namespace rsocket diff --git a/rsocket/internal/KeepaliveTimer.cpp b/rsocket/internal/KeepaliveTimer.cpp index 6b61683ee..6fdaa39d0 100644 --- a/rsocket/internal/KeepaliveTimer.cpp +++ b/rsocket/internal/KeepaliveTimer.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/KeepaliveTimer.h" @@ -15,34 +27,40 @@ KeepaliveTimer::~KeepaliveTimer() { stop(); } -std::chrono::milliseconds KeepaliveTimer::keepaliveTime() { +std::chrono::milliseconds KeepaliveTimer::keepaliveTime() const { return period_; } void KeepaliveTimer::schedule() { - auto scheduledGeneration = *generation_; - auto generation = generation_; + const auto scheduledGeneration = *generation_; + const auto generation = generation_; eventBase_.runAfterDelay( - [this, generation, scheduledGeneration]() { + [this, + wpConnection = std::weak_ptr(connection_), + generation, + scheduledGeneration]() { + auto spConnection = wpConnection.lock(); + if (!spConnection) { + return; + } if (*generation == scheduledGeneration) { - sendKeepalive(); + sendKeepalive(*spConnection); } }, static_cast(keepaliveTime().count())); } -void KeepaliveTimer::sendKeepalive() { +void KeepaliveTimer::sendKeepalive(FrameSink& sink) { if (pending_) { - // Make sure connection_ is not deleted (via external call to stop) - // while we still mid-operation - auto localPtr = connection_; stop(); // TODO: we need to use max lifetime from the setup frame for this - localPtr->disconnectOrCloseWithError( + sink.disconnectOrCloseWithError( Frame_ERROR::connectionError("no response to keepalive")); } else { - connection_->sendKeepalive(); + // this must happen before sendKeepalive as it can potentially result in + // stop() being called pending_ = true; + sink.sendKeepalive(); schedule(); } } @@ -51,7 +69,7 @@ void KeepaliveTimer::sendKeepalive() { void KeepaliveTimer::stop() { *generation_ += 1; pending_ = false; - connection_ = nullptr; + connection_.reset(); } // must be called from the same thread as stop @@ -66,4 +84,4 @@ void KeepaliveTimer::start(const std::shared_ptr& connection) { void KeepaliveTimer::keepaliveReceived() { pending_ = false; } -} +} // namespace rsocket diff --git a/rsocket/internal/KeepaliveTimer.h b/rsocket/internal/KeepaliveTimer.h index 3bf764b1b..51bb6c3c2 100644 --- a/rsocket/internal/KeepaliveTimer.h +++ b/rsocket/internal/KeepaliveTimer.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -14,7 +26,7 @@ class KeepaliveTimer { ~KeepaliveTimer(); - std::chrono::milliseconds keepaliveTime(); + std::chrono::milliseconds keepaliveTime() const; void schedule(); @@ -22,15 +34,15 @@ class KeepaliveTimer { void start(const std::shared_ptr& connection); - void sendKeepalive(); + void sendKeepalive(FrameSink& sink); void keepaliveReceived(); private: std::shared_ptr connection_; folly::EventBase& eventBase_; - std::shared_ptr generation_; - std::chrono::milliseconds period_; + const std::shared_ptr generation_; + const std::chrono::milliseconds period_; std::atomic pending_{false}; }; -} +} // namespace rsocket diff --git a/rsocket/internal/ScheduledRSocketResponder.cpp b/rsocket/internal/ScheduledRSocketResponder.cpp index a047c0513..d534657c8 100644 --- a/rsocket/internal/ScheduledRSocketResponder.cpp +++ b/rsocket/internal/ScheduledRSocketResponder.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ScheduledRSocketResponder.h" @@ -11,67 +23,60 @@ namespace rsocket { ScheduledRSocketResponder::ScheduledRSocketResponder( std::shared_ptr inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) {} + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} -yarpl::Reference> +std::shared_ptr> ScheduledRSocketResponder::handleRequestResponse( Payload request, StreamId streamId) { - auto innerFlowable = inner_->handleRequestResponse(std::move(request), - streamId); + auto innerFlowable = + inner_->handleRequestResponse(std::move(request), streamId); return yarpl::single::Singles::create( - [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( - yarpl::Reference> - observer) { - innerFlowable->subscribe(yarpl::make_ref< - ScheduledSingleObserver> - (std::move(observer), *eventBase)); - }); + [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( + std::shared_ptr> observer) { + innerFlowable->subscribe( + std::make_shared>( + std::move(observer), *eventBase)); + }); } -yarpl::Reference> +std::shared_ptr> ScheduledRSocketResponder::handleRequestStream( Payload request, StreamId streamId) { - auto innerFlowable = inner_->handleRequestStream(std::move(request), - streamId); - return yarpl::flowable::Flowables::fromPublisher( - [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( - yarpl::Reference> - subscriber) { - innerFlowable->subscribe(yarpl::make_ref< - ScheduledSubscriber> - (std::move(subscriber), *eventBase)); - }); + auto innerFlowable = + inner_->handleRequestStream(std::move(request), streamId); + return yarpl::flowable::internal::flowableFromSubscriber( + [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( + std::shared_ptr> subscriber) { + innerFlowable->subscribe(std::make_shared>( + std::move(subscriber), *eventBase)); + }); } -yarpl::Reference> +std::shared_ptr> ScheduledRSocketResponder::handleRequestChannel( Payload request, - yarpl::Reference> - requestStream, + std::shared_ptr> requestStream, StreamId streamId) { - auto requestStreamFlowable = yarpl::flowable::Flowables::fromPublisher( - [requestStream = std::move(requestStream), eventBase = &eventBase_]( - yarpl::Reference> - subscriber) { - requestStream->subscribe(yarpl::make_ref< - ScheduledSubscriptionSubscriber> - (std::move(subscriber), *eventBase)); - }); - auto innerFlowable = inner_->handleRequestChannel(std::move(request), - std::move( - requestStreamFlowable), - streamId); - return yarpl::flowable::Flowables::fromPublisher( - [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( - yarpl::Reference> - subscriber) { - innerFlowable->subscribe(yarpl::make_ref< - ScheduledSubscriber> - (std::move(subscriber), *eventBase)); - }); + auto requestStreamFlowable = + yarpl::flowable::internal::flowableFromSubscriber( + [requestStream = std::move(requestStream), eventBase = &eventBase_]( + std::shared_ptr> + subscriber) { + requestStream->subscribe( + std::make_shared>( + std::move(subscriber), *eventBase)); + }); + auto innerFlowable = inner_->handleRequestChannel( + std::move(request), std::move(requestStreamFlowable), streamId); + return yarpl::flowable::internal::flowableFromSubscriber( + [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( + std::shared_ptr> subscriber) { + innerFlowable->subscribe(std::make_shared>( + std::move(subscriber), *eventBase)); + }); } void ScheduledRSocketResponder::handleFireAndForget( @@ -80,4 +85,4 @@ void ScheduledRSocketResponder::handleFireAndForget( inner_->handleFireAndForget(std::move(request), streamId); } -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledRSocketResponder.h b/rsocket/internal/ScheduledRSocketResponder.h index e943c6ef2..fe9039dcc 100644 --- a/rsocket/internal/ScheduledRSocketResponder.h +++ b/rsocket/internal/ScheduledRSocketResponder.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -20,30 +32,24 @@ class ScheduledRSocketResponder : public RSocketResponder { std::shared_ptr inner, folly::EventBase& eventBase); - yarpl::Reference> - handleRequestResponse( + std::shared_ptr> handleRequestResponse( Payload request, StreamId streamId) override; - yarpl::Reference> - handleRequestStream( + std::shared_ptr> handleRequestStream( Payload request, StreamId streamId) override; - yarpl::Reference> - handleRequestChannel( + std::shared_ptr> handleRequestChannel( Payload request, - yarpl::Reference> - requestStream, + std::shared_ptr> requestStream, StreamId streamId) override; - void handleFireAndForget( - Payload request, - StreamId streamId) override; + void handleFireAndForget(Payload request, StreamId streamId) override; private: - std::shared_ptr inner_; + const std::shared_ptr inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSingleObserver.h b/rsocket/internal/ScheduledSingleObserver.h index 878689334..167b5458e 100644 --- a/rsocket/internal/ScheduledSingleObserver.h +++ b/rsocket/internal/ScheduledSingleObserver.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -17,24 +29,23 @@ namespace rsocket { // application code so that calls to on{Subscribe,Success,Error} are // scheduled on the right EventBase. // -template +template class ScheduledSingleObserver : public yarpl::single::SingleObserver { public: ScheduledSingleObserver( - yarpl::Reference> observer, - folly::EventBase& eventBase) : - inner_(std::move(observer)), eventBase_(eventBase) {} + std::shared_ptr> observer, + folly::EventBase& eventBase) + : inner_(std::move(observer)), eventBase_(eventBase) {} - void onSubscribe( - yarpl::Reference subscription) override { + void onSubscribe(std::shared_ptr + subscription) override { if (eventBase_.isInEventBaseThread()) { inner_->onSubscribe(std::move(subscription)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, subscription = std::move(subscription)] - { - inner->onSubscribe(std::move(subscription)); - }); + [inner = inner_, subscription = std::move(subscription)] { + inner->onSubscribe(std::move(subscription)); + }); } } @@ -44,9 +55,9 @@ class ScheduledSingleObserver : public yarpl::single::SingleObserver { inner_->onSuccess(std::move(value)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, value = std::move(value)]() mutable { - inner->onSuccess(std::move(value)); - }); + [inner = inner_, value = std::move(value)]() mutable { + inner->onSuccess(std::move(value)); + }); } } @@ -56,14 +67,14 @@ class ScheduledSingleObserver : public yarpl::single::SingleObserver { inner_->onError(std::move(ex)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, ex = std::move(ex)]() mutable { - inner->onError(std::move(ex)); - }); + [inner = inner_, ex = std::move(ex)]() mutable { + inner->onError(std::move(ex)); + }); } } private: - yarpl::Reference> inner_; + const std::shared_ptr> inner_; folly::EventBase& eventBase_; }; @@ -73,18 +84,19 @@ class ScheduledSingleObserver : public yarpl::single::SingleObserver { // application code will be wrapped with a scheduled subscription to make the // call to Subscription::cancel safe. // -template -class ScheduledSubscriptionSingleObserver : public yarpl::single::SingleObserver { +template +class ScheduledSubscriptionSingleObserver + : public yarpl::single::SingleObserver { public: ScheduledSubscriptionSingleObserver( - yarpl::Reference> observer, - folly::EventBase& eventBase) : - inner_(std::move(observer)), eventBase_(eventBase) {} + std::shared_ptr> observer, + folly::EventBase& eventBase) + : inner_(std::move(observer)), eventBase_(eventBase) {} - void onSubscribe( - yarpl::Reference subscription) override { - inner_->onSubscribe( - yarpl::make_ref(std::move(subscription), eventBase_)); + void onSubscribe(std::shared_ptr + subscription) override { + inner_->onSubscribe(std::make_shared( + std::move(subscription), eventBase_)); } // No further calls to the subscription after this method is invoked. @@ -98,7 +110,7 @@ class ScheduledSubscriptionSingleObserver : public yarpl::single::SingleObserver } private: - yarpl::Reference> inner_; + const std::shared_ptr> inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSingleSubscription.cpp b/rsocket/internal/ScheduledSingleSubscription.cpp index 4f5167608..b56f76c0d 100644 --- a/rsocket/internal/ScheduledSingleSubscription.cpp +++ b/rsocket/internal/ScheduledSingleSubscription.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ScheduledSingleSubscription.h" @@ -7,20 +19,16 @@ namespace rsocket { ScheduledSingleSubscription::ScheduledSingleSubscription( - yarpl::Reference inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) { -} + std::shared_ptr inner, + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} void ScheduledSingleSubscription::cancel() { if (eventBase_.isInEventBaseThread()) { inner_->cancel(); } else { - eventBase_.runInEventBaseThread([inner = inner_] - { - inner->cancel(); - }); + eventBase_.runInEventBaseThread([inner = inner_] { inner->cancel(); }); } } -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSingleSubscription.h b/rsocket/internal/ScheduledSingleSubscription.h index 5877c4914..1d29412e4 100644 --- a/rsocket/internal/ScheduledSingleSubscription.h +++ b/rsocket/internal/ScheduledSingleSubscription.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -11,20 +23,20 @@ class EventBase; namespace rsocket { // -// A decorator of the SingleSubscription object which schedules the method calls on the -// provided EventBase +// A decorator of the SingleSubscription object which schedules the method calls +// on the provided EventBase // class ScheduledSingleSubscription : public yarpl::single::SingleSubscription { public: ScheduledSingleSubscription( - yarpl::Reference inner, + std::shared_ptr inner, folly::EventBase& eventBase); void cancel() override; private: - yarpl::Reference inner_; + const std::shared_ptr inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSubscriber.h b/rsocket/internal/ScheduledSubscriber.h index f2b4e6109..f73ee44f3 100644 --- a/rsocket/internal/ScheduledSubscriber.h +++ b/rsocket/internal/ScheduledSubscriber.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -22,20 +34,19 @@ template class ScheduledSubscriber : public yarpl::flowable::Subscriber { public: ScheduledSubscriber( - yarpl::Reference> inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) {} + std::shared_ptr> inner, + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} void onSubscribe( - yarpl::Reference subscription) override { + std::shared_ptr subscription) override { if (eventBase_.isInEventBaseThread()) { inner_->onSubscribe(std::move(subscription)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, subscription = std::move(subscription)] - { - inner->onSubscribe(std::move(subscription)); - }); + [inner = inner_, subscription = std::move(subscription)] { + inner->onSubscribe(std::move(subscription)); + }); } } @@ -45,10 +56,7 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { inner_->onComplete(); } else { eventBase_.runInEventBaseThread( - [inner = inner_] - { - inner->onComplete(); - }); + [inner = inner_] { inner->onComplete(); }); } } @@ -57,9 +65,9 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { inner_->onError(std::move(ex)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, ex = std::move(ex)]() mutable { - inner->onError(std::move(ex)); - }); + [inner = inner_, ex = std::move(ex)]() mutable { + inner->onError(std::move(ex)); + }); } } @@ -68,14 +76,14 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { inner_->onNext(std::move(value)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, value = std::move(value)]() mutable { - inner->onNext(std::move(value)); - }); + [inner = inner_, value = std::move(value)]() mutable { + inner->onNext(std::move(value)); + }); } } private: - yarpl::Reference> inner_; + const std::shared_ptr> inner_; folly::EventBase& eventBase_; }; @@ -88,36 +96,37 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { // request and cancel from any thread. // template -class ScheduledSubscriptionSubscriber - : public yarpl::flowable::Subscriber { +class ScheduledSubscriptionSubscriber : public yarpl::flowable::Subscriber { public: ScheduledSubscriptionSubscriber( - yarpl::Reference> inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) {} + std::shared_ptr> inner, + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} void onSubscribe( - yarpl::Reference subscription) override { - inner_->onSubscribe( - yarpl::make_ref(subscription, eventBase_)); + std::shared_ptr sub) override { + auto scheduled = + std::make_shared(std::move(sub), eventBase_); + inner_->onSubscribe(std::move(scheduled)); } - // No further calls to the subscription after this method is invoked. - void onComplete() override { - inner_->onComplete(); + void onNext(T value) override { + inner_->onNext(std::move(value)); } - void onError(folly::exception_wrapper ex) override { - inner_->onError(std::move(ex)); + void onComplete() override { + auto inner = std::move(inner_); + inner->onComplete(); } - void onNext(T value) override { - inner_->onNext(std::move(value)); + void onError(folly::exception_wrapper ew) override { + auto inner = std::move(inner_); + inner->onError(std::move(ew)); } private: - yarpl::Reference> inner_; + std::shared_ptr> inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSubscription.cpp b/rsocket/internal/ScheduledSubscription.cpp index 761f9aa0e..a92687aa9 100644 --- a/rsocket/internal/ScheduledSubscription.cpp +++ b/rsocket/internal/ScheduledSubscription.cpp @@ -1,37 +1,42 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ScheduledSubscription.h" -#include - namespace rsocket { ScheduledSubscription::ScheduledSubscription( - yarpl::Reference inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) { -} + std::shared_ptr inner, + folly::EventBase& eventBase) + : inner_{std::move(inner)}, eventBase_{eventBase} {} -void ScheduledSubscription::request(int64_t n) noexcept { +void ScheduledSubscription::request(int64_t n) { if (eventBase_.isInEventBaseThread()) { inner_->request(n); } else { - eventBase_.runInEventBaseThread([inner = inner_, n] - { - inner->request(n); - }); + eventBase_.runInEventBaseThread([inner = inner_, n] { inner->request(n); }); } } -void ScheduledSubscription::cancel() noexcept { +void ScheduledSubscription::cancel() { if (eventBase_.isInEventBaseThread()) { - inner_->cancel(); + auto inner = std::move(inner_); + inner->cancel(); } else { - eventBase_.runInEventBaseThread([inner = inner_] - { - inner->cancel(); - }); + eventBase_.runInEventBaseThread( + [inner = std::move(inner_)] { inner->cancel(); }); } } -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSubscription.h b/rsocket/internal/ScheduledSubscription.h index 9e595472f..14c058cb4 100644 --- a/rsocket/internal/ScheduledSubscription.h +++ b/rsocket/internal/ScheduledSubscription.h @@ -1,32 +1,39 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "yarpl/flowable/Subscription.h" +#include -namespace folly { -class EventBase; -} +#include "yarpl/flowable/Subscription.h" namespace rsocket { -// -// A decorator of the Subscription object which schedules the method calls on the -// provided EventBase -// +// A wrapper over Subscription that schedules all of the subscription's methods +// on an EventBase. class ScheduledSubscription : public yarpl::flowable::Subscription { public: ScheduledSubscription( - yarpl::Reference inner, - folly::EventBase& eventBase); - - void request(int64_t n) noexcept override; + std::shared_ptr, + folly::EventBase&); - void cancel() noexcept override; + void request(int64_t) override; + void cancel() override; private: - yarpl::Reference inner_; + std::shared_ptr inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/SetupResumeAcceptor.cpp b/rsocket/internal/SetupResumeAcceptor.cpp index dc06d8f46..828e4dbdd 100644 --- a/rsocket/internal/SetupResumeAcceptor.cpp +++ b/rsocket/internal/SetupResumeAcceptor.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/SetupResumeAcceptor.h" @@ -8,73 +20,74 @@ #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameProcessor.h" #include "rsocket/framing/FrameSerializer.h" -#include "rsocket/framing/FrameTransportImpl.h" namespace rsocket { -namespace { - -/// FrameProcessor that does nothing. Necessary to tell a FrameTransport it can -/// output frames in the cases where we want to error it. -class NoneFrameProcessor final : public FrameProcessor { - void processFrame(std::unique_ptr) override {} - void onTerminal(folly::exception_wrapper) override {} -}; - -} // namespace +/// Subscriber that owns a connection, sets itself as that connection's input, +/// and reads out a single frame before cancelling. +class SetupResumeAcceptor::OneFrameSubscriber final + : public yarpl::flowable::BaseSubscriber> { + public: + OneFrameSubscriber( + SetupResumeAcceptor& acceptor, + std::unique_ptr connection, + SetupResumeAcceptor::OnSetup onSetup, + SetupResumeAcceptor::OnResume onResume) + : acceptor_{acceptor}, + connection_{std::move(connection)}, + onSetup_{std::move(onSetup)}, + onResume_{std::move(onResume)} { + DCHECK(connection_); + DCHECK(onSetup_); + DCHECK(onResume_); + DCHECK(acceptor_.inOwnerThread()); + } -SetupResumeAcceptor::OneFrameSubscriber::OneFrameSubscriber( - SetupResumeAcceptor& acceptor, - std::unique_ptr connection, - SetupResumeAcceptor::OnSetup onSetup, - SetupResumeAcceptor::OnResume onResume) - : acceptor_{acceptor}, - connection_{std::move(connection)}, - onSetup_{std::move(onSetup)}, - onResume_{std::move(onResume)} { - DCHECK(connection_); - DCHECK(onSetup_); - DCHECK(onResume_); - DCHECK(acceptor_.inOwnerThread()); -} + void setInput() { + DCHECK(acceptor_.inOwnerThread()); + connection_->setInput(ref_from_this(this)); + } -void SetupResumeAcceptor::OneFrameSubscriber::setInput() { - DCHECK(acceptor_.inOwnerThread()); - connection_->setInput(ref_from_this(this)); -} + /// Shut down the DuplexConnection, breaking the cycle between it and this + /// subscriber. Expects the DuplexConnection's destructor to call + /// onComplete/onError on its input subscriber (this). + void close() { + auto self = ref_from_this(this); + connection_.reset(); + } -void SetupResumeAcceptor::OneFrameSubscriber::close() { - auto self = ref_from_this(this); - connection_.reset(); -} + void onSubscribeImpl() override { + DCHECK(acceptor_.inOwnerThread()); + this->request(1); + } -void SetupResumeAcceptor::OneFrameSubscriber::onSubscribeImpl() { - DCHECK(acceptor_.inOwnerThread()); - this->request(std::numeric_limits::max()); -} + void onNextImpl(std::unique_ptr buf) override { + DCHECK(connection_) << "OneFrameSubscriber received more than one frame"; + DCHECK(acceptor_.inOwnerThread()); -void SetupResumeAcceptor::OneFrameSubscriber::onNextImpl( - std::unique_ptr buf) { - DCHECK(connection_) << "OneFrameSubscriber received more than one frame"; - DCHECK(acceptor_.inOwnerThread()); + this->cancel(); // calls onTerminateImpl - this->cancel(); // calls onTerminateImpl + acceptor_.processFrame( + std::move(connection_), + std::move(buf), + std::move(onSetup_), + std::move(onResume_)); + } - acceptor_.processFrame( - std::move(connection_), - std::move(buf), - std::move(onSetup_), - std::move(onResume_)); -} + void onCompleteImpl() override {} + void onErrorImpl(folly::exception_wrapper) override {} -void SetupResumeAcceptor::OneFrameSubscriber::onCompleteImpl() {} -void SetupResumeAcceptor::OneFrameSubscriber::onErrorImpl( - folly::exception_wrapper) {} + void onTerminateImpl() override { + DCHECK(acceptor_.inOwnerThread()); + acceptor_.remove(ref_from_this(this)); + } -void SetupResumeAcceptor::OneFrameSubscriber::onTerminateImpl() { - DCHECK(acceptor_.inOwnerThread()); - acceptor_.remove(ref_from_this(this)); -} + private: + SetupResumeAcceptor& acceptor_; + std::unique_ptr connection_; + SetupResumeAcceptor::OnSetup onSetup_; + SetupResumeAcceptor::OnResume onResume_; +}; SetupResumeAcceptor::SetupResumeAcceptor(folly::EventBase* eventBase) : eventBase_{eventBase} { @@ -97,7 +110,7 @@ void SetupResumeAcceptor::processFrame( return; } - auto serializer = FrameSerializer::createAutodetectedSerializer(*buf); + const auto serializer = FrameSerializer::createAutodetectedSerializer(*buf); if (!serializer) { VLOG(2) << "Unable to detect protocol version"; return; @@ -107,7 +120,7 @@ void SetupResumeAcceptor::processFrame( case FrameType::SETUP: { Frame_SETUP frame; if (!serializer->deserializeFrom(frame, std::move(buf))) { - std::string msg{"Cannot decode SETUP frame"}; + constexpr auto msg = "Cannot decode SETUP frame"; auto err = serializer->serializeOut(Frame_ERROR::connectionError(msg)); connection->send(std::move(err)); break; @@ -119,30 +132,20 @@ void SetupResumeAcceptor::processFrame( frame.moveToSetupPayload(params); if (serializer->protocolVersion() != params.protocolVersion) { - std::string msg{"SETUP frame has invalid protocol version"}; + constexpr auto msg = "SETUP frame has invalid protocol version"; auto err = serializer->serializeOut(Frame_ERROR::invalidSetup(msg)); connection->send(std::move(err)); break; } - auto transport = - yarpl::make_ref(std::move(connection)); - - try { - onSetup(transport, std::move(params)); - } catch (const std::exception& exn) { - auto err = Frame_ERROR::rejectedSetup(exn.what()); - transport->setFrameProcessor(std::make_shared()); - transport->outputFrameOrDrop(serializer->serializeOut(std::move(err))); - transport->close(); - } + onSetup(std::move(connection), std::move(params)); break; } case FrameType::RESUME: { Frame_RESUME frame; if (!serializer->deserializeFrom(frame, std::move(buf))) { - std::string msg{"Cannot decode RESUME frame"}; + constexpr auto msg = "Cannot decode RESUME frame"; auto err = serializer->serializeOut(Frame_ERROR::connectionError(msg)); connection->send(std::move(err)); break; @@ -157,28 +160,18 @@ void SetupResumeAcceptor::processFrame( ProtocolVersion(frame.versionMajor_, frame.versionMinor_)); if (serializer->protocolVersion() != params.protocolVersion) { - std::string msg{"RESUME frame has invalid protocol version"}; + constexpr auto msg = "RESUME frame has invalid protocol version"; auto err = serializer->serializeOut(Frame_ERROR::rejectedResume(msg)); connection->send(std::move(err)); break; } - auto transport = - yarpl::make_ref(std::move(connection)); - - try { - onResume(transport, std::move(params)); - } catch (const std::exception& exn) { - auto err = Frame_ERROR::rejectedResume(exn.what()); - transport->setFrameProcessor(std::make_shared()); - transport->outputFrameOrDrop(serializer->serializeOut(std::move(err))); - transport->close(); - } + onResume(std::move(connection), std::move(params)); break; } default: { - std::string msg{"Invalid frame, expected SETUP/RESUME"}; + constexpr auto msg = "Invalid frame, expected SETUP/RESUME"; auto err = serializer->serializeOut(Frame_ERROR::connectionError(msg)); connection->send(std::move(err)); break; @@ -196,14 +189,14 @@ void SetupResumeAcceptor::accept( return; } - auto subscriber = yarpl::make_ref( + const auto subscriber = std::make_shared( *this, std::move(connection), std::move(onSetup), std::move(onResume)); connections_.insert(subscriber); subscriber->setInput(); } void SetupResumeAcceptor::remove( - const yarpl::Reference& + const std::shared_ptr& subscriber) { DCHECK(inOwnerThread()); connections_.erase(subscriber); diff --git a/rsocket/internal/SetupResumeAcceptor.h b/rsocket/internal/SetupResumeAcceptor.h index 96573d89b..7ae246e78 100644 --- a/rsocket/internal/SetupResumeAcceptor.h +++ b/rsocket/internal/SetupResumeAcceptor.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -21,8 +33,6 @@ class exception_wrapper; namespace rsocket { -class FrameTransport; - /// Acceptor of DuplexConnections that lets us decide whether the connection is /// trying to setup a new connection or resume an existing one. /// @@ -30,10 +40,10 @@ class FrameTransport; /// SetupResumeAcceptor::accept() entry point is not thread-safe. class SetupResumeAcceptor final { public: - using OnSetup = - folly::Function, SetupParameters)>; - using OnResume = - folly::Function, ResumeParameters)>; + using OnSetup = folly::Function< + void(std::unique_ptr, SetupParameters) noexcept>; + using OnResume = folly::Function< + void(std::unique_ptr, ResumeParameters) noexcept>; explicit SetupResumeAcceptor(folly::EventBase*); ~SetupResumeAcceptor(); @@ -48,37 +58,7 @@ class SetupResumeAcceptor final { folly::Future close(); private: - /// Subscriber that owns a connection, sets itself as that connection's input, - /// and reads out a single frame before cancelling. - class OneFrameSubscriber final - : public yarpl::flowable::BaseSubscriber> { - public: - OneFrameSubscriber( - SetupResumeAcceptor&, - std::unique_ptr, - SetupResumeAcceptor::OnSetup, - SetupResumeAcceptor::OnResume); - - void setInput(); - - /// Shut down the DuplexConnection, breaking the cycle between it and this - /// subscriber. Expects the DuplexConnection's destructor to call - /// onComplete/onError on its input subscriber (this). - void close(); - - // Subscriber. - void onSubscribeImpl() override; - void onNextImpl(std::unique_ptr) override; - void onCompleteImpl() override; - void onErrorImpl(folly::exception_wrapper) override; - void onTerminateImpl() override; - - private: - SetupResumeAcceptor& acceptor_; - std::unique_ptr connection_; - SetupResumeAcceptor::OnSetup onSetup_; - SetupResumeAcceptor::OnResume onResume_; - }; + class OneFrameSubscriber; void processFrame( std::unique_ptr, @@ -87,7 +67,7 @@ class SetupResumeAcceptor final { OnResume); /// Remove a OneFrameSubscriber from the set. - void remove(const yarpl::Reference&); + void remove(const std::shared_ptr&); /// Close all open connections. void closeAll(); @@ -100,11 +80,11 @@ class SetupResumeAcceptor final { /// work within the owner thread. bool inOwnerThread() const; - std::unordered_set> connections_; + std::unordered_set> connections_; bool closed_{false}; - folly::EventBase* eventBase_; + folly::EventBase* const eventBase_; }; } // namespace rsocket diff --git a/rsocket/internal/StackTraceUtils.h b/rsocket/internal/StackTraceUtils.h index 4d8d05069..b99d5b943 100644 --- a/rsocket/internal/StackTraceUtils.h +++ b/rsocket/internal/StackTraceUtils.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -14,4 +26,4 @@ inline std::string getStackTrace() { } #endif -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/SwappableEventBase.cpp b/rsocket/internal/SwappableEventBase.cpp index 7abd96fe0..f745a9365 100644 --- a/rsocket/internal/SwappableEventBase.cpp +++ b/rsocket/internal/SwappableEventBase.cpp @@ -1,34 +1,47 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "SwappableEventBase.h" namespace rsocket { bool SwappableEventBase::runInEventBaseThread(CbFunc cb) { - std::lock_guard l(hasSebDtored_->l_); + const std::lock_guard l(hasSebDtored_->l_); - if(this->isSwapping()) { + if (this->isSwapping()) { queued_.push_back(std::move(cb)); return false; } - return eb_->runInEventBaseThread([eb = eb_, cb_ = std::move(cb)]() mutable { - return cb_(*eb); - }); + eb_->runInEventBaseThread( + [eb = eb_, cb_ = std::move(cb)]() mutable { return cb_(*eb); }); + + return true; } void SwappableEventBase::setEventBase(folly::EventBase& newEb) { - std::lock_guard l(hasSebDtored_->l_); + const std::lock_guard l(hasSebDtored_->l_); auto const alreadySwapping = this->isSwapping(); nextEb_ = &newEb; - if(alreadySwapping) { + if (alreadySwapping) { return; } eb_->runInEventBaseThread([this, hasSebDtored = hasSebDtored_]() { - std::lock_guard lInner(hasSebDtored->l_); - if(hasSebDtored->destroyed_) { + const std::lock_guard lInner(hasSebDtored->l_); + if (hasSebDtored->destroyed_) { // SEB was destroyed, any queued callbacks were appended to the old eb_ return; } @@ -38,10 +51,9 @@ void SwappableEventBase::setEventBase(folly::EventBase& newEb) { // enqueue tasks that were being buffered while this was waiting // for the previous EB to drain - for(auto& cb : queued_) { - eb_->runInEventBaseThread([cb = std::move(cb), eb = eb_]() mutable { - return cb(*eb); - }); + for (auto& cb : queued_) { + eb_->runInEventBaseThread( + [cb = std::move(cb), eb = eb_]() mutable { return cb(*eb); }); } queued_.clear(); @@ -53,13 +65,12 @@ bool SwappableEventBase::isSwapping() const { } SwappableEventBase::~SwappableEventBase() { - std::lock_guard l(hasSebDtored_->l_); + const std::lock_guard l(hasSebDtored_->l_); hasSebDtored_->destroyed_ = true; - for(auto& cb : queued_) { - eb_->runInEventBaseThread([cb = std::move(cb), eb = eb_]() mutable { - return cb(*eb); - }); + for (auto& cb : queued_) { + eb_->runInEventBaseThread( + [cb = std::move(cb), eb = eb_]() mutable { return cb(*eb); }); } queued_.clear(); } diff --git a/rsocket/internal/SwappableEventBase.h b/rsocket/internal/SwappableEventBase.h index 97df41368..456eb67bf 100644 --- a/rsocket/internal/SwappableEventBase.h +++ b/rsocket/internal/SwappableEventBase.h @@ -1,9 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include #include +#include #include namespace rsocket { @@ -18,16 +30,16 @@ class SwappableEventBase final { // lock for synchronization on destroyed_, and all members of the parent SEB std::mutex l_; // has the SEB's destructor ran? - bool destroyed_ {false}; + bool destroyed_{false}; }; -public: + public: using CbFunc = folly::Function; explicit SwappableEventBase(folly::EventBase& eb) - : eb_(&eb), - nextEb_(nullptr), - hasSebDtored_(std::make_shared()) {} + : eb_(&eb), + nextEb_(nullptr), + hasSebDtored_(std::make_shared()) {} // Run or enqueue 'cb', in order with all prior calls to runInEventBaseThread // If setEventBase has been called, and the prior EventBase is still @@ -47,7 +59,7 @@ class SwappableEventBase final { // there are any pending by the time the SEB is destroyed ~SwappableEventBase(); -private: + private: folly::EventBase* eb_; folly::EventBase* nextEb_; // also indicate if we're in the middle of a swap @@ -66,5 +78,4 @@ class SwappableEventBase final { std::vector queued_; }; - -} /* ns rsocket */ +} // namespace rsocket diff --git a/rsocket/internal/WarmResumeManager.cpp b/rsocket/internal/WarmResumeManager.cpp index 436dadf2f..c67de86e9 100644 --- a/rsocket/internal/WarmResumeManager.cpp +++ b/rsocket/internal/WarmResumeManager.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/WarmResumeManager.h" @@ -29,7 +41,7 @@ void WarmResumeManager::trackSentFrame( size_t consumerAllowance) { if (shouldTrackFrame(frameType)) { // TODO(tmont): this could be expensive, find a better way to get length - auto frameDataLength = serializedFrame.computeChainDataLength(); + const auto frameDataLength = serializedFrame.computeChainDataLength(); VLOG(6) << "Track sent frame " << frameType << " Allowance: " << consumerAllowance; @@ -90,8 +102,8 @@ void WarmResumeManager::addFrame( void WarmResumeManager::evictFrame() { DCHECK(!frames_.empty()); - auto position = frames_.size() > 1 ? std::next(frames_.begin())->first - : lastSentPosition_; + const auto position = frames_.size() > 1 ? std::next(frames_.begin())->first + : lastSentPosition_; resetUpToPosition(position); } @@ -102,7 +114,7 @@ void WarmResumeManager::clearFrames(ResumePosition position) { DCHECK(position <= lastSentPosition_); DCHECK(position >= firstSentPosition_); - auto end = std::lower_bound( + const auto end = std::lower_bound( frames_.begin(), frames_.end(), position, @@ -110,7 +122,7 @@ void WarmResumeManager::clearFrames(ResumePosition position) { return pair.first < pos; }); DCHECK(end == frames_.end() || end->first >= firstSentPosition_); - auto pos = end == frames_.end() ? position : end->first; + const auto pos = end == frames_.end() ? position : end->first; stats_->resumeBufferChanged( -static_cast(std::distance(frames_.begin(), end)), -static_cast(pos - firstSentPosition_)); @@ -146,4 +158,16 @@ void WarmResumeManager::sendFramesFromPosition( } } -} // reactivesocket +std::shared_ptr ResumeManager::makeEmpty() { + class Empty : public WarmResumeManager { + public: + Empty() : WarmResumeManager(nullptr, 0) {} + bool shouldTrackFrame(FrameType) const override { + return false; + } + }; + + return std::make_shared(); +} + +} // namespace rsocket diff --git a/rsocket/internal/WarmResumeManager.h b/rsocket/internal/WarmResumeManager.h index 15d78f539..b14969ca9 100644 --- a/rsocket/internal/WarmResumeManager.h +++ b/rsocket/internal/WarmResumeManager.h @@ -1,9 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include +#include + #include "rsocket/RSocketStats.h" #include "rsocket/ResumeManager.h" @@ -58,20 +72,22 @@ class WarmResumeManager : public ResumeManager { // No action to perform for WarmResumeManager void onStreamOpen(StreamId, RequestOriginator, std::string, StreamType) - override{}; + override {} // No action to perform for WarmResumeManager - void onStreamClosed(StreamId) override{}; + void onStreamClosed(StreamId) override {} - const StreamResumeInfos& getStreamResumeInfos() override { + const StreamResumeInfos& getStreamResumeInfos() const override { LOG(FATAL) << "Not Implemented for Warm Resumption"; + folly::assume_unreachable(); } - StreamId getLargestUsedStreamId() override { + StreamId getLargestUsedStreamId() const override { LOG(FATAL) << "Not Implemented for Warm Resumption"; + folly::assume_unreachable(); } - size_t size() { + size_t size() const { return size_; } @@ -82,7 +98,7 @@ class WarmResumeManager : public ResumeManager { // Called before clearing cached frames to update stats. void clearFrames(ResumePosition position); - std::shared_ptr stats_; + const std::shared_ptr stats_; // Start position of the send buffer queue ResumePosition firstSentPosition_{0}; @@ -97,4 +113,4 @@ class WarmResumeManager : public ResumeManager { const size_t capacity_; size_t size_{0}; }; -} +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelRequester.cpp b/rsocket/statemachine/ChannelRequester.cpp index 2dd993ce9..6798613a1 100644 --- a/rsocket/statemachine/ChannelRequester.cpp +++ b/rsocket/statemachine/ChannelRequester.cpp @@ -1,75 +1,71 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/ChannelRequester.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void ChannelRequester::onSubscribe( - Reference subscription) noexcept { + std::shared_ptr subscription) { CHECK(!requested_); publisherSubscribe(std::move(subscription)); + + if (hasInitialRequest_) { + initStream(std::move(request_)); + } } -void ChannelRequester::onNext(Payload request) noexcept { +void ChannelRequester::onNext(Payload request) { if (!requested_) { - requested_ = true; - - size_t initialN = - initialResponseAllowance_.consumeUpTo(Frame_REQUEST_N::kMaxRequestN); - size_t remainingN = initialResponseAllowance_.consumeAll(); - // Send as much as possible with the initial request. - CHECK_GE(Frame_REQUEST_N::kMaxRequestN, initialN); - newStream( - StreamType::CHANNEL, - static_cast(initialN), - std::move(request), - false); - // We must inform ConsumerBase about an implicit allowance we have - // requested from the remote end. - ConsumerBase::addImplicitAllowance(initialN); - // Pump the remaining allowance into the ConsumerBase _after_ sending the - // initial request. - if (remainingN) { - ConsumerBase::generateRequest(remainingN); - } + initStream(std::move(request)); return; } - checkPublisherOnNext(); if (!publisherClosed()) { - writePayload(std::move(request), false); + writePayload(std::move(request)); } } // TODO: consolidate code in onCompleteImpl, onErrorImpl, cancelImpl -void ChannelRequester::onComplete() noexcept { +void ChannelRequester::onComplete() { if (!requested_) { - closeStream(StreamCompletionSignal::CANCEL); + endStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); return; } if (!publisherClosed()) { publisherComplete(); - completeStream(); + writeComplete(); tryCompleteChannel(); } } -void ChannelRequester::onError(folly::exception_wrapper ex) noexcept { +void ChannelRequester::onError(folly::exception_wrapper ex) { if (!requested_) { - closeStream(StreamCompletionSignal::CANCEL); + endStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); return; } if (!publisherClosed()) { publisherComplete(); - applicationError(ex.get_exception()->what()); + endStream(StreamCompletionSignal::ERROR); + writeApplicationError(ex.get_exception()->what()); tryCompleteChannel(); } } -void ChannelRequester::request(int64_t n) noexcept { +void ChannelRequester::request(int64_t n) { if (!requested_) { // The initial request has not been sent out yet, hence we must accumulate // the unsynchronised allowance, portion of which will be sent out with @@ -81,54 +77,79 @@ void ChannelRequester::request(int64_t n) noexcept { ConsumerBase::generateRequest(n); } -void ChannelRequester::cancel() noexcept { +void ChannelRequester::cancel() { if (!requested_) { - closeStream(StreamCompletionSignal::CANCEL); + endStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); return; } cancelConsumer(); - cancelStream(); + writeCancel(); tryCompleteChannel(); } -void ChannelRequester::endStream(StreamCompletionSignal signal) { - terminatePublisher(); - ConsumerBase::endStream(signal); -} - -void ChannelRequester::tryCompleteChannel() { - if (publisherClosed() && consumerClosed()) { - closeStream(StreamCompletionSignal::COMPLETE); - } -} - void ChannelRequester::handlePayload( Payload&& payload, - bool complete, - bool next) { + bool flagsComplete, + bool flagsNext, + bool flagsFollows) { CHECK(requested_); - processPayload(std::move(payload), next); + bool finalComplete = processFragmentedPayload( + std::move(payload), flagsNext, flagsComplete, flagsFollows); - if (complete) { + if (finalComplete) { completeConsumer(); tryCompleteChannel(); } } -void ChannelRequester::handleError(folly::exception_wrapper ex) { +void ChannelRequester::handleRequestN(uint32_t n) { CHECK(requested_); - errorConsumer(std::move(ex)); - tryCompleteChannel(); + PublisherBase::processRequestN(n); } -void ChannelRequester::handleRequestN(uint32_t n) { +void ChannelRequester::handleError(folly::exception_wrapper ew) { CHECK(requested_); - PublisherBase::processRequestN(n); + errorConsumer(std::move(ew)); + terminatePublisher(); } void ChannelRequester::handleCancel() { CHECK(requested_); - publisherComplete(); + terminatePublisher(); tryCompleteChannel(); } + +void ChannelRequester::endStream(StreamCompletionSignal signal) { + terminatePublisher(); + ConsumerBase::endStream(signal); +} + +void ChannelRequester::initStream(Payload&& request) { + requested_ = true; + + const size_t initialN = initialResponseAllowance_.consumeUpTo(kMaxRequestN); + const size_t remainingN = initialResponseAllowance_.consumeAll(); + + // Send as much as possible with the initial request. + CHECK_GE(static_cast(kMaxRequestN), initialN); + newStream( + StreamType::CHANNEL, static_cast(initialN), std::move(request)); + // We must inform ConsumerBase about an implicit allowance we have + // requested from the remote end. + ConsumerBase::addImplicitAllowance(initialN); + // Pump the remaining allowance into the ConsumerBase _after_ sending the + // initial request. + if (remainingN) { + ConsumerBase::generateRequest(remainingN); + } +} + +void ChannelRequester::tryCompleteChannel() { + if (publisherClosed() && consumerClosed()) { + endStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); + } } + +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelRequester.h b/rsocket/statemachine/ChannelRequester.h index 198621311..7c05b2028 100644 --- a/rsocket/statemachine/ChannelRequester.h +++ b/rsocket/statemachine/ChannelRequester.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -14,31 +26,49 @@ class ChannelRequester : public ConsumerBase, public PublisherBase, public yarpl::flowable::Subscriber { public: + ChannelRequester( + Payload request, + std::shared_ptr writer, + StreamId streamId) + : ConsumerBase(std::move(writer), streamId), + PublisherBase(0 /*initialRequestN*/), + request_(std::move(request)), + hasInitialRequest_(true) {} + ChannelRequester(std::shared_ptr writer, StreamId streamId) : ConsumerBase(std::move(writer), streamId), PublisherBase(1 /*initialRequestN*/) {} - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload) noexcept override; - void onComplete() noexcept override; - void onError(folly::exception_wrapper) noexcept override; - - void request(int64_t) noexcept override; - void cancel() noexcept override; - - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleRequestN(uint32_t n) override; - void handleError(folly::exception_wrapper errorPayload) override; + void onSubscribe(std::shared_ptr) override; + void onNext(Payload) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; + + void request(int64_t) override; + void cancel() override; + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleRequestN(uint32_t) override; + void handleError(folly::exception_wrapper) override; void handleCancel() override; void endStream(StreamCompletionSignal) override; + + private: + void initStream(Payload&&); void tryCompleteChannel(); - /// An allowance accumulated before the stream is initialised. - /// Remaining part of the allowance is forwarded to the ConsumerBase. + /// An allowance accumulated before the stream is initialised. Remaining part + /// of the allowance is forwarded to the ConsumerBase. Allowance initialResponseAllowance_; + + Payload request_; bool requested_{false}; + bool hasInitialRequest_{false}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelResponder.cpp b/rsocket/statemachine/ChannelResponder.cpp index e559f0a33..db366e70d 100644 --- a/rsocket/statemachine/ChannelResponder.cpp +++ b/rsocket/statemachine/ChannelResponder.cpp @@ -1,102 +1,122 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/ChannelResponder.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void ChannelResponder::onSubscribe( - Reference subscription) noexcept { + std::shared_ptr subscription) { publisherSubscribe(std::move(subscription)); } -void ChannelResponder::onNext(Payload response) noexcept { - checkPublisherOnNext(); +void ChannelResponder::onNext(Payload response) { if (!publisherClosed()) { - writePayload(std::move(response), false); + writePayload(std::move(response)); } } -void ChannelResponder::onComplete() noexcept { +void ChannelResponder::onComplete() { if (!publisherClosed()) { publisherComplete(); - completeStream(); + writeComplete(); tryCompleteChannel(); } } -void ChannelResponder::onError(folly::exception_wrapper ex) noexcept { +void ChannelResponder::onError(folly::exception_wrapper ex) { if (!publisherClosed()) { publisherComplete(); - applicationError(ex.get_exception()->what()); + endStream(StreamCompletionSignal::ERROR); + if (!ex.with_exception([this](rsocket::ErrorWithPayload& err) { + writeApplicationError(std::move(err.payload)); + })) { + writeApplicationError(ex.get_exception()->what()); + } tryCompleteChannel(); } } -void ChannelResponder::tryCompleteChannel() { - if (publisherClosed() && consumerClosed()) { - closeStream(StreamCompletionSignal::COMPLETE); - } -} - -void ChannelResponder::request(int64_t n) noexcept { +void ChannelResponder::request(int64_t n) { ConsumerBase::generateRequest(n); } -void ChannelResponder::cancel() noexcept { +void ChannelResponder::cancel() { cancelConsumer(); - cancelStream(); + writeCancel(); tryCompleteChannel(); } -void ChannelResponder::endStream(StreamCompletionSignal signal) { - terminatePublisher(); - ConsumerBase::endStream(signal); -} - -// TODO: remove this unused function -void ChannelResponder::processInitialFrame(Frame_REQUEST_CHANNEL&& frame) { - onNextPayloadFrame( - frame.requestN_, - std::move(frame.payload_), - frame.header_.flagsComplete(), - true); -} - void ChannelResponder::handlePayload( Payload&& payload, - bool complete, - bool flagsNext) { - onNextPayloadFrame(0, std::move(payload), complete, flagsNext); -} + bool flagsComplete, + bool flagsNext, + bool flagsFollows) { + payloadFragments_.addPayload(std::move(payload), flagsNext, flagsComplete); + + if (flagsFollows) { + // there will be more fragments to come + return; + } -void ChannelResponder::onNextPayloadFrame( - uint32_t requestN, - Payload&& payload, - bool complete, - bool next) { - processRequestN(requestN); - processPayload(std::move(payload), next); + bool finalFlagsComplete, finalFlagsNext; + Payload finalPayload; + + std::tie(finalPayload, finalFlagsNext, finalFlagsComplete) = + payloadFragments_.consumePayloadAndFlags(); + + if (newStream_) { + newStream_ = false; + auto channelOutputSubscriber = onNewStreamReady( + StreamType::CHANNEL, + std::move(finalPayload), + std::static_pointer_cast(shared_from_this())); + subscribe(std::move(channelOutputSubscriber)); + } else { + processPayload(std::move(finalPayload), finalFlagsNext); + } - if (complete) { + if (finalFlagsComplete) { completeConsumer(); tryCompleteChannel(); } } -void ChannelResponder::handleCancel() { - publisherComplete(); - tryCompleteChannel(); -} - void ChannelResponder::handleRequestN(uint32_t n) { processRequestN(n); } -void ChannelResponder::handleError(folly::exception_wrapper ex) { - errorConsumer(std::move(ex)); +void ChannelResponder::handleError(folly::exception_wrapper ew) { + errorConsumer(std::move(ew)); + terminatePublisher(); +} + +void ChannelResponder::handleCancel() { + terminatePublisher(); tryCompleteChannel(); } + +void ChannelResponder::endStream(StreamCompletionSignal signal) { + terminatePublisher(); + ConsumerBase::endStream(signal); +} + +void ChannelResponder::tryCompleteChannel() { + if (publisherClosed() && consumerClosed()) { + endStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); + } } + +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelResponder.h b/rsocket/statemachine/ChannelResponder.h index 79ac77d09..c0e6de708 100644 --- a/rsocket/statemachine/ChannelResponder.h +++ b/rsocket/statemachine/ChannelResponder.h @@ -1,9 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include - #include "rsocket/statemachine/ConsumerBase.h" #include "rsocket/statemachine/PublisherBase.h" #include "yarpl/flowable/Subscriber.h" @@ -22,31 +32,30 @@ class ChannelResponder : public ConsumerBase, : ConsumerBase(std::move(writer), streamId), PublisherBase(initialRequestN) {} - void processInitialFrame(Frame_REQUEST_CHANNEL&&); + void onSubscribe(std::shared_ptr) override; + void onNext(Payload) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload) noexcept override; - void onComplete() noexcept override; - void onError(folly::exception_wrapper) noexcept override; + void request(int64_t) override; + void cancel() override; - void request(int64_t n) noexcept override; - void cancel() noexcept override; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleRequestN(uint32_t n) override; + void handleRequestN(uint32_t) override; + void handleError(folly::exception_wrapper) override; void handleCancel() override; - void handleError(folly::exception_wrapper ex) override; - - void onNextPayloadFrame( - uint32_t requestN, - Payload&& payload, - bool complete, - bool next); void endStream(StreamCompletionSignal) override; + private: void tryCompleteChannel(); + + bool newStream_{true}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/ConsumerBase.cpp b/rsocket/statemachine/ConsumerBase.cpp index ddc933330..21d1cedc5 100644 --- a/rsocket/statemachine/ConsumerBase.cpp +++ b/rsocket/statemachine/ConsumerBase.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/ConsumerBase.h" @@ -6,29 +18,21 @@ #include -#include "rsocket/Payload.h" -#include "yarpl/flowable/Subscription.h" - namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void ConsumerBase::subscribe( - Reference> subscriber) { - if (isTerminated()) { - subscriber->onSubscribe(yarpl::flowable::Subscription::empty()); + std::shared_ptr> subscriber) { + if (state_ == State::CLOSED) { + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); subscriber->onComplete(); return; } DCHECK(!consumingSubscriber_); consumingSubscriber_ = std::move(subscriber); - consumingSubscriber_->onSubscribe(this->ref_from_this(this)); + consumingSubscriber_->onSubscribe(shared_from_this()); } -// TODO: this is probably buggy and misused and not needed (when -// completeConsumer exists) void ConsumerBase::cancelConsumer() { state_ = State::CLOSED; VLOG(5) << "ConsumerBase::cancelConsumer()"; @@ -48,6 +52,7 @@ void ConsumerBase::generateRequest(size_t n) { void ConsumerBase::endStream(StreamCompletionSignal signal) { VLOG(5) << "ConsumerBase::endStream(" << signal << ")"; + state_ = State::CLOSED; if (auto subscriber = std::move(consumingSubscriber_)) { if (signal == StreamCompletionSignal::COMPLETE || signal == StreamCompletionSignal::CANCEL) { // TODO: remove CANCEL @@ -58,7 +63,6 @@ void ConsumerBase::endStream(StreamCompletionSignal signal) { subscriber->onError(StreamInterruptedException(static_cast(signal))); } } - StreamStateMachineBase::endStream(signal); } size_t ConsumerBase::getConsumerAllowance() const { @@ -66,25 +70,47 @@ size_t ConsumerBase::getConsumerAllowance() const { } void ConsumerBase::processPayload(Payload&& payload, bool onNext) { - if (payload || onNext) { - // Frames carry application-level payloads are taken into account when - // figuring out flow control allowance. - if (allowance_.tryConsume(1) && activeRequests_.tryConsume(1)) { - sendRequests(); - if (consumingSubscriber_) { - consumingSubscriber_->onNext(std::move(payload)); - } else { - LOG(ERROR) - << "consuming subscriber is missing, might be a race condition on " - " cancel/onNext."; - } - } else { - handleFlowControlError(); - return; - } + if (!payload && !onNext) { + return; + } + + // Frames carrying application-level payloads are taken into account when + // figuring out flow control allowance. + if (!allowance_.tryConsume(1) || !activeRequests_.tryConsume(1)) { + handleFlowControlError(); + return; + } + + sendRequests(); + if (consumingSubscriber_) { + consumingSubscriber_->onNext(std::move(payload)); + } else { + LOG(ERROR) << "Consuming subscriber is missing, might be a race on " + << "cancel/onNext"; } } +bool ConsumerBase::processFragmentedPayload( + Payload&& payload, + bool flagsNext, + bool flagsComplete, + bool flagsFollows) { + payloadFragments_.addPayload(std::move(payload), flagsNext, flagsComplete); + + if (flagsFollows) { + // there will be more fragments to come + return false; + } + + bool finalFlagsComplete, finalFlagsNext; + Payload finalPayload; + + std::tie(finalPayload, finalFlagsNext, finalFlagsComplete) = + payloadFragments_.consumePayloadAndFlags(); + processPayload(std::move(finalPayload), finalFlagsNext); + return finalFlagsComplete; +} + void ConsumerBase::completeConsumer() { state_ = State::CLOSED; VLOG(5) << "ConsumerBase::completeConsumer()"; @@ -93,20 +119,18 @@ void ConsumerBase::completeConsumer() { } } -void ConsumerBase::errorConsumer(folly::exception_wrapper ex) { +void ConsumerBase::errorConsumer(folly::exception_wrapper ew) { state_ = State::CLOSED; VLOG(5) << "ConsumerBase::errorConsumer()"; if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError(std::move(ex)); + subscriber->onError(std::move(ew)); } } void ConsumerBase::sendRequests() { - auto toSync = - std::min(pendingAllowance_.get(), Frame_REQUEST_N::kMaxRequestN); + auto toSync = std::min(pendingAllowance_.get(), kMaxRequestN); auto actives = activeRequests_.get(); - if (actives < (toSync + 1) / 2) { - toSync = toSync - actives; + if (actives <= toSync) { toSync = pendingAllowance_.consumeUpTo(toSync); if (toSync > 0) { writeRequestN(static_cast(toSync)); @@ -117,8 +141,11 @@ void ConsumerBase::sendRequests() { void ConsumerBase::handleFlowControlError() { if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError(std::runtime_error("surplus response")); + subscriber->onError(std::runtime_error("Surplus response")); } - errorStream("flow control error"); + writeInvalidError("Flow control error"); + endStream(StreamCompletionSignal::ERROR); + removeFromWriter(); } + } // namespace rsocket diff --git a/rsocket/statemachine/ConsumerBase.h b/rsocket/statemachine/ConsumerBase.h index fd5c6567b..773c1350c 100644 --- a/rsocket/statemachine/ConsumerBase.h +++ b/rsocket/statemachine/ConsumerBase.h @@ -1,64 +1,75 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include -#include - #include "rsocket/Payload.h" #include "rsocket/internal/Allowance.h" -#include "rsocket/internal/Common.h" -#include "rsocket/statemachine/RSocketStateMachine.h" #include "rsocket/statemachine/StreamStateMachineBase.h" +#include "yarpl/flowable/Subscriber.h" #include "yarpl/flowable/Subscription.h" namespace rsocket { -enum class StreamCompletionSignal; - /// A class that represents a flow-control-aware consumer of data. class ConsumerBase : public StreamStateMachineBase, public yarpl::flowable::Subscription, - public yarpl::enable_get_ref { + public std::enable_shared_from_this { public: using StreamStateMachineBase::StreamStateMachineBase; + void subscribe(std::shared_ptr>); + /// Adds implicit allowance. /// /// This portion of allowance will not be synced to the remote end, but will /// count towards the limit of allowance the remote PublisherBase may use. - void addImplicitAllowance(size_t n); + void addImplicitAllowance(size_t); - void subscribe( - yarpl::Reference> subscriber); - - void generateRequest(size_t n); - - size_t getConsumerAllowance() const override; - - protected: - void cancelConsumer(); + void generateRequest(size_t); bool consumerClosed() const { return state_ == State::CLOSED; } - void endStream(StreamCompletionSignal signal) override; + size_t getConsumerAllowance() const override; + void endStream(StreamCompletionSignal) override; + protected: void processPayload(Payload&&, bool onNext); + // returns true if the stream is completed + bool + processFragmentedPayload(Payload&&, bool next, bool complete, bool follows); + + void cancelConsumer(); void completeConsumer(); - void errorConsumer(folly::exception_wrapper ex); + void errorConsumer(folly::exception_wrapper); private: + enum class State : uint8_t { + RESPONDING, + CLOSED, + }; + void sendRequests(); void handleFlowControlError(); - /// A Subscriber that will consume payloads. - /// This is responsible for delivering a terminal signal to the - /// Subscriber once the stream ends. - yarpl::Reference> consumingSubscriber_; + /// A Subscriber that will consume payloads. This is responsible for + /// delivering a terminal signal to the Subscriber once the stream ends. + std::shared_ptr> consumingSubscriber_; /// A total, net allowance (requested less delivered) by this consumer. Allowance allowance_; @@ -66,13 +77,11 @@ class ConsumerBase : public StreamStateMachineBase, /// REQUEST_N frames. Allowance pendingAllowance_; - /// The number of already requested payload count. - /// Prevent excessive requestN calls. + /// The number of already requested payload count. Prevent excessive requestN + /// calls. Allowance activeRequests_; - enum class State : uint8_t { - RESPONDING, - CLOSED, - } state_{State::RESPONDING}; + State state_{State::RESPONDING}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/FireAndForgetResponder.cpp b/rsocket/statemachine/FireAndForgetResponder.cpp new file mode 100644 index 000000000..2a15b87a6 --- /dev/null +++ b/rsocket/statemachine/FireAndForgetResponder.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/FireAndForgetResponder.h" + +namespace rsocket { + +using namespace yarpl::flowable; + +void FireAndForgetResponder::handlePayload( + Payload&& payload, + bool /*flagsComplete*/, + bool /*flagsNext*/, + bool flagsFollows) { + payloadFragments_.addPayloadIgnoreFlags(std::move(payload)); + + if (flagsFollows) { + // there will be more fragments to come + return; + } + + Payload finalPayload = payloadFragments_.consumePayloadIgnoreFlags(); + onNewStreamReady( + StreamType::FNF, + std::move(finalPayload), + std::shared_ptr>(nullptr)); + removeFromWriter(); +} + +void FireAndForgetResponder::handleCancel() { + removeFromWriter(); +} + +} // namespace rsocket diff --git a/rsocket/statemachine/FireAndForgetResponder.h b/rsocket/statemachine/FireAndForgetResponder.h new file mode 100644 index 000000000..bf9ad3397 --- /dev/null +++ b/rsocket/statemachine/FireAndForgetResponder.h @@ -0,0 +1,41 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/statemachine/StreamStateMachineBase.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleSubscription.h" + +namespace rsocket { + +/// Helper class for handling receiving fragmented payload +class FireAndForgetResponder : public StreamStateMachineBase { + public: + FireAndForgetResponder( + std::shared_ptr writer, + StreamId streamId) + : StreamStateMachineBase(std::move(writer), streamId) {} + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + + private: + void handleCancel() override; +}; +} // namespace rsocket diff --git a/rsocket/statemachine/PublisherBase.cpp b/rsocket/statemachine/PublisherBase.cpp index 8a6a99abd..867ae4255 100644 --- a/rsocket/statemachine/PublisherBase.cpp +++ b/rsocket/statemachine/PublisherBase.cpp @@ -1,18 +1,28 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/PublisherBase.h" #include -#include "rsocket/statemachine/RSocketStateMachine.h" - namespace rsocket { PublisherBase::PublisherBase(uint32_t initialRequestN) : initialRequestN_(initialRequestN) {} void PublisherBase::publisherSubscribe( - yarpl::Reference subscription) { + std::shared_ptr subscription) { if (state_ == State::CLOSED) { subscription->cancel(); return; @@ -24,12 +34,6 @@ void PublisherBase::publisherSubscribe( } } -void PublisherBase::checkPublisherOnNext() { - // we are either responding and publisherSubscribe method was called - // or we are already terminated - CHECK((state_ == State::RESPONDING) == !!producingSubscription_); -} - void PublisherBase::publisherComplete() { state_ = State::CLOSED; producingSubscription_ = nullptr; @@ -40,12 +44,12 @@ bool PublisherBase::publisherClosed() const { } void PublisherBase::processRequestN(uint32_t requestN) { - if (!requestN || state_ == State::CLOSED) { + if (requestN == 0 || state_ == State::CLOSED) { return; } - // we might not have the subscription set yet as there can be REQUEST_N - // frames scheduled on the executor before onSubscribe method + // We might not have the subscription set yet as there can be REQUEST_N frames + // scheduled on the executor before onSubscribe method. if (producingSubscription_) { producingSubscription_->request(requestN); } else { @@ -59,4 +63,5 @@ void PublisherBase::terminatePublisher() { subscription->cancel(); } } -} + +} // namespace rsocket diff --git a/rsocket/statemachine/PublisherBase.h b/rsocket/statemachine/PublisherBase.h index 6183a682d..b5df39909 100644 --- a/rsocket/statemachine/PublisherBase.h +++ b/rsocket/statemachine/PublisherBase.h @@ -1,42 +1,46 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "rsocket/Payload.h" #include "rsocket/internal/Allowance.h" #include "yarpl/flowable/Subscription.h" namespace rsocket { -enum class StreamCompletionSignal; - /// A class that represents a flow-control-aware producer of data. class PublisherBase { public: explicit PublisherBase(uint32_t initialRequestN); - void publisherSubscribe( - yarpl::Reference subscription); - - void checkPublisherOnNext(); + void publisherSubscribe(std::shared_ptr); + void processRequestN(uint32_t); void publisherComplete(); - bool publisherClosed() const; - - void processRequestN(uint32_t requestN); + bool publisherClosed() const; void terminatePublisher(); private: - /// A Subscription that constrols production of payloads. - /// This is responsible for delivering a terminal signal to the - /// Subscription once the stream ends. - yarpl::Reference producingSubscription_; - Allowance initialRequestN_; - enum class State : uint8_t { RESPONDING, CLOSED, - } state_{State::RESPONDING}; + }; + + std::shared_ptr producingSubscription_; + Allowance initialRequestN_; + State state_{State::RESPONDING}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/RSocketStateMachine.cpp b/rsocket/statemachine/RSocketStateMachine.cpp index 489eca4af..4d914052a 100644 --- a/rsocket/statemachine/RSocketStateMachine.cpp +++ b/rsocket/statemachine/RSocketStateMachine.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/RSocketStateMachine.h" @@ -7,6 +19,7 @@ #include #include #include +#include #include "rsocket/DuplexConnection.h" #include "rsocket/RSocketConnectionEvents.h" @@ -17,15 +30,41 @@ #include "rsocket/framing/FrameSerializer.h" #include "rsocket/framing/FrameTransportImpl.h" #include "rsocket/internal/ClientResumeStatusCallback.h" -#include "rsocket/internal/ConnectionSet.h" #include "rsocket/internal/ScheduledSubscriber.h" #include "rsocket/internal/WarmResumeManager.h" +#include "rsocket/statemachine/ChannelRequester.h" #include "rsocket/statemachine/ChannelResponder.h" -#include "rsocket/statemachine/StreamState.h" +#include "rsocket/statemachine/FireAndForgetResponder.h" +#include "rsocket/statemachine/RequestResponseRequester.h" +#include "rsocket/statemachine/RequestResponseResponder.h" +#include "rsocket/statemachine/StreamRequester.h" +#include "rsocket/statemachine/StreamResponder.h" #include "rsocket/statemachine/StreamStateMachineBase.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/single/SingleSubscriptions.h" + namespace rsocket { +namespace { + +void disconnectError( + std::shared_ptr> subscriber) { + std::runtime_error exn{"RSocket connection is disconnected or closed"}; + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onError(std::move(exn)); +} + +void disconnectError( + std::shared_ptr> observer) { + auto exn = folly::make_exception_wrapper( + "RSocket connection is disconnected or closed"); + observer->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + observer->onError(std::move(exn)); +} + +} // namespace + RSocketStateMachine::RSocketStateMachine( std::shared_ptr requestResponder, std::unique_ptr keepaliveTimer, @@ -34,17 +73,37 @@ RSocketStateMachine::RSocketStateMachine( std::shared_ptr connectionEvents, std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler) + : RSocketStateMachine( + std::make_shared( + std::move(requestResponder)), + std::move(keepaliveTimer), + mode, + std::move(stats), + std::move(connectionEvents), + std::move(resumeManager), + std::move(coldResumeHandler)) {} + +RSocketStateMachine::RSocketStateMachine( + std::shared_ptr requestResponder, + std::unique_ptr keepaliveTimer, + RSocketMode mode, + std::shared_ptr stats, + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler) : mode_{mode}, stats_{stats ? stats : RSocketStats::noop()}, - streamState_{*stats_}, - resumeManager_{resumeManager - ? resumeManager - : std::make_shared(stats_)}, + // Streams initiated by a client MUST use odd-numbered and streams + // initiated by the server MUST use even-numbered stream identifiers + nextStreamId_(mode == RSocketMode::CLIENT ? 1 : 2), + resumeManager_(std::move(resumeManager)), requestResponder_{std::move(requestResponder)}, keepaliveTimer_{std::move(keepaliveTimer)}, coldResumeHandler_{std::move(coldResumeHandler)}, - streamsFactory_{*this, mode}, connectionEvents_{connectionEvents} { + CHECK(resumeManager_) + << "provide ResumeManager::makeEmpty() instead of nullptr"; + // We deliberately do not "open" input or output to avoid having c'tor on the // stack when processing any signals from the connection. See ::connect and // ::onSubscribe. @@ -75,33 +134,37 @@ void RSocketStateMachine::setResumable(bool resumable) { } void RSocketStateMachine::connectServer( - yarpl::Reference frameTransport, + std::shared_ptr frameTransport, const SetupParameters& setupParams) { setResumable(setupParams.resumable); - connect(std::move(frameTransport), setupParams.protocolVersion); + setProtocolVersionOrThrow(setupParams.protocolVersion, frameTransport); + connect(std::move(frameTransport)); sendPendingFrames(); } bool RSocketStateMachine::resumeServer( - yarpl::Reference frameTransport, + std::shared_ptr frameTransport, const ResumeParameters& resumeParams) { - folly::Optional clientAvailable = + const folly::Optional clientAvailable = (resumeParams.clientPosition == kUnspecifiedResumePosition) ? folly::none : folly::make_optional( resumeManager_->impliedPosition() - resumeParams.clientPosition); - int64_t serverAvailable = + const int64_t serverAvailable = resumeManager_->lastSentPosition() - resumeManager_->firstSentPosition(); - int64_t serverDelta = + const int64_t serverDelta = resumeManager_->lastSentPosition() - resumeParams.serverPosition; - std::runtime_error exn{"Connection being resumed, dropping old connection"}; - disconnect(std::move(exn)); - - connect(std::move(frameTransport), resumeParams.protocolVersion); + if (frameTransport) { + stats_->socketDisconnected(); + } + closeFrameTransport( + std::runtime_error{"Connection being resumed, dropping old connection"}); + setProtocolVersionOrThrow(resumeParams.protocolVersion, frameTransport); + connect(std::move(frameTransport)); - auto result = resumeFromPositionOrClose( + const auto result = resumeFromPositionOrClose( resumeParams.serverPosition, resumeParams.clientPosition); stats_->serverResume( @@ -115,17 +178,18 @@ bool RSocketStateMachine::resumeServer( } void RSocketStateMachine::connectClient( - yarpl::Reference transport, + std::shared_ptr transport, SetupParameters params) { auto const version = params.protocolVersion == ProtocolVersion::Unknown - ? ProtocolVersion::Current() + ? ProtocolVersion::Latest : params.protocolVersion; - setFrameSerializer(FrameSerializer::createFrameSerializer(version)); + setProtocolVersionOrThrow(version, transport); setResumable(params.resumable); Frame_SETUP frame( - params.resumable ? FrameFlags::RESUME_ENABLE : FrameFlags::EMPTY, + (params.resumable ? FrameFlags::RESUME_ENABLE : FrameFlags::EMPTY_) | + (params.payload.metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_), version.major, version.minor, getKeepaliveTime(), @@ -140,7 +204,7 @@ void RSocketStateMachine::connectClient( VLOG(3) << "Out: " << frame; - connect(std::move(transport), ProtocolVersion::Unknown); + connect(std::move(transport)); // making sure we send setup frame first outputFrame(frameSerializer_->serializeOut(std::move(frame))); // then the rest of the cached frames will be sent @@ -149,24 +213,19 @@ void RSocketStateMachine::connectClient( void RSocketStateMachine::resumeClient( ResumeIdentificationToken token, - yarpl::Reference transport, + std::shared_ptr transport, std::unique_ptr resumeCallback, ProtocolVersion version) { - // Verify warm-resumption using the same version. - if (frameSerializer_ && frameSerializer_->protocolVersion() != version) { - throw std::invalid_argument{"Client resuming with different version"}; - } - // Cold-resumption. Set the serializer. if (!frameSerializer_) { CHECK(coldResumeHandler_); coldResumeInProgress_ = true; - if (version == ProtocolVersion::Unknown) { - version = ProtocolVersion::Current(); - } - setFrameSerializer(FrameSerializer::createFrameSerializer(version)); } + setProtocolVersionOrThrow( + version == ProtocolVersion::Unknown ? ProtocolVersion::Latest : version, + transport); + Frame_RESUME resumeFrame( std::move(token), resumeManager_->impliedPosition(), @@ -182,53 +241,35 @@ void RSocketStateMachine::resumeClient( outputFrame(frameSerializer_->serializeOut(std::move(resumeFrame))); } -void RSocketStateMachine::connect( - yarpl::Reference transport, - ProtocolVersion version) { +void RSocketStateMachine::connect(std::shared_ptr transport) { VLOG(2) << "Connecting to transport " << transport.get(); CHECK(isDisconnected()); CHECK(transport); - if (version != ProtocolVersion::Unknown) { - if (frameSerializer_) { - if (frameSerializer_->protocolVersion() != version) { - transport->close(); - throw std::runtime_error{"Protocol version mismatch"}; - } - } else { - frameSerializer_ = FrameSerializer::createFrameSerializer(version); - if (!frameSerializer_) { - transport->close(); - throw std::runtime_error{"Invalid protocol version"}; - } - } - } // Keep a reference to the argument, make sure the instance survives until // setFrameProcessor() returns. There can be terminating signals processed in // that call which will nullify frameTransport_. frameTransport_ = transport; + CHECK(frameSerializer_); + frameSerializer_->preallocateFrameSizeField() = + transport->isConnectionFramed(); + if (connectionEvents_) { connectionEvents_->onConnected(); } - // Keep a reference to this, as processing frames might close this - // instance. - auto copyThis = shared_from_this(); - frameTransport_->setFrameProcessor(copyThis); - stats_->socketConnected(); + // Keep a reference to stats, as processing frames might close this instance. + auto const stats = stats_; + frameTransport_->setFrameProcessor(shared_from_this()); + stats->socketConnected(); } void RSocketStateMachine::sendPendingFrames() { DCHECK(!resumeCallback_); - // We are free to try to send frames again. Not all frames might be sent if - // the connection breaks, the rest of them will queue up again. - auto outputFrames = streamState_.moveOutputPendingFrames(); - for (auto& frame : outputFrames) { - outputFrameOrEnqueue(std::move(frame)); - } + StreamsWriterImpl::sendPendingFrames(); // TODO: turn on only after setup frame was received if (keepaliveTimer_) { @@ -279,8 +320,8 @@ void RSocketStateMachine::close( connectionEvents->onClosed(std::move(ex)); } - if (auto set = connectionSet_.lock()) { - set->remove(shared_from_this()); + if (closeCallback_) { + closeCallback_->remove(*this); } } @@ -350,13 +391,13 @@ void RSocketStateMachine::closeWithError(Frame_ERROR&& error) { std::runtime_error exn{error.payload_.cloneDataToString()}; if (frameSerializer_) { - outputFrameOrEnqueue(std::move(error)); + outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(error))); } close(std::move(exn), signal); } void RSocketStateMachine::reconnect( - yarpl::Reference newFrameTransport, + std::shared_ptr newFrameTransport, std::unique_ptr resumeCallback) { CHECK(newFrameTransport); CHECK(resumeCallback); @@ -368,60 +409,72 @@ void RSocketStateMachine::reconnect( // TODO: output frame buffer should not be written to the new connection until // we receive resume ok resumeCallback_ = std::move(resumeCallback); - connect(std::move(newFrameTransport), ProtocolVersion::Unknown); + connect(std::move(newFrameTransport)); } -void RSocketStateMachine::addStream( - StreamId streamId, - yarpl::Reference stateMachine) { - auto result = - streamState_.streams_.emplace(streamId, std::move(stateMachine)); +void RSocketStateMachine::requestStream( + Payload request, + std::shared_ptr> responseSink) { + if (isDisconnected()) { + disconnectError(std::move(responseSink)); + return; + } + + auto const streamId = getNextStreamId(); + auto stateMachine = std::make_shared( + shared_from_this(), streamId, std::move(request)); + const auto result = streams_.emplace(streamId, stateMachine); DCHECK(result.second); + stateMachine->subscribe(std::move(responseSink)); } -void RSocketStateMachine::endStream( - StreamId streamId, - StreamCompletionSignal signal) { - VLOG(6) << "endStream"; - // The signal must be idempotent. - if (!endStreamInternal(streamId, signal)) { - return; +std::shared_ptr> +RSocketStateMachine::requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> responseSink) { + if (isDisconnected()) { + disconnectError(std::move(responseSink)); + return nullptr; } - resumeManager_->onStreamClosed(streamId); - DCHECK( - signal == StreamCompletionSignal::CANCEL || - signal == StreamCompletionSignal::COMPLETE || - signal == StreamCompletionSignal::APPLICATION_ERROR || - signal == StreamCompletionSignal::ERROR); + + auto const streamId = getNextStreamId(); + std::shared_ptr stateMachine; + if (hasInitialRequest) { + stateMachine = std::make_shared( + std::move(request), shared_from_this(), streamId); + } else { + stateMachine = + std::make_shared(shared_from_this(), streamId); + } + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe(std::move(responseSink)); + return stateMachine; } -bool RSocketStateMachine::endStreamInternal( - StreamId streamId, - StreamCompletionSignal signal) { - VLOG(6) << "endStreamInternal"; - auto it = streamState_.streams_.find(streamId); - if (it == streamState_.streams_.end()) { - // Unsubscribe handshake initiated by the connection, we're done. - return false; +void RSocketStateMachine::requestResponse( + Payload request, + std::shared_ptr> responseSink) { + if (isDisconnected()) { + disconnectError(std::move(responseSink)); + return; } - // Remove from the map before notifying the stateMachine. - auto stateMachine = std::move(it->second); - streamState_.streams_.erase(it); - stateMachine->endStream(signal); - return true; + auto const streamId = getNextStreamId(); + auto stateMachine = std::make_shared( + shared_from_this(), streamId, std::move(request)); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe(std::move(responseSink)); } void RSocketStateMachine::closeStreams(StreamCompletionSignal signal) { - // Close all streams. - while (!streamState_.streams_.empty()) { - auto oldSize = streamState_.streams_.size(); - auto result = - endStreamInternal(streamState_.streams_.begin()->first, signal); - // TODO(stupaq): what kind of a user action could violate these - // assertions? - DCHECK(result); - DCHECK_EQ(streamState_.streams_.size(), oldSize - 1); + while (!streams_.empty()) { + auto it = streams_.begin(); + auto streamStateMachine = std::move(it->second); + streams_.erase(it); + streamStateMachine->endStream(signal); } } @@ -431,43 +484,25 @@ void RSocketStateMachine::processFrame(std::unique_ptr frame) { return; } - // Necessary in case the only stream state machine closes itself, and takes - // the RSocketStateMachine with it. - auto self = shared_from_this(); - if (!ensureOrAutodetectFrameSerializer(*frame)) { - constexpr folly::StringPiece message{"Cannot detect protocol version"}; - closeWithError(Frame_ERROR::connectionError(message.str())); + constexpr auto msg = "Cannot detect protocol version"; + closeWithError(Frame_ERROR::connectionError(msg)); return; } - auto frameType = frameSerializer_->peekFrameType(*frame); + const auto frameType = frameSerializer_->peekFrameType(*frame); stats_->frameRead(frameType); - auto optStreamId = frameSerializer_->peekStreamId(*frame); + const auto optStreamId = frameSerializer_->peekStreamId(*frame, false); if (!optStreamId) { - constexpr folly::StringPiece message{"Cannot decode stream ID"}; - closeWithError(Frame_ERROR::connectionError(message.str())); + constexpr auto msg = "Cannot decode stream ID"; + closeWithError(Frame_ERROR::connectionError(msg)); return; } - auto frameLength = frame->computeChainDataLength(); - auto streamId = *optStreamId; - if (streamId == 0) { - handleConnectionFrame(frameType, std::move(frame)); - } else if (resumeCallback_) { - // during the time when we are resuming we are can't receive any other - // than connection level frames which drives the resumption - // TODO(lehecka): this assertion should be handled more elegantly using - // different state machine - constexpr folly::StringPiece message{ - "Received stream frame while resuming"}; - LOG(ERROR) << message; - closeWithError(Frame_ERROR::connectionError(message.str())); - return; - } else { - handleStreamFrame(streamId, frameType, std::move(frame)); - } + const auto frameLength = frame->computeChainDataLength(); + const auto streamId = *optStreamId; + handleFrame(streamId, frameType, std::move(frame)); resumeManager_->trackReceivedFrame( frameLength, frameType, streamId, getConsumerAllowance(streamId)); } @@ -477,46 +512,170 @@ void RSocketStateMachine::onTerminal(folly::exception_wrapper ex) { disconnect(std::move(ex)); return; } - auto termSignal = ex ? StreamCompletionSignal::CONNECTION_ERROR - : StreamCompletionSignal::CONNECTION_END; + const auto termSignal = ex ? StreamCompletionSignal::CONNECTION_ERROR + : StreamCompletionSignal::CONNECTION_END; close(std::move(ex), termSignal); } -void RSocketStateMachine::handleConnectionFrame( +void RSocketStateMachine::onKeepAliveFrame( + ResumePosition resumePosition, + std::unique_ptr data, + bool keepAliveRespond) { + resumeManager_->resetUpToPosition(resumePosition); + if (mode_ == RSocketMode::SERVER) { + if (keepAliveRespond) { + sendKeepalive(FrameFlags::EMPTY_, std::move(data)); + } else { + closeWithError(Frame_ERROR::connectionError("keepalive without flag")); + } + } else { + if (keepAliveRespond) { + closeWithError(Frame_ERROR::connectionError( + "client received keepalive with respond flag")); + } else if (keepaliveTimer_) { + keepaliveTimer_->keepaliveReceived(); + } + stats_->keepaliveReceived(); + } +} + +void RSocketStateMachine::onMetadataPushFrame( + std::unique_ptr metadata) { + requestResponder_->handleMetadataPush(std::move(metadata)); +} + +void RSocketStateMachine::onResumeOkFrame(ResumePosition resumePosition) { + if (!resumeCallback_) { + constexpr auto msg = "Received RESUME_OK while not resuming"; + closeWithError(Frame_ERROR::connectionError(msg)); + return; + } + + if (!resumeManager_->isPositionAvailable(resumePosition)) { + auto const msg = folly::sformat( + "Client cannot resume, server position {} is not available", + resumePosition); + closeWithError(Frame_ERROR::connectionError(msg)); + return; + } + + if (coldResumeInProgress_) { + setNextStreamId(resumeManager_->getLargestUsedStreamId()); + for (const auto& it : resumeManager_->getStreamResumeInfos()) { + const auto streamId = it.first; + const StreamResumeInfo& streamResumeInfo = it.second; + if (streamResumeInfo.requester == RequestOriginator::LOCAL && + streamResumeInfo.streamType == StreamType::STREAM) { + auto subscriber = coldResumeHandler_->handleRequesterResumeStream( + streamResumeInfo.streamToken, streamResumeInfo.consumerAllowance); + + auto stateMachine = std::make_shared( + shared_from_this(), streamId, Payload()); + // Set requested to true (since cold resumption) + stateMachine->setRequested(streamResumeInfo.consumerAllowance); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe( + std::make_shared>( + std::move(subscriber), + *folly::EventBaseManager::get()->getEventBase())); + } + } + coldResumeInProgress_ = false; + } + + auto resumeCallback = std::move(resumeCallback_); + resumeCallback->onResumeOk(); + resumeFromPosition(resumePosition); +} + +void RSocketStateMachine::onErrorFrame( + StreamId streamId, + ErrorCode errorCode, + Payload payload) { + if (streamId != 0) { + if (!ensureNotInResumption()) { + return; + } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + if (errorCode != ErrorCode::APPLICATION_ERROR) { + // Encapsulate non-user errors with runtime_error, which is more + // suitable for LOGging. + stateMachine->handleError( + std::runtime_error(payload.moveDataToString())); + } else { + // Don't expose user errors + stateMachine->handleError(ErrorWithPayload(std::move(payload))); + } + } + } else { + // TODO: handle INVALID_SETUP, UNSUPPORTED_SETUP, REJECTED_SETUP + if ((errorCode == ErrorCode::CONNECTION_ERROR || + errorCode == ErrorCode::REJECTED_RESUME) && + resumeCallback_) { + auto resumeCallback = std::move(resumeCallback_); + resumeCallback->onResumeError( + ResumptionException(payload.cloneDataToString())); + // fall through + } + close( + std::runtime_error(payload.moveDataToString()), + StreamCompletionSignal::ERROR); + } +} + +void RSocketStateMachine::onSetupFrame() { + // this should be processed in SetupResumeAcceptor + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onResumeFrame() { + // this should be processed in SetupResumeAcceptor + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onReservedFrame() { + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onLeaseFrame() { + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onExtFrame() { + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onUnexpectedFrame(StreamId streamId) { + auto&& msg = folly::sformat("Unexpected frame for stream {}", streamId); + closeWithError(Frame_ERROR::connectionError(msg)); +} + +void RSocketStateMachine::handleFrame( + StreamId streamId, FrameType frameType, std::unique_ptr payload) { switch (frameType) { case FrameType::KEEPALIVE: { Frame_KEEPALIVE frame; - if (!deserializeFrameOrError(isResumable_, frame, std::move(payload))) { + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frame; - resumeManager_->resetUpToPosition(frame.position_); - if (mode_ == RSocketMode::SERVER) { - if (!!(frame.header_.flags & FrameFlags::KEEPALIVE_RESPOND)) { - sendKeepalive(FrameFlags::EMPTY, std::move(frame.data_)); - } else { - closeWithError( - Frame_ERROR::connectionError("keepalive without flag")); - } - } else { - if (!!(frame.header_.flags & FrameFlags::KEEPALIVE_RESPOND)) { - closeWithError(Frame_ERROR::connectionError( - "client received keepalive with respond flag")); - } else if (keepaliveTimer_) { - keepaliveTimer_->keepaliveReceived(); - } - stats_->keepaliveReceived(); - } + onKeepAliveFrame( + frame.position_, + std::move(frame.data_), + !!(frame.header_.flags & FrameFlags::KEEPALIVE_RESPOND)); return; } case FrameType::METADATA_PUSH: { Frame_METADATA_PUSH frame; - if (deserializeFrameOrError(frame, std::move(payload))) { - VLOG(3) << mode_ << " In: " << frame; - requestResponder_->handleMetadataPush(std::move(frame.metadata_)); + if (!deserializeFrameOrError(frame, std::move(payload))) { + return; } + VLOG(3) << mode_ << " In: " << frame; + onMetadataPushFrame(std::move(frame.metadata_)); return; } case FrameType::RESUME_OK: { @@ -525,47 +684,7 @@ void RSocketStateMachine::handleConnectionFrame( return; } VLOG(3) << mode_ << " In: " << frame; - - if (!resumeCallback_) { - constexpr folly::StringPiece message{ - "Received RESUME_OK while not resuming"}; - closeWithError(Frame_ERROR::connectionError(message.str())); - return; - } - - if (!resumeManager_->isPositionAvailable(frame.position_)) { - auto message = folly::sformat( - "Client cannot resume, server position {} is not available", - frame.position_); - closeWithError(Frame_ERROR::connectionError(std::move(message))); - return; - } - - if (coldResumeInProgress_) { - streamsFactory().setNextStreamId( - resumeManager_->getLargestUsedStreamId()); - for (const auto& it : resumeManager_->getStreamResumeInfos()) { - auto streamId = it.first; - const StreamResumeInfo& streamResumeInfo = it.second; - if (streamResumeInfo.requester == RequestOriginator::LOCAL && - streamResumeInfo.streamType == StreamType::STREAM) { - auto subscriber = coldResumeHandler_->handleRequesterResumeStream( - streamResumeInfo.streamToken, - streamResumeInfo.consumerAllowance); - streamsFactory().createStreamRequester( - yarpl::make_ref>( - std::move(subscriber), - *folly::EventBaseManager::get()->getEventBase()), - streamId, - streamResumeInfo.consumerAllowance); - } - } - coldResumeInProgress_ = false; - } - - auto resumeCallback = std::move(resumeCallback_); - resumeCallback->onResumeOk(); - resumeFromPosition(frame.position_); + onResumeOkFrame(frame.position_); return; } case FrameType::ERROR: { @@ -574,194 +693,295 @@ void RSocketStateMachine::handleConnectionFrame( return; } VLOG(3) << mode_ << " In: " << frame; - - // TODO: handle INVALID_SETUP, UNSUPPORTED_SETUP, REJECTED_SETUP - - if ((frame.errorCode_ == ErrorCode::CONNECTION_ERROR || - frame.errorCode_ == ErrorCode::REJECTED_RESUME) && - resumeCallback_) { - auto resumeCallback = std::move(resumeCallback_); - resumeCallback->onResumeError( - ResumptionException(frame.payload_.cloneDataToString())); - // fall through - } - - close( - std::runtime_error(frame.payload_.moveDataToString()), - StreamCompletionSignal::ERROR); + onErrorFrame(streamId, frame.errorCode_, std::move(frame.payload_)); return; } - case FrameType::SETUP: // this should be processed in SetupResumeAcceptor - case FrameType::RESUME: // this should be processed in SetupResumeAcceptor + case FrameType::SETUP: + onSetupFrame(); + return; + case FrameType::RESUME: + onResumeFrame(); + return; case FrameType::RESERVED: + onReservedFrame(); + return; case FrameType::LEASE: - case FrameType::REQUEST_RESPONSE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_STREAM: - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_N: - case FrameType::CANCEL: - case FrameType::PAYLOAD: - case FrameType::EXT: - default: { - auto msg = folly::sformat( - "Unexpected {} frame for stream 0", toString(frameType)); - closeWithError(Frame_ERROR::connectionError(std::move(msg))); + onLeaseFrame(); return; - } - } -} - -void RSocketStateMachine::handleStreamFrame( - StreamId streamId, - FrameType frameType, - std::unique_ptr serializedFrame) { - auto it = streamState_.streams_.find(streamId); - if (it == streamState_.streams_.end()) { - handleUnknownStream(streamId, frameType, std::move(serializedFrame)); - return; - } - - // we are purposely making a copy of the reference here to avoid problems with - // lifetime of the stateMachine when a terminating signal is delivered which - // will cause the stateMachine to be destroyed while in one of its methods - auto stateMachine = it->second; - - switch (frameType) { case FrameType::REQUEST_N: { Frame_REQUEST_N frameRequestN; - if (!deserializeFrameOrError(frameRequestN, std::move(serializedFrame))) { + if (!deserializeFrameOrError(frameRequestN, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << frameRequestN; - stateMachine->handleRequestN(frameRequestN.requestN_); + onRequestNFrame(streamId, frameRequestN.requestN_); break; } case FrameType::CANCEL: { VLOG(3) << mode_ << " In: " << Frame_CANCEL(streamId); - stateMachine->handleCancel(); + onCancelFrame(streamId); break; } case FrameType::PAYLOAD: { Frame_PAYLOAD framePayload; - if (!deserializeFrameOrError(framePayload, std::move(serializedFrame))) { + if (!deserializeFrameOrError(framePayload, std::move(payload))) { return; } VLOG(3) << mode_ << " In: " << framePayload; - stateMachine->handlePayload( + onPayloadFrame( + streamId, std::move(framePayload.payload_), + framePayload.header_.flagsFollows(), framePayload.header_.flagsComplete(), framePayload.header_.flagsNext()); break; } - case FrameType::ERROR: { - Frame_ERROR frameError; - if (!deserializeFrameOrError(frameError, std::move(serializedFrame))) { + case FrameType::REQUEST_CHANNEL: { + Frame_REQUEST_CHANNEL frame; + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << mode_ << " In: " << frameError; - stateMachine->handleError( - std::runtime_error(frameError.payload_.moveDataToString())); + VLOG(3) << mode_ << " In: " << frame; + onRequestChannelFrame( + streamId, + frame.requestN_, + std::move(frame.payload_), + frame.header_.flagsComplete(), + frame.header_.flagsNext(), + frame.header_.flagsFollows()); break; } - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_RESPONSE: - case FrameType::RESERVED: - case FrameType::SETUP: - case FrameType::LEASE: - case FrameType::KEEPALIVE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_STREAM: - case FrameType::METADATA_PUSH: - case FrameType::RESUME: - case FrameType::RESUME_OK: - case FrameType::EXT: { - auto msg = folly::sformat( - "Unexpected {} frame for stream {}", toString(frameType), streamId); - closeWithError(Frame_ERROR::connectionError(std::move(msg))); + case FrameType::REQUEST_STREAM: { + Frame_REQUEST_STREAM frame; + if (!deserializeFrameOrError(frame, std::move(payload))) { + return; + } + VLOG(3) << mode_ << " In: " << frame; + onRequestStreamFrame( + streamId, + frame.requestN_, + std::move(frame.payload_), + frame.header_.flagsFollows()); break; } - default: - // Ignore unknown frames for compatibility with future frame types. + case FrameType::REQUEST_RESPONSE: { + Frame_REQUEST_RESPONSE frame; + if (!deserializeFrameOrError(frame, std::move(payload))) { + return; + } + VLOG(3) << mode_ << " In: " << frame; + onRequestResponseFrame( + streamId, std::move(frame.payload_), frame.header_.flagsFollows()); + break; + } + case FrameType::REQUEST_FNF: { + Frame_REQUEST_FNF frame; + if (!deserializeFrameOrError(frame, std::move(payload))) { + return; + } + VLOG(3) << mode_ << " In: " << frame; + onFireAndForgetFrame( + streamId, std::move(frame.payload_), frame.header_.flagsFollows()); break; + } + case FrameType::EXT: + onExtFrame(); + return; + + default: { + stats_->unknownFrameReceived(); + // per rsocket spec, we will ignore any other unknown frames + return; + } + } +} + +std::shared_ptr +RSocketStateMachine::getStreamStateMachine(StreamId streamId) { + const auto&& it = streams_.find(streamId); + if (it == streams_.end()) { + return nullptr; } + // we are purposely making a copy of the reference here to avoid problems with + // lifetime of the stateMachine when a terminating signal is delivered which + // will cause the stateMachine to be destroyed while in one of its methods + return it->second; } -void RSocketStateMachine::handleUnknownStream( +bool RSocketStateMachine::ensureNotInResumption() { + if (resumeCallback_) { + // during the time when we are resuming we are can't receive any other + // than connection level frames which drives the resumption + // TODO(lehecka): this assertion should be handled more elegantly using + // different state machine + constexpr auto msg = "Received stream frame while resuming"; + LOG(ERROR) << msg; + closeWithError(Frame_ERROR::connectionError(msg)); + return false; + } + return true; +} + +void RSocketStateMachine::onRequestNFrame( StreamId streamId, - FrameType frameType, - std::unique_ptr serializedFrame) { - DCHECK(streamId != 0); - // TODO: comparing string versions is odd because from version - // 10.0 the lexicographic comparison doesn't work - // we should change the version to struct - if (frameSerializer_->protocolVersion() > ProtocolVersion{0, 0} && - !streamsFactory_.registerNewPeerStreamId(streamId)) { + uint32_t requestN) { + if (!ensureNotInResumption()) { return; } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + stateMachine->handleRequestN(requestN); + } +} - if (!isNewStreamFrame(frameType)) { - auto msg = folly::sformat( - "Unexpected frame {} for stream {}", toString(frameType), streamId); - VLOG(1) << msg; - closeWithError(Frame_ERROR::connectionError(std::move(msg))); +void RSocketStateMachine::onCancelFrame(StreamId streamId) { + if (!ensureNotInResumption()) { return; } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + stateMachine->handleCancel(); + } +} - auto saveStreamToken = [&](const Payload& payload) { - if (coldResumeHandler_) { - auto streamType = getStreamType(frameType); - CHECK(streamType != StreamType::FNF); - auto streamToken = coldResumeHandler_->generateStreamToken( - payload, streamId, streamType); - resumeManager_->onStreamOpen( - streamId, RequestOriginator::REMOTE, streamToken, streamType); - } - }; +void RSocketStateMachine::onPayloadFrame( + StreamId streamId, + Payload payload, + bool flagsFollows, + bool flagsComplete, + bool flagsNext) { + if (!ensureNotInResumption()) { + return; + } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + stateMachine->handlePayload( + std::move(payload), flagsComplete, flagsNext, flagsFollows); + } +} - if (frameType == FrameType::REQUEST_CHANNEL) { - Frame_REQUEST_CHANNEL frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { - return; - } - VLOG(3) << mode_ << " In: " << frame; - auto stateMachine = - streamsFactory_.createChannelResponder(frame.requestN_, streamId); - saveStreamToken(frame.payload_); - auto requestSink = requestResponder_->handleRequestChannelCore( - std::move(frame.payload_), streamId, stateMachine); - stateMachine->subscribe(requestSink); - } else if (frameType == FrameType::REQUEST_STREAM) { - Frame_REQUEST_STREAM frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { - return; - } - VLOG(3) << mode_ << " In: " << frame; - auto stateMachine = - streamsFactory_.createStreamResponder(frame.requestN_, streamId); - saveStreamToken(frame.payload_); - requestResponder_->handleRequestStreamCore( - std::move(frame.payload_), streamId, stateMachine); - } else if (frameType == FrameType::REQUEST_RESPONSE) { - Frame_REQUEST_RESPONSE frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { - return; - } - VLOG(3) << mode_ << " In: " << frame; - auto stateMachine = - streamsFactory_.createRequestResponseResponder(streamId); - saveStreamToken(frame.payload_); - requestResponder_->handleRequestResponseCore( - std::move(frame.payload_), streamId, stateMachine); - } else if (frameType == FrameType::REQUEST_FNF) { - Frame_REQUEST_FNF frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { - return; - } - VLOG(3) << mode_ << " In: " << frame; - // no stream tracking is necessary - requestResponder_->handleFireAndForget(std::move(frame.payload_), streamId); +void RSocketStateMachine::onRequestStreamFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = + std::make_shared(shared_from_this(), streamId, requestN); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); +} + +void RSocketStateMachine::onRequestChannelFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = std::make_shared( + shared_from_this(), streamId, requestN); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload( + std::move(payload), flagsComplete, flagsNext, flagsFollows); +} + +void RSocketStateMachine::onRequestResponseFrame( + StreamId streamId, + Payload payload, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = + std::make_shared(shared_from_this(), streamId); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); +} + +void RSocketStateMachine::onFireAndForgetFrame( + StreamId streamId, + Payload payload, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = + std::make_shared(shared_from_this(), streamId); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); +} + +bool RSocketStateMachine::isNewStreamId(StreamId streamId) { + if (frameSerializer_->protocolVersion() > ProtocolVersion{0, 0} && + !registerNewPeerStreamId(streamId)) { + return false; + } + return true; +} + +std::shared_ptr> +RSocketStateMachine::onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + if (coldResumeHandler_ && streamType != StreamType::FNF) { + auto streamToken = + coldResumeHandler_->generateStreamToken(payload, streamId, streamType); + resumeManager_->onStreamOpen( + streamId, RequestOriginator::REMOTE, streamToken, streamType); + } + + switch (streamType) { + case StreamType::CHANNEL: + return requestResponder_->handleRequestChannel( + std::move(payload), streamId, std::move(response)); + + case StreamType::STREAM: + requestResponder_->handleRequestStream( + std::move(payload), streamId, std::move(response)); + return nullptr; + + case StreamType::REQUEST_RESPONSE: + // the other overload method should be called + CHECK(false); + folly::assume_unreachable(); + + case StreamType::FNF: + requestResponder_->handleFireAndForget(std::move(payload), streamId); + return nullptr; + + default: + CHECK(false) << "unknown value: " << streamType; + folly::assume_unreachable(); + } +} + +void RSocketStateMachine::onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + CHECK(streamType == StreamType::REQUEST_RESPONSE); + + if (coldResumeHandler_) { + auto streamToken = + coldResumeHandler_->generateStreamToken(payload, streamId, streamType); + resumeManager_->onStreamOpen( + streamId, RequestOriginator::REMOTE, streamToken, streamType); } + requestResponder_->handleRequestResponse( + std::move(payload), streamId, std::move(response)); } void RSocketStateMachine::sendKeepalive(std::unique_ptr data) { @@ -773,9 +993,8 @@ void RSocketStateMachine::sendKeepalive( std::unique_ptr data) { Frame_KEEPALIVE pingFrame( flags, resumeManager_->impliedPosition(), std::move(data)); - VLOG(3) << "Out: " << pingFrame; - outputFrameOrEnqueue( - frameSerializer_->serializeOut(std::move(pingFrame), isResumable_)); + VLOG(3) << mode_ << " Out: " << pingFrame; + outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(pingFrame))); stats_->keepaliveSent(); } @@ -790,12 +1009,13 @@ bool RSocketStateMachine::resumeFromPositionOrClose( DCHECK(!isDisconnected()); DCHECK(mode_ == RSocketMode::SERVER); - bool clientPositionExist = (clientPosition == kUnspecifiedResumePosition) || + const bool clientPositionExist = + (clientPosition == kUnspecifiedResumePosition) || clientPosition <= resumeManager_->impliedPosition(); if (clientPositionExist && resumeManager_->isPositionAvailable(serverPosition)) { - Frame_RESUME_OK resumeOkFrame(resumeManager_->impliedPosition()); + Frame_RESUME_OK resumeOkFrame{resumeManager_->impliedPosition()}; VLOG(3) << "Out: " << resumeOkFrame; frameTransport_->outputFrameOrDrop( frameSerializer_->serializeOut(std::move(resumeOkFrame))); @@ -803,7 +1023,7 @@ bool RSocketStateMachine::resumeFromPositionOrClose( return true; } - auto message = folly::to( + auto const msg = folly::to( "Cannot resume server, client lastServerPosition=", serverPosition, " firstClientPosition=", @@ -811,7 +1031,7 @@ bool RSocketStateMachine::resumeFromPositionOrClose( " is not available. Last reset position is ", resumeManager_->firstSentPosition()); - closeWithError(Frame_ERROR::connectionError(std::move(message))); + closeWithError(Frame_ERROR::connectionError(msg)); return false; } @@ -825,7 +1045,8 @@ void RSocketStateMachine::resumeFromPosition(ResumePosition position) { } resumeManager_->sendFramesFromPosition(position, *frameTransport_); - for (auto& frame : streamState_.moveOutputPendingFrames()) { + auto frames = consumePendingOutputFrames(); + for (auto& frame : frames) { outputFrameOrEnqueue(std::move(frame)); } @@ -834,35 +1055,31 @@ void RSocketStateMachine::resumeFromPosition(ResumePosition position) { } } -void RSocketStateMachine::outputFrameOrEnqueue( - std::unique_ptr frame) { +bool RSocketStateMachine::shouldQueue() { // if we are resuming we cant send any frames until we receive RESUME_OK - if (!isDisconnected() && !resumeCallback_) { - outputFrame(std::move(frame)); - } else { - streamState_.enqueueOutputPendingFrame(std::move(frame)); - } + return isDisconnected() || resumeCallback_; } void RSocketStateMachine::fireAndForget(Payload request) { - auto const streamId = streamsFactory().getNextStreamId(); - Frame_REQUEST_FNF frame{streamId, FrameFlags::EMPTY, std::move(request)}; - outputFrameOrEnqueue(std::move(frame)); + auto const streamId = getNextStreamId(); + Frame_REQUEST_FNF frame{streamId, FrameFlags::EMPTY_, std::move(request)}; + outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(frame))); } void RSocketStateMachine::metadataPush(std::unique_ptr metadata) { Frame_METADATA_PUSH metadataPushFrame{std::move(metadata)}; - outputFrameOrEnqueue(std::move(metadataPushFrame)); + outputFrameOrEnqueue( + frameSerializer_->serializeOut(std::move(metadataPushFrame))); } void RSocketStateMachine::outputFrame(std::unique_ptr frame) { DCHECK(!isDisconnected()); - auto frameType = frameSerializer_->peekFrameType(*frame); + const auto frameType = frameSerializer_->peekFrameType(*frame); stats_->frameWritten(frameType); if (isResumable_) { - auto streamIdPtr = frameSerializer_->peekStreamId(*frame); + auto streamIdPtr = frameSerializer_->peekStreamId(*frame, false); CHECK(streamIdPtr) << "Error in serialized frame."; resumeManager_->trackSentFrame( *frame, frameType, *streamIdPtr, getConsumerAllowance(*streamIdPtr)); @@ -884,76 +1101,25 @@ bool RSocketStateMachine::isClosed() const { return isClosed_; } -void RSocketStateMachine::setFrameSerializer( - std::unique_ptr frameSerializer) { - CHECK(frameSerializer); - // serializer is not interchangeable, it would screw up resumability - // CHECK(!frameSerializer_); - frameSerializer_ = std::move(frameSerializer); -} - void RSocketStateMachine::writeNewStream( StreamId streamId, StreamType streamType, uint32_t initialRequestN, - Payload payload, - bool completed) { + Payload payload) { if (coldResumeHandler_ && streamType != StreamType::FNF) { - auto streamToken = + const auto streamToken = coldResumeHandler_->generateStreamToken(payload, streamId, streamType); resumeManager_->onStreamOpen( streamId, RequestOriginator::LOCAL, streamToken, streamType); } - switch (streamType) { - case StreamType::CHANNEL: - outputFrameOrEnqueue(Frame_REQUEST_CHANNEL( - streamId, - completed ? FrameFlags::COMPLETE : FrameFlags::EMPTY, - initialRequestN, - std::move(payload))); - break; - - case StreamType::STREAM: - outputFrameOrEnqueue(Frame_REQUEST_STREAM( - streamId, FrameFlags::EMPTY, initialRequestN, std::move(payload))); - break; - - case StreamType::REQUEST_RESPONSE: - outputFrameOrEnqueue(Frame_REQUEST_RESPONSE( - streamId, FrameFlags::EMPTY, std::move(payload))); - break; - - case StreamType::FNF: - outputFrameOrEnqueue( - Frame_REQUEST_FNF(streamId, FrameFlags::EMPTY, std::move(payload))); - break; - - default: - CHECK(false); // unknown type - } + StreamsWriterImpl::writeNewStream( + streamId, streamType, initialRequestN, std::move(payload)); } -void RSocketStateMachine::writeRequestN(Frame_REQUEST_N&& frame) { - outputFrameOrEnqueue(std::move(frame)); -} - -void RSocketStateMachine::writeCancel(Frame_CANCEL&& frame) { - outputFrameOrEnqueue(std::move(frame)); -} - -void RSocketStateMachine::writePayload(Frame_PAYLOAD&& frame) { - outputFrameOrEnqueue(std::move(frame)); -} - -void RSocketStateMachine::writeError(Frame_ERROR&& frame) { - outputFrameOrEnqueue(std::move(frame)); -} - -void RSocketStateMachine::onStreamClosed( - StreamId streamId, - StreamCompletionSignal signal) { - endStream(streamId, signal); +void RSocketStateMachine::onStreamClosed(StreamId streamId) { + streams_.erase(streamId); + resumeManager_->onStreamClosed(streamId); } bool RSocketStateMachine::ensureOrAutodetectFrameSerializer( @@ -977,23 +1143,94 @@ bool RSocketStateMachine::ensureOrAutodetectFrameSerializer( VLOG(2) << "detected protocol version" << serializer->protocolVersion(); frameSerializer_ = std::move(serializer); + frameSerializer_->preallocateFrameSizeField() = + frameTransport_ && frameTransport_->isConnectionFramed(); + return true; } size_t RSocketStateMachine::getConsumerAllowance(StreamId streamId) const { - size_t consumerAllowance = 0; - auto it = streamState_.streams_.find(streamId); - if (it != streamState_.streams_.end()) { - consumerAllowance = it->second->getConsumerAllowance(); - } - return consumerAllowance; + auto const it = streams_.find(streamId); + return it != streams_.end() ? it->second->getConsumerAllowance() : 0; } -void RSocketStateMachine::registerSet(std::shared_ptr set) { - connectionSet_ = std::move(set); +void RSocketStateMachine::registerCloseCallback( + RSocketStateMachine::CloseCallback* callback) { + closeCallback_ = callback; } DuplexConnection* RSocketStateMachine::getConnection() { return frameTransport_ ? frameTransport_->getConnection() : nullptr; } + +void RSocketStateMachine::setProtocolVersionOrThrow( + ProtocolVersion version, + const std::shared_ptr& transport) { + CHECK(version != ProtocolVersion::Unknown); + + // TODO(lehecka): this is a temporary guard to make sure the transport is + // explicitly closed when exceptions are thrown. The right solution is to + // automatically close duplex connection in the destructor when unique_ptr + // is released + auto transportGuard = folly::makeGuard([&] { transport->close(); }); + + if (frameSerializer_) { + if (frameSerializer_->protocolVersion() != version) { + // serializer is not interchangeable, it would screw up resumability + throw std::runtime_error{"Protocol version mismatch"}; + } + } else { + auto frameSerializer = FrameSerializer::createFrameSerializer(version); + if (!frameSerializer) { + throw std::runtime_error{"Invalid protocol version"}; + } + + frameSerializer_ = std::move(frameSerializer); + frameSerializer_->preallocateFrameSizeField() = + frameTransport_ && frameTransport_->isConnectionFramed(); + } + + transportGuard.dismiss(); +} + +StreamId RSocketStateMachine::getNextStreamId() { + constexpr auto limit = + static_cast(std::numeric_limits::max() - 2); + + auto const streamId = nextStreamId_; + if (streamId >= limit) { + throw std::runtime_error{"Ran out of stream IDs"}; + } + + CHECK_EQ(0, streams_.count(streamId)) + << "Next stream ID already exists in the streams map"; + + nextStreamId_ += 2; + return streamId; +} + +void RSocketStateMachine::setNextStreamId(StreamId streamId) { + nextStreamId_ = streamId + 2; +} + +bool RSocketStateMachine::registerNewPeerStreamId(StreamId streamId) { + DCHECK_NE(0, streamId); + if (nextStreamId_ % 2 == streamId % 2) { + // if this is an unknown stream to the socket and this socket is + // generating such stream ids, it is an incoming frame on the stream which + // no longer exist + return false; + } + if (streamId <= lastPeerStreamId_) { + // receiving frame for a stream which no longer exists + return false; + } + lastPeerStreamId_ = streamId; + return true; +} + +bool RSocketStateMachine::hasStreams() const { + return !streams_.empty(); +} + } // namespace rsocket diff --git a/rsocket/statemachine/RSocketStateMachine.h b/rsocket/statemachine/RSocketStateMachine.h index 0c4e69944..71deeb322 100644 --- a/rsocket/statemachine/RSocketStateMachine.h +++ b/rsocket/statemachine/RSocketStateMachine.h @@ -1,8 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include +#include #include #include "rsocket/ColdResumeHandler.h" @@ -11,29 +23,31 @@ #include "rsocket/RSocketParameters.h" #include "rsocket/ResumeManager.h" #include "rsocket/framing/FrameProcessor.h" +#include "rsocket/framing/FrameSerializer.h" #include "rsocket/internal/Common.h" #include "rsocket/internal/KeepaliveTimer.h" -#include "rsocket/statemachine/StreamState.h" -#include "rsocket/statemachine/StreamsFactory.h" +#include "rsocket/statemachine/StreamFragmentAccumulator.h" +#include "rsocket/statemachine/StreamStateMachineBase.h" #include "rsocket/statemachine/StreamsWriter.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/single/SingleObserver.h" namespace rsocket { class ClientResumeStatusCallback; -class ConnectionSet; class DuplexConnection; -class FrameSerializer; class FrameTransport; class Frame_ERROR; class KeepaliveTimer; class RSocketConnectionEvents; class RSocketParameters; class RSocketResponder; +class RSocketResponderCore; class RSocketStateMachine; class RSocketStats; class ResumeManager; -class StreamState; -class StreamStateMachineBase; +class RSocketStateMachineTest; class FrameSink { public: @@ -60,9 +74,18 @@ class FrameSink { class RSocketStateMachine final : public FrameSink, public FrameProcessor, - public StreamsWriter, + public StreamsWriterImpl, public std::enable_shared_from_this { public: + RSocketStateMachine( + std::shared_ptr requestResponder, + std::unique_ptr keepaliveTimer, + RSocketMode mode, + std::shared_ptr stats, + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler); + RSocketStateMachine( std::shared_ptr requestResponder, std::unique_ptr keepaliveTimer, @@ -75,18 +98,18 @@ class RSocketStateMachine final ~RSocketStateMachine(); /// Create a new connection as a server. - void connectServer(yarpl::Reference, const SetupParameters&); + void connectServer(std::shared_ptr, const SetupParameters&); /// Resume a connection as a server. - bool resumeServer(yarpl::Reference, const ResumeParameters&); + bool resumeServer(std::shared_ptr, const ResumeParameters&); /// Connect as a client. Sends a SETUP frame. - void connectClient(yarpl::Reference, SetupParameters); + void connectClient(std::shared_ptr, SetupParameters); /// Resume a connection as a client. Sends a RESUME frame. void resumeClient( ResumeIdentificationToken, - yarpl::Reference, + std::shared_ptr, std::unique_ptr, ProtocolVersion); @@ -107,39 +130,18 @@ class RSocketStateMachine final /// Close the connection and all of its streams. void close(folly::exception_wrapper, StreamCompletionSignal); - /// A contract exposed to StreamAutomatonBase, modelled after Subscriber - /// and Subscription contracts, while omitting flow control related signals. + void requestStream( + Payload request, + std::shared_ptr> responseSink); - /// Adds a stream stateMachine to the connection. - /// - /// This signal corresponds to Subscriber::onSubscribe. - /// - /// No frames will be issued as a result of this call. Stream stateMachine - /// must take care of writing appropriate frames to the connection, using - /// ::writeFrame after calling this method. - void addStream(StreamId, yarpl::Reference); + std::shared_ptr> requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> responseSink); - /// Indicates that the stream should be removed from the connection. - /// - /// No frames will be issued as a result of this call. Stream stateMachine - /// must take care of writing appropriate frames to the connection, using - /// ::writeFrame, prior to calling this method. - /// - /// This signal corresponds to Subscriber::{onComplete,onError} and - /// Subscription::cancel. - /// Per ReactiveStreams specification: - /// 1. no other signal can be delivered during or after this one, - /// 2. "unsubscribe handshake" guarantees that the signal will be delivered - /// at least once, even if the stateMachine initiated stream closure, - /// 3. per "unsubscribe handshake", the stateMachine must deliver - /// corresponding - /// terminal signal to the connection. - /// - /// Additionally, in order to simplify implementation of stream stateMachine: - /// 4. the signal bound with a particular StreamId is idempotent and may be - /// delivered multiple times as long as the caller holds shared_ptr to - /// ConnectionAutomaton. - void endStream(StreamId, StreamCompletionSignal); + void requestResponse( + Payload payload, + std::shared_ptr> responseSink); /// Send a REQUEST_FNF frame. void fireAndForget(Payload); @@ -150,21 +152,73 @@ class RSocketStateMachine final /// Send a KEEPALIVE frame, with the RESPOND flag set. void sendKeepalive(std::unique_ptr) override; - /// Register the connection set that's holding this state machine. - void registerSet(std::shared_ptr); + class CloseCallback { + public: + virtual ~CloseCallback() = default; + virtual void remove(RSocketStateMachine&) = 0; + }; - StreamsFactory& streamsFactory() { - return streamsFactory_; - } + /// Register a callback to be called when the StateMachine is closed. + /// It will be used to inform the containers, i.e. ConnectionSet or + /// wangle::ConnectionManager, to don't store the StateMachine anymore. + void registerCloseCallback(CloseCallback* callback); DuplexConnection* getConnection(); + // Has active requests? + bool hasStreams() const; + private: - void connect(yarpl::Reference, ProtocolVersion); + // connection scope signals + void onKeepAliveFrame( + ResumePosition resumePosition, + std::unique_ptr data, + bool keepAliveRespond); + void onMetadataPushFrame(std::unique_ptr metadata); + void onResumeOkFrame(ResumePosition resumePosition); + void onErrorFrame(StreamId streamId, ErrorCode errorCode, Payload payload); + + // stream scope signals + void onRequestNFrame(StreamId streamId, uint32_t requestN); + void onCancelFrame(StreamId streamId); + void onPayloadFrame( + StreamId streamId, + Payload payload, + bool flagsFollows, + bool flagsComplete, + bool flagsNext); + + void onRequestStreamFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsFollows); + void onRequestChannelFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows); + void + onRequestResponseFrame(StreamId streamId, Payload payload, bool flagsFollows); + void + onFireAndForgetFrame(StreamId streamId, Payload payload, bool flagsFollows); + void onSetupFrame(); + void onResumeFrame(); + void onReservedFrame(); + void onLeaseFrame(); + void onExtFrame(); + void onUnexpectedFrame(StreamId streamId); + + std::shared_ptr getStreamStateMachine( + StreamId streamId); + + void connect(std::shared_ptr); /// Terminate underlying connection and connect new connection void reconnect( - yarpl::Reference, + std::shared_ptr, std::unique_ptr); void setResumable(bool); @@ -180,57 +234,35 @@ class RSocketStateMachine final uint32_t getKeepaliveTime() const; - void setFrameSerializer(std::unique_ptr); + void sendPendingFrames() override; - void sendPendingFrames(); - - /// Send a frame to the output. Will buffer the frame if the state machine is - /// disconnected or in the process of resuming. - void outputFrameOrEnqueue(std::unique_ptr); - - template - void outputFrameOrEnqueue(T&& frame) { - VLOG(3) << mode_ << " Out: " << frame; - outputFrameOrEnqueue( - frameSerializer_->serializeOut(std::forward(frame))); + // Should buffer the frame if the state machine is disconnected or in the + // process of resuming. + bool shouldQueue() override; + RSocketStats& stats() override { + return *stats_; } - template - bool deserializeFrameOrError( - TFrame& frame, - std::unique_ptr buf) { - if (frameSerializer_->deserializeFrom(frame, std::move(buf))) { - return true; - } - closeWithError(Frame_ERROR::connectionError("Invalid frame")); - return false; + FrameSerializer& serializer() override { + return *frameSerializer_; } template bool deserializeFrameOrError( - bool resumable, TFrame& frame, std::unique_ptr buf) { - if (frameSerializer_->deserializeFrom(frame, std::move(buf), resumable)) { + if (frameSerializer_->deserializeFrom(frame, std::move(buf))) { return true; } closeWithError(Frame_ERROR::connectionError("Invalid frame")); return false; } - /// Performs the same actions as ::endStream without propagating closure - /// signal to the underlying connection. - /// - /// The call is idempotent and returns false iff a stream has not been found. - bool endStreamInternal(StreamId streamId, StreamCompletionSignal signal); - // FrameProcessor. void processFrame(std::unique_ptr) override; void onTerminal(folly::exception_wrapper) override; - void handleConnectionFrame(FrameType, std::unique_ptr); - void handleStreamFrame(StreamId, FrameType, std::unique_ptr); - void handleUnknownStream(StreamId, FrameType, std::unique_ptr); + void handleFrame(StreamId, FrameType, std::unique_ptr); void closeStreams(StreamCompletionSignal); void closeFrameTransport(folly::exception_wrapper); @@ -238,27 +270,43 @@ class RSocketStateMachine final void sendKeepalive(FrameFlags, std::unique_ptr); void resumeFromPosition(ResumePosition); - void outputFrame(std::unique_ptr); + void outputFrame(std::unique_ptr) override; void writeNewStream( StreamId streamId, StreamType streamType, uint32_t initialRequestN, - Payload payload, - bool completed) override; - void writeRequestN(Frame_REQUEST_N&&) override; - void writeCancel(Frame_CANCEL&&) override; - - void writePayload(Frame_PAYLOAD&&) override; - void writeError(Frame_ERROR&&) override; + Payload payload) override; - void onStreamClosed(StreamId streamId, StreamCompletionSignal signal) + std::shared_ptr> onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) override; + void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) override; + void onStreamClosed(StreamId) override; + bool ensureOrAutodetectFrameSerializer(const folly::IOBuf& firstFrame); + bool ensureNotInResumption(); size_t getConsumerAllowance(StreamId) const; + void setProtocolVersionOrThrow( + ProtocolVersion version, + const std::shared_ptr& transport); + + bool isNewStreamId(StreamId streamId); + bool registerNewPeerStreamId(StreamId streamId); + StreamId getNextStreamId(); + + void setNextStreamId(StreamId streamId); + /// Client/server mode this state machine is operating in. const RSocketMode mode_; @@ -273,14 +321,17 @@ class RSocketStateMachine final std::shared_ptr stats_; - /// Per-stream frame buffer between the state machine and the FrameTransport. - StreamState streamState_; + /// Map of all individual stream state machines. + std::unordered_map> + streams_; + StreamId nextStreamId_; + StreamId lastPeerStreamId_{0}; // Manages all state needed for warm/cold resumption. std::shared_ptr resumeManager_; - std::shared_ptr requestResponder_; - yarpl::Reference frameTransport_; + const std::shared_ptr requestResponder_; + std::shared_ptr frameTransport_; std::unique_ptr frameSerializer_; const std::unique_ptr keepaliveTimer_; @@ -288,11 +339,11 @@ class RSocketStateMachine final std::unique_ptr resumeCallback_; std::shared_ptr coldResumeHandler_; - StreamsFactory streamsFactory_; - std::shared_ptr connectionEvents_; - /// Back reference to the set that's holding this state machine. - std::weak_ptr connectionSet_; + CloseCallback* closeCallback_{nullptr}; + + friend class RSocketStateMachineTest; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseRequester.cpp b/rsocket/statemachine/RequestResponseRequester.cpp index c2aada8c9..2d39be17b 100644 --- a/rsocket/statemachine/RequestResponseRequester.cpp +++ b/rsocket/statemachine/RequestResponseRequester.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/RequestResponseRequester.h" @@ -7,26 +19,23 @@ namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void RequestResponseRequester::subscribe( - yarpl::Reference> subscriber) { - DCHECK(!isTerminated()); + std::shared_ptr> subscriber) { + DCHECK(state_ != State::CLOSED); DCHECK(!consumingSubscriber_); consumingSubscriber_ = std::move(subscriber); - consumingSubscriber_->onSubscribe(this->ref_from_this(this)); + consumingSubscriber_->onSubscribe(shared_from_this()); if (state_ == State::NEW) { state_ = State::REQUESTED; newStream(StreamType::REQUEST_RESPONSE, 1, std::move(initialPayload_)); - } else { - if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError( - std::runtime_error("cannot request more than 1 item")); - } - closeStream(StreamCompletionSignal::ERROR); + return; } + + if (auto subscriber = std::move(consumingSubscriber_)) { + subscriber->onError(std::runtime_error("cannot request more than 1 item")); + } + removeFromWriter(); } void RequestResponseRequester::cancel() noexcept { @@ -34,17 +43,17 @@ void RequestResponseRequester::cancel() noexcept { switch (state_) { case State::NEW: state_ = State::CLOSED; - closeStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); break; case State::REQUESTED: { state_ = State::CLOSED; - cancelStream(); - closeStream(StreamCompletionSignal::CANCEL); + writeCancel(); + removeFromWriter(); } break; case State::CLOSED: break; } - consumingSubscriber_ = nullptr; + consumingSubscriber_.reset(); } void RequestResponseRequester::endStream(StreamCompletionSignal signal) { @@ -66,8 +75,7 @@ void RequestResponseRequester::endStream(StreamCompletionSignal signal) { } } -void RequestResponseRequester::handleError( - folly::exception_wrapper errorPayload) { +void RequestResponseRequester::handleError(folly::exception_wrapper ew) { switch (state_) { case State::NEW: // Cannot receive a frame before sending the initial request. @@ -76,9 +84,9 @@ void RequestResponseRequester::handleError( case State::REQUESTED: state_ = State::CLOSED; if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError(errorPayload); + subscriber->onError(std::move(ew)); } - closeStream(StreamCompletionSignal::ERROR); + removeFromWriter(); break; case State::CLOSED: break; @@ -87,34 +95,41 @@ void RequestResponseRequester::handleError( void RequestResponseRequester::handlePayload( Payload&& payload, - bool complete, - bool flagsNext) { - switch (state_) { - case State::NEW: - // Cannot receive a frame before sending the initial request. - CHECK(false); - break; - case State::REQUESTED: - state_ = State::CLOSED; - break; - case State::CLOSED: - // should not be receiving frames when closed - // if we ended up here, we broke some internal invariant of the class - CHECK(false); - break; + bool /*flagsComplete*/, + bool flagsNext, + bool flagsFollows) { + // (State::NEW) Cannot receive a frame before sending the initial request. + // (State::CLOSED) should not be receiving frames when closed + // if we fail here, we broke some internal invariant of the class + CHECK(state_ == State::REQUESTED); + + payloadFragments_.addPayload(std::move(payload), flagsNext, false); + + if (flagsFollows) { + // there will be more fragments to come + return; } - if (payload || flagsNext) { - consumingSubscriber_->onSuccess(std::move(payload)); + bool finalFlagsNext, finalFlagsComplete; + Payload finalPayload; + + std::tie(finalPayload, finalFlagsNext, finalFlagsComplete) = + payloadFragments_.consumePayloadAndFlags(); + + state_ = State::CLOSED; + + if (finalPayload || finalFlagsNext) { + consumingSubscriber_->onSuccess(std::move(finalPayload)); consumingSubscriber_ = nullptr; - } else if (!complete) { - errorStream("payload, NEXT or COMPLETE flag expected"); - return; + } else if (!finalFlagsComplete) { + writeInvalidError("Payload, NEXT or COMPLETE flag expected"); + endStream(StreamCompletionSignal::ERROR); } - closeStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); } size_t RequestResponseRequester::getConsumerAllowance() const { return (state_ == State::REQUESTED) ? 1 : 0; } -} + +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseRequester.h b/rsocket/statemachine/RequestResponseRequester.h index 4b84201e3..be17cf546 100644 --- a/rsocket/statemachine/RequestResponseRequester.h +++ b/rsocket/statemachine/RequestResponseRequester.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -11,9 +23,10 @@ namespace rsocket { /// Implementation of stream stateMachine that represents a RequestResponse /// requester -class RequestResponseRequester : public StreamStateMachineBase, - public yarpl::single::SingleSubscription, - public yarpl::enable_get_ref { +class RequestResponseRequester + : public StreamStateMachineBase, + public yarpl::single::SingleSubscription, + public std::enable_shared_from_this { public: RequestResponseRequester( std::shared_ptr writer, @@ -23,13 +36,17 @@ class RequestResponseRequester : public StreamStateMachineBase, initialPayload_(std::move(payload)) {} void subscribe( - yarpl::Reference> subscriber); + std::shared_ptr> subscriber); private: void cancel() noexcept override; - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleError(folly::exception_wrapper errorPayload) override; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleError(folly::exception_wrapper ew) override; void endStream(StreamCompletionSignal signal) override; @@ -40,12 +57,14 @@ class RequestResponseRequester : public StreamStateMachineBase, NEW, REQUESTED, CLOSED, - } state_{State::NEW}; + }; + + State state_{State::NEW}; /// The observer that will consume payloads. - yarpl::Reference> consumingSubscriber_; + std::shared_ptr> consumingSubscriber_; /// Initial payload which has to be sent with 1st request. Payload initialPayload_; }; -} +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseResponder.cpp b/rsocket/statemachine/RequestResponseResponder.cpp index cc9bfef11..ca51ff54b 100644 --- a/rsocket/statemachine/RequestResponseResponder.cpp +++ b/rsocket/statemachine/RequestResponseResponder.cpp @@ -1,32 +1,33 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/RequestResponseResponder.h" -#include "rsocket/Payload.h" - namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void RequestResponseResponder::onSubscribe( - Reference subscription) noexcept { -#ifdef DEBUG - DCHECK(!gotOnSubscribe_.exchange(true)) << "Already called onSubscribe()"; -#endif - - if (StreamStateMachineBase::isTerminated()) { + std::shared_ptr subscription) { + DCHECK(State::NEW != state_); + if (state_ == State::CLOSED) { subscription->cancel(); return; } producingSubscription_ = std::move(subscription); } -void RequestResponseResponder::onSuccess(Payload response) noexcept { -#ifdef DEBUG - DCHECK(gotOnSubscribe_.load()) << "didnt call onSubscribe"; - DCHECK(!gotTerminating_.exchange(true)) << "Already called onSuccess/onError"; -#endif +void RequestResponseResponder::onSuccess(Payload response) { + DCHECK(State::NEW != state_); if (!producingSubscription_) { return; } @@ -34,59 +35,93 @@ void RequestResponseResponder::onSuccess(Payload response) noexcept { switch (state_) { case State::RESPONDING: { state_ = State::CLOSED; - writePayload(std::move(response), true); + writePayload(std::move(response), true /* complete */); producingSubscription_ = nullptr; - closeStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); break; } case State::CLOSED: break; + + case State::NEW: + default: + // class is internally misused + CHECK(false); } } -void RequestResponseResponder::onError(folly::exception_wrapper ex) noexcept { -#ifdef DEBUG - DCHECK(gotOnSubscribe_.load()) << "didnt call onSubscribe"; - DCHECK(!gotTerminating_.exchange(true)) << "Already called onSuccess/onError"; -#endif - +void RequestResponseResponder::onError(folly::exception_wrapper ex) { + DCHECK(State::NEW != state_); producingSubscription_ = nullptr; switch (state_) { case State::RESPONDING: { state_ = State::CLOSED; - applicationError(ex.get_exception()->what()); - closeStream(StreamCompletionSignal::APPLICATION_ERROR); + if (!ex.with_exception([this](rsocket::ErrorWithPayload& err) { + writeApplicationError(std::move(err.payload)); + })) { + writeApplicationError(ex.get_exception()->what()); + } + removeFromWriter(); } break; case State::CLOSED: break; + + case State::NEW: + default: + // class is internally misused + CHECK(false); } } -void RequestResponseResponder::endStream(StreamCompletionSignal signal) { +void RequestResponseResponder::handleCancel() { switch (state_) { case State::RESPONDING: - // Spontaneous ::endStream signal means an error. - DCHECK(StreamCompletionSignal::COMPLETE != signal); - DCHECK(StreamCompletionSignal::CANCEL != signal); state_ = State::CLOSED; + removeFromWriter(); break; + case State::NEW: case State::CLOSED: break; } - if (auto subscription = std::move(producingSubscription_)) { - subscription->cancel(); +} + +void RequestResponseResponder::handlePayload( + Payload&& payload, + bool /*flagsComplete*/, + bool /*flagsNext*/, + bool flagsFollows) { + payloadFragments_.addPayloadIgnoreFlags(std::move(payload)); + + if (flagsFollows) { + // there will be more fragments to come + return; } - StreamStateMachineBase::endStream(signal); + + CHECK(state_ == State::NEW); + Payload finalPayload = payloadFragments_.consumePayloadIgnoreFlags(); + + state_ = State::RESPONDING; + onNewStreamReady( + StreamType::REQUEST_RESPONSE, + std::move(finalPayload), + shared_from_this()); } -void RequestResponseResponder::handleCancel() { +void RequestResponseResponder::endStream(StreamCompletionSignal signal) { switch (state_) { + case State::NEW: case State::RESPONDING: + // Spontaneous ::endStream signal means an error. + DCHECK(StreamCompletionSignal::COMPLETE != signal); + DCHECK(StreamCompletionSignal::CANCEL != signal); state_ = State::CLOSED; - closeStream(StreamCompletionSignal::CANCEL); break; case State::CLOSED: break; } + if (auto subscription = std::move(producingSubscription_)) { + subscription->cancel(); + } } -} + +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseResponder.h b/rsocket/statemachine/RequestResponseResponder.h index 86057c244..3e7a5e37b 100644 --- a/rsocket/statemachine/RequestResponseResponder.h +++ b/rsocket/statemachine/RequestResponseResponder.h @@ -1,46 +1,61 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include "rsocket/Payload.h" #include "rsocket/statemachine/StreamStateMachineBase.h" -#include "yarpl/flowable/Subscriber.h" #include "yarpl/single/SingleObserver.h" #include "yarpl/single/SingleSubscription.h" -#include - namespace rsocket { /// Implementation of stream stateMachine that represents a RequestResponse /// responder -class RequestResponseResponder : public StreamStateMachineBase, - public yarpl::single::SingleObserver { +class RequestResponseResponder + : public StreamStateMachineBase, + public yarpl::single::SingleObserver, + public std::enable_shared_from_this { public: RequestResponseResponder( std::shared_ptr writer, StreamId streamId) : StreamStateMachineBase(std::move(writer), streamId) {} - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onSuccess(Payload) noexcept override; - void onError(folly::exception_wrapper) noexcept override; + void onSubscribe(std::shared_ptr) override; + void onSuccess(Payload) override; + void onError(folly::exception_wrapper) override; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; void handleCancel() override; void endStream(StreamCompletionSignal) override; + private: /// State of the Subscription responder. enum class State : uint8_t { + NEW, RESPONDING, CLOSED, - } state_{State::RESPONDING}; + }; - yarpl::Reference producingSubscription_; -#ifdef DEBUG - std::atomic gotOnSubscribe_{false}; - std::atomic gotTerminating_{false}; -#endif + std::shared_ptr producingSubscription_; + State state_{State::NEW}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamFragmentAccumulator.cpp b/rsocket/statemachine/StreamFragmentAccumulator.cpp new file mode 100644 index 000000000..07c7a3986 --- /dev/null +++ b/rsocket/statemachine/StreamFragmentAccumulator.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/StreamFragmentAccumulator.h" + +namespace rsocket { + +StreamFragmentAccumulator::StreamFragmentAccumulator() + : flagsComplete(false), flagsNext(false) {} + +void StreamFragmentAccumulator::addPayloadIgnoreFlags(Payload p) { + if (p.metadata) { + if (!fragments.metadata) { + fragments.metadata = std::move(p.metadata); + } else { + fragments.metadata->prev()->appendChain(std::move(p.metadata)); + } + } + + if (p.data) { + if (!fragments.data) { + fragments.data = std::move(p.data); + } else { + fragments.data->prev()->appendChain(std::move(p.data)); + } + } +} + +void StreamFragmentAccumulator::addPayload( + Payload p, + bool next, + bool complete) { + flagsNext |= next; + flagsComplete |= complete; + addPayloadIgnoreFlags(std::move(p)); +} + +Payload StreamFragmentAccumulator::consumePayloadIgnoreFlags() { + flagsComplete = false; + flagsNext = false; + return std::move(fragments); +} + +std::tuple +StreamFragmentAccumulator::consumePayloadAndFlags() { + auto ret = std::make_tuple( + std::move(fragments), bool(flagsNext), bool(flagsComplete)); + flagsComplete = false; + flagsNext = false; + return ret; +} + +} /* namespace rsocket */ diff --git a/rsocket/statemachine/StreamFragmentAccumulator.h b/rsocket/statemachine/StreamFragmentAccumulator.h new file mode 100644 index 000000000..0ed5227d8 --- /dev/null +++ b/rsocket/statemachine/StreamFragmentAccumulator.h @@ -0,0 +1,41 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" + +namespace rsocket { + +class StreamFragmentAccumulator { + public: + StreamFragmentAccumulator(); + + void addPayloadIgnoreFlags(Payload p); + void addPayload(Payload p, bool next, bool complete); + + Payload consumePayloadIgnoreFlags(); + std::tuple consumePayloadAndFlags(); + + bool anyFragments() const { + return fragments.data || fragments.metadata; + } + + private: + bool flagsComplete : 1; + bool flagsNext : 1; + Payload fragments; +}; + +} /* namespace rsocket */ diff --git a/rsocket/statemachine/StreamRequester.cpp b/rsocket/statemachine/StreamRequester.cpp index c5e70a350..52e407be9 100644 --- a/rsocket/statemachine/StreamRequester.cpp +++ b/rsocket/statemachine/StreamRequester.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/StreamRequester.h" @@ -10,72 +22,66 @@ void StreamRequester::setRequested(size_t n) { addImplicitAllowance(n); } -void StreamRequester::request(int64_t n) noexcept { - if (n == 0) { +void StreamRequester::request(int64_t signedN) { + if (signedN <= 0 || consumerClosed()) { return; } - if(!requested_) { - requested_ = true; + const size_t n = signedN; - auto initialN = - n > Frame_REQUEST_N::kMaxRequestN ? Frame_REQUEST_N::kMaxRequestN : n; - auto remainingN = n > Frame_REQUEST_N::kMaxRequestN - ? n - Frame_REQUEST_N::kMaxRequestN - : 0; + if (requested_) { + generateRequest(n); + return; + } - // Send as much as possible with the initial request. - CHECK_GE(Frame_REQUEST_N::kMaxRequestN, initialN); + requested_ = true; - // We must inform ConsumerBase about an implicit allowance we have - // requested from the remote end. - addImplicitAllowance(initialN); - newStream( - StreamType::STREAM, - static_cast(initialN), - std::move(initialPayload_)); + // We must inform ConsumerBase about an implicit allowance we have requested + // from the remote end. + auto const initial = std::min(n, kMaxRequestN); + addImplicitAllowance(initial); + newStream(StreamType::STREAM, initial, std::move(initialPayload_)); - // Pump the remaining allowance into the ConsumerBase _after_ sending the - // initial request. - if (remainingN) { - generateRequest(remainingN); - } - return; + // Pump the remaining allowance into the ConsumerBase _after_ sending the + // initial request. + if (n > initial) { + generateRequest(n - initial); } - - generateRequest(n); } -void StreamRequester::cancel() noexcept { +void StreamRequester::cancel() { VLOG(5) << "StreamRequester::cancel(requested_=" << requested_ << ")"; + if (consumerClosed()) { + return; + } + cancelConsumer(); if (requested_) { - cancelConsumer(); - cancelStream(); + writeCancel(); } - closeStream(StreamCompletionSignal::CANCEL); -} - -void StreamRequester::endStream(StreamCompletionSignal signal) { - VLOG(5) << "StreamRequester::endStream()"; - ConsumerBase::endStream(signal); + removeFromWriter(); } void StreamRequester::handlePayload( Payload&& payload, bool complete, - bool next) { - CHECK(requested_); - processPayload(std::move(payload), next); + bool next, + bool follows) { + if (!requested_) { + handleError(std::runtime_error("Haven't sent REQUEST_STREAM yet")); + return; + } + bool finalComplete = + processFragmentedPayload(std::move(payload), next, complete, follows); - if (complete) { + if (finalComplete) { completeConsumer(); - closeStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); } } -void StreamRequester::handleError(folly::exception_wrapper errorPayload) { - CHECK(requested_); - errorConsumer(std::move(errorPayload)); - closeStream(StreamCompletionSignal::ERROR); -} +void StreamRequester::handleError(folly::exception_wrapper ew) { + errorConsumer(std::move(ew)); + removeFromWriter(); } + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamRequester.h b/rsocket/statemachine/StreamRequester.h index ffbb38f5a..696b81472 100644 --- a/rsocket/statemachine/StreamRequester.h +++ b/rsocket/statemachine/StreamRequester.h @@ -1,26 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include -#include "rsocket/internal/Allowance.h" #include "rsocket/statemachine/ConsumerBase.h" -namespace folly { -class exception_wrapper; -} - namespace rsocket { -enum class StreamCompletionSignal; - /// Implementation of stream stateMachine that represents a Stream requester class StreamRequester : public ConsumerBase { - using Base = ConsumerBase; - public: - // initialization of the ExecutorBase will be ignored for any of the - // derived classes StreamRequester( std::shared_ptr writer, StreamId streamId, @@ -28,24 +28,24 @@ class StreamRequester : public ConsumerBase { : ConsumerBase(std::move(writer), streamId), initialPayload_(std::move(payload)) {} - void setRequested(size_t n); - - private: - // implementation from ConsumerBase::Subscription - void request(int64_t) noexcept override; - void cancel() noexcept override; + void setRequested(size_t); - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleError(folly::exception_wrapper errorPayload) override; + void request(int64_t) override; + void cancel() override; - void endStream(StreamCompletionSignal) override; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleError(folly::exception_wrapper ew) override; - /// An allowance accumulated before the stream is initialised. - /// Remaining part of the allowance is forwarded to the ConsumerBase. - Allowance initialResponseAllowance_; - - /// Initial payload which has to be sent with 1st request. + private: + /// Payload to be sent with the first request. Payload initialPayload_; + + /// Whether request() has been called. bool requested_{false}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamResponder.cpp b/rsocket/statemachine/StreamResponder.cpp index c0e38859a..9dfa8a6a3 100644 --- a/rsocket/statemachine/StreamResponder.cpp +++ b/rsocket/statemachine/StreamResponder.cpp @@ -1,51 +1,102 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/StreamResponder.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void StreamResponder::onSubscribe( - Reference subscription) noexcept { + std::shared_ptr subscription) { publisherSubscribe(std::move(subscription)); } -void StreamResponder::onNext(Payload response) noexcept { - checkPublisherOnNext(); - if (!publisherClosed()) { - writePayload(std::move(response), false); +void StreamResponder::onNext(Payload response) { + if (publisherClosed()) { + return; } + writePayload(std::move(response)); } -void StreamResponder::onComplete() noexcept { - if (!publisherClosed()) { - publisherComplete(); - completeStream(); - closeStream(StreamCompletionSignal::COMPLETE); +void StreamResponder::onComplete() { + if (publisherClosed()) { + return; } + publisherComplete(); + writeComplete(); + removeFromWriter(); } -void StreamResponder::onError(folly::exception_wrapper ex) noexcept { - if (!publisherClosed()) { - publisherComplete(); - applicationError(ex.get_exception()->what()); - closeStream(StreamCompletionSignal::ERROR); +void StreamResponder::onError(folly::exception_wrapper ew) { + if (publisherClosed()) { + return; + } + publisherComplete(); + if (!ew.with_exception([this](rsocket::ErrorWithPayload& err) { + writeApplicationError(std::move(err.payload)); + })) { + writeApplicationError(ew.get_exception()->what()); } + removeFromWriter(); } -void StreamResponder::endStream(StreamCompletionSignal signal) { - terminatePublisher(); - StreamStateMachineBase::endStream(signal); +void StreamResponder::handleRequestN(uint32_t n) { + processRequestN(n); } -void StreamResponder::handleCancel() { - closeStream(StreamCompletionSignal::CANCEL); - publisherComplete(); +void StreamResponder::handleError(folly::exception_wrapper) { + handleCancel(); } -void StreamResponder::handleRequestN(uint32_t n) { - processRequestN(n); +void StreamResponder::handlePayload( + Payload&& payload, + bool /*flagsComplete*/, + bool /*flagsNext*/, + bool flagsFollows) { + payloadFragments_.addPayloadIgnoreFlags(std::move(payload)); + + if (flagsFollows) { + // there will be more fragments to come + return; + } + + Payload finalPayload = payloadFragments_.consumePayloadIgnoreFlags(); + + if (newStream_) { + newStream_ = false; + onNewStreamReady( + StreamType::STREAM, std::move(finalPayload), shared_from_this()); + } else { + // per rsocket spec, ignore unexpected frame (payload) if it makes no sense + // in the semantic of the stream + } } + +void StreamResponder::handleCancel() { + if (publisherClosed()) { + return; + } + terminatePublisher(); + removeFromWriter(); } + +void StreamResponder::endStream(StreamCompletionSignal signal) { + if (publisherClosed()) { + return; + } + terminatePublisher(); + writeApplicationError(to_string(signal)); + removeFromWriter(); +} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamResponder.h b/rsocket/statemachine/StreamResponder.h index 678452f7a..09b445eda 100644 --- a/rsocket/statemachine/StreamResponder.h +++ b/rsocket/statemachine/StreamResponder.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -11,7 +23,8 @@ namespace rsocket { /// Implementation of stream stateMachine that represents a Stream responder class StreamResponder : public StreamStateMachineBase, public PublisherBase, - public yarpl::flowable::Subscriber { + public yarpl::flowable::Subscriber, + public std::enable_shared_from_this { public: StreamResponder( std::shared_ptr writer, @@ -20,17 +33,24 @@ class StreamResponder : public StreamStateMachineBase, : StreamStateMachineBase(std::move(writer), streamId), PublisherBase(initialRequestN) {} - protected: - void handleCancel() override; - void handleRequestN(uint32_t n) override; + void onSubscribe(std::shared_ptr) override; + void onNext(Payload) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload) noexcept override; - void onComplete() noexcept override; - void onError(folly::exception_wrapper) noexcept override; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleRequestN(uint32_t) override; + void handleError(folly::exception_wrapper) override; + void handleCancel() override; void endStream(StreamCompletionSignal) override; + + private: + bool newStream_{true}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamState.cpp b/rsocket/statemachine/StreamState.cpp deleted file mode 100644 index 7cdbd4edd..000000000 --- a/rsocket/statemachine/StreamState.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/statemachine/StreamState.h" - -#include "rsocket/RSocketStats.h" - -namespace rsocket { - -StreamState::StreamState(RSocketStats& stats) : stats_(stats) {} - -StreamState::~StreamState() { - onClearFrames(); -} - -void StreamState::enqueueOutputPendingFrame( - std::unique_ptr frame) { - auto length = frame->computeChainDataLength(); - stats_.streamBufferChanged(1, static_cast(length)); - dataLength_ += length; - outputFrames_.push_back(std::move(frame)); -} - -std::deque> -StreamState::moveOutputPendingFrames() { - onClearFrames(); - return std::move(outputFrames_); -} - -void StreamState::onClearFrames() { - auto numFrames = outputFrames_.size(); - if (numFrames != 0) { - stats_.streamBufferChanged( - -static_cast(numFrames), -static_cast(dataLength_)); - dataLength_ = 0; - } -} -} diff --git a/rsocket/statemachine/StreamState.h b/rsocket/statemachine/StreamState.h deleted file mode 100644 index 14907aaf4..000000000 --- a/rsocket/statemachine/StreamState.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include - -#include "rsocket/statemachine/StreamStateMachineBase.h" -#include "yarpl/Refcounted.h" - -namespace rsocket { - -class RSocketStateMachine; -class RSocketStats; -class StreamStateMachineBase; - -class StreamState { - public: - explicit StreamState(RSocketStats& stats); - ~StreamState(); - - void enqueueOutputPendingFrame(std::unique_ptr frame); - - std::deque> moveOutputPendingFrames(); - - std::unordered_map> - streams_; - - private: - /// Called to update stats when outputFrames_ is about to be cleared. - void onClearFrames(); - - RSocketStats& stats_; - - /// Total data length of all IOBufs in outputFrames_. - uint64_t dataLength_{0}; - - std::deque> outputFrames_; -}; -} diff --git a/rsocket/statemachine/StreamStateMachineBase.cpp b/rsocket/statemachine/StreamStateMachineBase.cpp index 31874164f..f0988fff7 100644 --- a/rsocket/statemachine/StreamStateMachineBase.cpp +++ b/rsocket/statemachine/StreamStateMachineBase.cpp @@ -1,24 +1,31 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/StreamStateMachineBase.h" - #include - #include "rsocket/statemachine/RSocketStateMachine.h" #include "rsocket/statemachine/StreamsWriter.h" namespace rsocket { -void StreamStateMachineBase::handlePayload(Payload&&, bool, bool) { - VLOG(4) << "Unexpected handlePayload"; -} - void StreamStateMachineBase::handleRequestN(uint32_t) { VLOG(4) << "Unexpected handleRequestN"; } void StreamStateMachineBase::handleError(folly::exception_wrapper) { - closeStream(StreamCompletionSignal::ERROR); + endStream(StreamCompletionSignal::ERROR); + removeFromWriter(); } void StreamStateMachineBase::handleCancel() { @@ -29,49 +36,65 @@ size_t StreamStateMachineBase::getConsumerAllowance() const { return 0; } -void StreamStateMachineBase::endStream(StreamCompletionSignal) { - isTerminated_ = true; -} - void StreamStateMachineBase::newStream( StreamType streamType, uint32_t initialRequestN, - Payload payload, - bool completed) { + Payload payload) { writer_->writeNewStream( - streamId_, streamType, initialRequestN, std::move(payload), completed); + streamId_, streamType, initialRequestN, std::move(payload)); +} + +void StreamStateMachineBase::writeRequestN(uint32_t n) { + writer_->writeRequestN(Frame_REQUEST_N{streamId_, n}); +} + +void StreamStateMachineBase::writeCancel() { + writer_->writeCancel(Frame_CANCEL{streamId_}); } void StreamStateMachineBase::writePayload(Payload&& payload, bool complete) { auto const flags = - FrameFlags::NEXT | (complete ? FrameFlags::COMPLETE : FrameFlags::EMPTY); + FrameFlags::NEXT | (complete ? FrameFlags::COMPLETE : FrameFlags::EMPTY_); Frame_PAYLOAD frame{streamId_, flags, std::move(payload)}; writer_->writePayload(std::move(frame)); } -void StreamStateMachineBase::writeRequestN(uint32_t n) { - writer_->writeRequestN(Frame_REQUEST_N{streamId_, n}); +void StreamStateMachineBase::writeComplete() { + writer_->writePayload(Frame_PAYLOAD::complete(streamId_)); } -void StreamStateMachineBase::applicationError(std::string msg) { - writer_->writeError(Frame_ERROR::applicationError(streamId_, std::move(msg))); +void StreamStateMachineBase::writeApplicationError(folly::StringPiece msg) { + writer_->writeError(Frame_ERROR::applicationError(streamId_, msg)); } -void StreamStateMachineBase::errorStream(std::string msg) { - writer_->writeError(Frame_ERROR::invalid(streamId_, std::move(msg))); - closeStream(StreamCompletionSignal::ERROR); +void StreamStateMachineBase::writeApplicationError(Payload&& payload) { + writer_->writeError( + Frame_ERROR::applicationError(streamId_, std::move(payload))); } -void StreamStateMachineBase::cancelStream() { - writer_->writeCancel(Frame_CANCEL{streamId_}); +void StreamStateMachineBase::writeInvalidError(folly::StringPiece msg) { + writer_->writeError(Frame_ERROR::invalid(streamId_, msg)); } -void StreamStateMachineBase::completeStream() { - writer_->writePayload(Frame_PAYLOAD::complete(streamId_)); +void StreamStateMachineBase::removeFromWriter() { + writer_->onStreamClosed(streamId_); + // TODO: set writer_ to nullptr } -void StreamStateMachineBase::closeStream(StreamCompletionSignal signal) { - writer_->onStreamClosed(streamId_, signal); - // TODO: set writer_ to nullptr +std::shared_ptr> +StreamStateMachineBase::onNewStreamReady( + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + return writer_->onNewStreamReady( + streamId_, streamType, std::move(payload), std::move(response)); } + +void StreamStateMachineBase::onNewStreamReady( + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + writer_->onNewStreamReady( + streamId_, streamType, std::move(payload), std::move(response)); } +} // namespace rsocket diff --git a/rsocket/statemachine/StreamStateMachineBase.h b/rsocket/statemachine/StreamStateMachineBase.h index 2a6ad74b7..012c3b7fa 100644 --- a/rsocket/statemachine/StreamStateMachineBase.h +++ b/rsocket/statemachine/StreamStateMachineBase.h @@ -1,13 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include - #include +#include "rsocket/framing/FrameHeader.h" #include "rsocket/internal/Common.h" -#include "yarpl/Refcounted.h" +#include "rsocket/statemachine/StreamFragmentAccumulator.h" +#include "yarpl/Flowable.h" +#include "yarpl/Single.h" namespace folly { class IOBuf; @@ -22,17 +35,21 @@ struct Payload; /// /// The instances might be destroyed on a different thread than they were /// created. -class StreamStateMachineBase : public virtual yarpl::Refcounted { +class StreamStateMachineBase { public: StreamStateMachineBase( std::shared_ptr writer, StreamId streamId) - : writer_{std::move(writer)}, streamId_(streamId) {} + : writer_(std::move(writer)), streamId_(streamId) {} virtual ~StreamStateMachineBase() = default; - virtual void handlePayload(Payload&& payload, bool complete, bool flagsNext); + virtual void handlePayload( + Payload&& payload, + bool complete, + bool flagsNext, + bool flagsFollows) = 0; virtual void handleRequestN(uint32_t n); - virtual void handleError(folly::exception_wrapper errorPayload); + virtual void handleError(folly::exception_wrapper); virtual void handleCancel(); virtual size_t getConsumerAllowance() const; @@ -48,32 +65,41 @@ class StreamStateMachineBase : public virtual yarpl::Refcounted { /// 3. per "unsubscribe handshake", the state machine must deliver /// corresponding /// terminal signal to the connection. - virtual void endStream(StreamCompletionSignal signal); + virtual void endStream(StreamCompletionSignal) {} protected: - bool isTerminated() const { - return isTerminated_; - } + void + newStream(StreamType streamType, uint32_t initialRequestN, Payload payload); + + void writeRequestN(uint32_t); + void writeCancel(); + + void writePayload(Payload&& payload, bool complete = false); + void writeComplete(); + void writeApplicationError(folly::StringPiece); + void writeApplicationError(Payload&& payload); + void writeInvalidError(folly::StringPiece); + + void removeFromWriter(); + + std::shared_ptr> onNewStreamReady( + StreamType streamType, + Payload payload, + std::shared_ptr> response); - void newStream( + void onNewStreamReady( StreamType streamType, - uint32_t initialRequestN, Payload payload, - bool completed = false); - void writePayload(Payload&& payload, bool complete); - void writeRequestN(uint32_t n); - void applicationError(std::string errorPayload); - void errorStream(std::string errorPayload); - void cancelStream(); - void completeStream(); - void closeStream(StreamCompletionSignal signal); + std::shared_ptr> response); /// A partially-owning pointer to the connection, the stream runs on. /// It is declared as const to allow only ctor to initialize it for thread /// safety of the dtor. const std::shared_ptr writer_; + StreamFragmentAccumulator payloadFragments_; + + private: const StreamId streamId_; - // TODO: remove and nulify the writer_ instead - bool isTerminated_{false}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamsFactory.cpp b/rsocket/statemachine/StreamsFactory.cpp deleted file mode 100644 index 0fdd5ea9c..000000000 --- a/rsocket/statemachine/StreamsFactory.cpp +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/statemachine/StreamsFactory.h" - -#include "rsocket/statemachine/ChannelRequester.h" -#include "rsocket/statemachine/ChannelResponder.h" -#include "rsocket/statemachine/RSocketStateMachine.h" -#include "rsocket/statemachine/RequestResponseRequester.h" -#include "rsocket/statemachine/RequestResponseResponder.h" -#include "rsocket/statemachine/StreamRequester.h" -#include "rsocket/statemachine/StreamResponder.h" - -#include "yarpl/flowable/Flowable.h" -#include "yarpl/single/Singles.h" - -namespace rsocket { - -using namespace yarpl; - -StreamsFactory::StreamsFactory( - RSocketStateMachine& connection, - RSocketMode mode) - : connection_(connection), - nextStreamId_( - mode == RSocketMode::CLIENT - ? 1 /*Streams initiated by a client MUST use - odd-numbered stream identifiers*/ - : 2 /*streams initiated by the server MUST use - even-numbered stream identifiers*/) {} - -static void subscribeToErrorFlowable( - Reference> responseSink) { - yarpl::flowable::Flowables::error( - std::runtime_error("state machine is disconnected/closed")) - ->subscribe(std::move(responseSink)); -} - -static void subscribeToErrorSingle( - Reference> responseSink) { - yarpl::single::Singles::error( - std::runtime_error("state machine is disconnected/closed")) - ->subscribe(std::move(responseSink)); -} - -Reference> -StreamsFactory::createChannelRequester( - Reference> responseSink) { - if (connection_.isDisconnected()) { - subscribeToErrorFlowable(std::move(responseSink)); - return nullptr; - } - - auto const streamId = getNextStreamId(); - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId); - connection_.addStream(streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); - return stateMachine; -} - -void StreamsFactory::createStreamRequester( - Payload request, - Reference> responseSink) { - if (connection_.isDisconnected()) { - subscribeToErrorFlowable(std::move(responseSink)); - return; - } - - auto const streamId = getNextStreamId(); - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId, std::move(request)); - connection_.addStream(streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); -} - -void StreamsFactory::createStreamRequester( - Reference> responseSink, - StreamId streamId, - size_t n) { - if (connection_.isDisconnected()) { - subscribeToErrorFlowable(std::move(responseSink)); - return; - } - - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId, Payload()); - // Set requested to true (since cold resumption) - stateMachine->setRequested(n); - connection_.addStream(streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); -} - -void StreamsFactory::createRequestResponseRequester( - Payload payload, - Reference> responseSink) { - if (connection_.isDisconnected()) { - subscribeToErrorSingle(std::move(responseSink)); - return; - } - - auto const streamId = getNextStreamId(); - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId, std::move(payload)); - connection_.addStream(streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); -} - -StreamId StreamsFactory::getNextStreamId() { - StreamId streamId = nextStreamId_; - CHECK(streamId <= std::numeric_limits::max() - 2); - nextStreamId_ += 2; - return streamId; -} - -void StreamsFactory::setNextStreamId(StreamId streamId) { - nextStreamId_ = streamId + 2; -} - -bool StreamsFactory::registerNewPeerStreamId(StreamId streamId) { - DCHECK(streamId != 0); - if (nextStreamId_ % 2 == streamId % 2) { - // if this is an unknown stream to the socket and this socket is - // generating - // such stream ids, it is an incoming frame on the stream which no longer - // exist - return false; - } - if (streamId <= lastPeerStreamId_) { - // receiving frame for a stream which no longer exists - return false; - } - lastPeerStreamId_ = streamId; - return true; -} - -Reference StreamsFactory::createChannelResponder( - uint32_t initialRequestN, - StreamId streamId) { - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId, initialRequestN); - connection_.addStream(streamId, stateMachine); - return stateMachine; -} - -Reference> -StreamsFactory::createStreamResponder( - uint32_t initialRequestN, - StreamId streamId) { - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId, initialRequestN); - connection_.addStream(streamId, stateMachine); - return stateMachine; -} - -Reference> -StreamsFactory::createRequestResponseResponder(StreamId streamId) { - auto stateMachine = yarpl::make_ref( - connection_.shared_from_this(), streamId); - connection_.addStream(streamId, stateMachine); - return stateMachine; -} -} diff --git a/rsocket/statemachine/StreamsFactory.h b/rsocket/statemachine/StreamsFactory.h deleted file mode 100644 index 91d68b1ec..000000000 --- a/rsocket/statemachine/StreamsFactory.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/internal/Common.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscription.h" -#include "yarpl/single/SingleObserver.h" - -namespace folly { -class Executor; -} - -namespace rsocket { - -class RSocketStateMachine; -class ChannelResponder; -struct Payload; - -class StreamsFactory { - public: - StreamsFactory(RSocketStateMachine& connection, RSocketMode mode); - - yarpl::Reference> createChannelRequester( - yarpl::Reference> responseSink); - - void createStreamRequester( - Payload request, - yarpl::Reference> responseSink); - - void createStreamRequester( - yarpl::Reference> responseSink, - StreamId streamId, - size_t n); - - void createRequestResponseRequester( - Payload payload, - yarpl::Reference> responseSink); - - // TODO: the return type should not be the stateMachine type, but something - // generic - yarpl::Reference createChannelResponder( - uint32_t initialRequestN, - StreamId streamId); - - yarpl::Reference> createStreamResponder( - uint32_t initialRequestN, - StreamId streamId); - - yarpl::Reference> - createRequestResponseResponder(StreamId streamId); - - bool registerNewPeerStreamId(StreamId streamId); - StreamId getNextStreamId(); - - void setNextStreamId(StreamId streamId); - - private: - RSocketStateMachine& connection_; - StreamId nextStreamId_; - StreamId lastPeerStreamId_{0}; -}; -} // reactivesocket diff --git a/rsocket/statemachine/StreamsWriter.cpp b/rsocket/statemachine/StreamsWriter.cpp new file mode 100644 index 000000000..5e2279d70 --- /dev/null +++ b/rsocket/statemachine/StreamsWriter.cpp @@ -0,0 +1,197 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/StreamsWriter.h" + +#include "rsocket/RSocketStats.h" +#include "rsocket/framing/FrameSerializer.h" + +namespace rsocket { + +void StreamsWriterImpl::outputFrameOrEnqueue( + std::unique_ptr frame) { + if (shouldQueue()) { + enqueuePendingOutputFrame(std::move(frame)); + } else { + outputFrame(std::move(frame)); + } +} + +void StreamsWriterImpl::sendPendingFrames() { + // We are free to try to send frames again. Not all frames might be sent if + // the connection breaks, the rest of them will queue up again. + auto frames = consumePendingOutputFrames(); + for (auto& frame : frames) { + outputFrameOrEnqueue(std::move(frame)); + } +} + +void StreamsWriterImpl::enqueuePendingOutputFrame( + std::unique_ptr frame) { + auto const length = frame->computeChainDataLength(); + stats().streamBufferChanged(1, static_cast(length)); + pendingSize_ += length; + pendingOutputFrames_.push_back(std::move(frame)); +} + +std::deque> +StreamsWriterImpl::consumePendingOutputFrames() { + if (auto const numFrames = pendingOutputFrames_.size()) { + stats().streamBufferChanged( + -static_cast(numFrames), -static_cast(pendingSize_)); + pendingSize_ = 0; + } + return std::move(pendingOutputFrames_); +} + +void StreamsWriterImpl::writeNewStream( + StreamId streamId, + StreamType streamType, + uint32_t initialRequestN, + Payload payload) { + // for simplicity, require that sent buffers don't consist of chains + writeFragmented( + [&](Payload p, FrameFlags flags) { + switch (streamType) { + case StreamType::CHANNEL: + outputFrameOrEnqueue( + serializer().serializeOut(Frame_REQUEST_CHANNEL( + streamId, flags, initialRequestN, std::move(p)))); + break; + case StreamType::STREAM: + outputFrameOrEnqueue(serializer().serializeOut(Frame_REQUEST_STREAM( + streamId, flags, initialRequestN, std::move(p)))); + break; + case StreamType::REQUEST_RESPONSE: + outputFrameOrEnqueue(serializer().serializeOut( + Frame_REQUEST_RESPONSE(streamId, flags, std::move(p)))); + break; + case StreamType::FNF: + outputFrameOrEnqueue(serializer().serializeOut( + Frame_REQUEST_FNF(streamId, flags, std::move(p)))); + break; + default: + CHECK(false) << "invalid stream type " << toString(streamType); + } + }, + streamId, + FrameFlags::EMPTY_, + std::move(payload)); +} + +void StreamsWriterImpl::writeRequestN(Frame_REQUEST_N&& frame) { + outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); +} + +void StreamsWriterImpl::writeCancel(Frame_CANCEL&& frame) { + outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); +} + +void StreamsWriterImpl::writePayload(Frame_PAYLOAD&& f) { + Frame_PAYLOAD frame = std::move(f); + auto const streamId = frame.header_.streamId; + auto const initialFlags = frame.header_.flags; + + writeFragmented( + [this, streamId](Payload p, FrameFlags flags) { + outputFrameOrEnqueue(serializer().serializeOut( + Frame_PAYLOAD(streamId, flags, std::move(p)))); + }, + streamId, + initialFlags, + std::move(frame.payload_)); +} + +void StreamsWriterImpl::writeError(Frame_ERROR&& frame) { + // TODO: implement fragmentation for writeError as well + outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); +} + +// The max amount of user data transmitted per frame - eg the size +// of the data and metadata combined, plus the size of the frame header. +// This assumes that the frame header will never be more than 512 bytes in +// size. A CHECK in FrameTransportImpl enforces this. The idea is that +// 16M is so much larger than the ~500 bytes possibly wasted that it won't +// be noticeable (0.003% wasted at most) +constexpr size_t GENEROUS_MAX_FRAME_SIZE = 0xFFFFFF - 512; + +// writeFragmented takes a `payload` and splits it up into chunks which +// are sent as fragmented requests. The first fragmented payload is +// given to writeInitialFrame, which is expected to write the initial +// "REQUEST_" or "PAYLOAD" frame of a stream or response. writeFragmented +// then writes the rest of the frames as payloads. +// +// writeInitialFrame +// - called with the payload of the first frame to send, and any additional +// flags (eg, addFlags with FOLLOWS, if there are more frames to write) +// streamId +// - The stream ID to write additional fragments with +// addFlags +// - All flags that writeInitialFrame wants to write the first frame with, +// and all flags that subsequent fragmented payloads will be sent with +// payload +// - The unsplit payload to send, possibly in multiple fragments +template +void StreamsWriterImpl::writeFragmented( + WriteInitialFrame writeInitialFrame, + StreamId const streamId, + FrameFlags const addFlags, + Payload payload) { + folly::IOBufQueue metaQueue{folly::IOBufQueue::cacheChainLength()}; + folly::IOBufQueue dataQueue{folly::IOBufQueue::cacheChainLength()}; + + // have to keep track of "did the full payload even have a metadata", because + // the rsocket protocol makes a distinction between a zero-length metadata + // and a null metadata. + bool const haveNonNullMeta = !!payload.metadata; + metaQueue.append(std::move(payload.metadata)); + dataQueue.append(std::move(payload.data)); + + bool isFirstFrame = true; + + while (true) { + Payload sendme; + + // chew off some metadata (splitAtMost will never return a null pointer, + // safe to compute length on it always) + if (haveNonNullMeta) { + sendme.metadata = metaQueue.splitAtMost(GENEROUS_MAX_FRAME_SIZE); + DCHECK_GE( + GENEROUS_MAX_FRAME_SIZE, sendme.metadata->computeChainDataLength()); + } + sendme.data = dataQueue.splitAtMost( + GENEROUS_MAX_FRAME_SIZE - + (haveNonNullMeta ? sendme.metadata->computeChainDataLength() : 0)); + + auto const metaLeft = metaQueue.chainLength(); + auto const dataLeft = dataQueue.chainLength(); + auto const moreFragments = metaLeft || dataLeft; + auto const flags = + (moreFragments ? FrameFlags::FOLLOWS : FrameFlags::EMPTY_) | addFlags; + + if (isFirstFrame) { + isFirstFrame = false; + writeInitialFrame(std::move(sendme), flags); + } else { + outputFrameOrEnqueue(serializer().serializeOut( + Frame_PAYLOAD(streamId, flags, std::move(sendme)))); + } + + if (!moreFragments) { + break; + } + } +} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamsWriter.h b/rsocket/statemachine/StreamsWriter.h index ddcb3d650..7ecf1da87 100644 --- a/rsocket/statemachine/StreamsWriter.h +++ b/rsocket/statemachine/StreamsWriter.h @@ -1,12 +1,33 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include + +#include +#include #include "rsocket/Payload.h" +#include "rsocket/framing/Frame.h" +#include "rsocket/framing/FrameType.h" #include "rsocket/internal/Common.h" namespace rsocket { +class RSocketStats; +class FrameSerializer; + /// The interface for writing stream related frames on the wire. class StreamsWriter { public: @@ -16,8 +37,7 @@ class StreamsWriter { StreamId streamId, StreamType streamType, uint32_t initialRequestN, - Payload payload, - bool TEMP_completed) = 0; + Payload payload) = 0; virtual void writeRequestN(Frame_REQUEST_N&&) = 0; virtual void writeCancel(Frame_CANCEL&&) = 0; @@ -25,6 +45,63 @@ class StreamsWriter { virtual void writePayload(Frame_PAYLOAD&&) = 0; virtual void writeError(Frame_ERROR&&) = 0; - virtual void onStreamClosed(StreamId, StreamCompletionSignal) = 0; + virtual void onStreamClosed(StreamId) = 0; + + virtual std::shared_ptr> + onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) = 0; + virtual void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) = 0; }; -} + +class StreamsWriterImpl : public StreamsWriter { + public: + void writeNewStream( + StreamId streamId, + StreamType streamType, + uint32_t initialRequestN, + Payload payload) override; + + void writeRequestN(Frame_REQUEST_N&&) override; + void writeCancel(Frame_CANCEL&&) override; + + void writePayload(Frame_PAYLOAD&&) override; + + // TODO: writeFragmentedError + void writeError(Frame_ERROR&&) override; + + protected: + // note: onStreamClosed() method is also still pure + virtual void outputFrame(std::unique_ptr) = 0; + virtual FrameSerializer& serializer() = 0; + virtual RSocketStats& stats() = 0; + virtual bool shouldQueue() = 0; + + template + void writeFragmented( + WriteInitialFrame, + StreamId const, + FrameFlags const, + Payload payload); + + /// Send a frame to the output, or queue it if shouldQueue() + virtual void sendPendingFrames(); + void outputFrameOrEnqueue(std::unique_ptr); + void enqueuePendingOutputFrame(std::unique_ptr frame); + std::deque> consumePendingOutputFrames(); + + private: + /// A queue of frames that are slated to be sent out. + std::deque> pendingOutputFrames_; + + /// The byte size of all pending output frames. + size_t pendingSize_{0}; +}; + +} // namespace rsocket diff --git a/tck-test/BaseSubscriber.cpp b/rsocket/tck-test/BaseSubscriber.cpp similarity index 73% rename from tck-test/BaseSubscriber.cpp rename to rsocket/tck-test/BaseSubscriber.cpp index 24dbb1510..c0df54613 100644 --- a/tck-test/BaseSubscriber.cpp +++ b/rsocket/tck-test/BaseSubscriber.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "tck-test/BaseSubscriber.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/BaseSubscriber.h" #include @@ -30,9 +42,9 @@ void BaseSubscriber::awaitAtLeast(int numItems) { } void BaseSubscriber::awaitNoEvents(int waitTime) { - int valuesCount = valuesCount_; - bool completed = completed_; - bool errored = errored_; + const int valuesCount = valuesCount_; + const bool completed = completed_; + const bool errored = errored_; /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(waitTime)); if (valuesCount != valuesCount_ || completed != completed_ || @@ -57,7 +69,7 @@ void BaseSubscriber::assertError() { void BaseSubscriber::assertValues( const std::vector>& values) { assertValueCount(values.size()); - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); for (size_t i = 0; i < values.size(); i++) { if (values_[i] != values[i]) { throw std::runtime_error(folly::sformat( @@ -71,7 +83,7 @@ void BaseSubscriber::assertValues( } void BaseSubscriber::assertValueCount(size_t valueCount) { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); if (values_.size() != valueCount) { throw std::runtime_error(folly::sformat( "Did not receive expected number of values! Expected={} Actual={}", @@ -81,7 +93,7 @@ void BaseSubscriber::assertValueCount(size_t valueCount) { } void BaseSubscriber::assertReceivedAtLeast(size_t valueCount) { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); if (values_.size() < valueCount) { throw std::runtime_error(folly::sformat( "Did not receive the minimum number of values! Expected={} Actual={}", @@ -108,5 +120,5 @@ void BaseSubscriber::assertCanceled() { } } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/tck-test/BaseSubscriber.h b/rsocket/tck-test/BaseSubscriber.h similarity index 61% rename from tck-test/BaseSubscriber.h rename to rsocket/tck-test/BaseSubscriber.h index 3a357b645..fdda649ff 100644 --- a/tck-test/BaseSubscriber.h +++ b/rsocket/tck-test/BaseSubscriber.h @@ -1,3 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include @@ -10,7 +24,7 @@ namespace rsocket { namespace tck { -class BaseSubscriber : public virtual yarpl::Refcounted { +class BaseSubscriber { public: virtual void request(int n) = 0; virtual void cancel() = 0; @@ -31,7 +45,8 @@ class BaseSubscriber : public virtual yarpl::Refcounted { std::atomic canceled_{false}; //////////////////////////////////////////////////////////////////////////// - std::mutex mutex_; // all variables below has to be protected with the mutex + mutable std::mutex + mutex_; // all variables below has to be protected with the mutex std::vector> values_; std::condition_variable valuesCV_; @@ -45,5 +60,5 @@ class BaseSubscriber : public virtual yarpl::Refcounted { //////////////////////////////////////////////////////////////////////////// }; -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/tck-test/FlowableSubscriber.cpp b/rsocket/tck-test/FlowableSubscriber.cpp similarity index 60% rename from tck-test/FlowableSubscriber.cpp rename to rsocket/tck-test/FlowableSubscriber.cpp index ca88d21d6..33b72b7fb 100644 --- a/tck-test/FlowableSubscriber.cpp +++ b/rsocket/tck-test/FlowableSubscriber.cpp @@ -1,4 +1,18 @@ -#include "tck-test/FlowableSubscriber.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/FlowableSubscriber.h" #include @@ -29,7 +43,7 @@ void FlowableSubscriber::cancel() { } void FlowableSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { VLOG(4) << "OnSubscribe in FlowableSubscriber"; subscription_ = subscription; if (initialRequestN_ > 0) { @@ -40,10 +54,10 @@ void FlowableSubscriber::onSubscribe( void FlowableSubscriber::onNext(Payload element) noexcept { LOG(INFO) << "... received onNext from Publisher: " << element; { - std::unique_lock lock(mutex_); - std::string data = + const std::unique_lock lock(mutex_); + const std::string data = element.data ? element.data->moveToFbString().toStdString() : ""; - std::string metadata = element.metadata + const std::string metadata = element.metadata ? element.metadata->moveToFbString().toStdString() : ""; values_.push_back(std::make_pair(data, metadata)); @@ -55,7 +69,7 @@ void FlowableSubscriber::onNext(Payload element) noexcept { void FlowableSubscriber::onComplete() noexcept { LOG(INFO) << "... received onComplete from Publisher"; { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); completed_ = true; } @@ -65,12 +79,12 @@ void FlowableSubscriber::onComplete() noexcept { void FlowableSubscriber::onError(folly::exception_wrapper ex) noexcept { LOG(INFO) << "... received onError from Publisher"; { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); errors_.push_back(std::move(ex)); errored_ = true; } terminatedCV_.notify_one(); } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/FlowableSubscriber.h b/rsocket/tck-test/FlowableSubscriber.h new file mode 100644 index 000000000..3de091023 --- /dev/null +++ b/rsocket/tck-test/FlowableSubscriber.h @@ -0,0 +1,47 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/tck-test/BaseSubscriber.h" + +#include "yarpl/Flowable.h" + +namespace rsocket { +namespace tck { + +class FlowableSubscriber : public BaseSubscriber, + public yarpl::flowable::Subscriber { + public: + explicit FlowableSubscriber(int initialRequestN = 0); + + // Inherited from BaseSubscriber + void request(int n) override; + void cancel() override; + + protected: + // Inherited from flowable::Subscriber + void onSubscribe(std::shared_ptr + subscription) noexcept override; + void onNext(Payload element) noexcept override; + void onComplete() noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; + + private: + std::shared_ptr subscription_; + int initialRequestN_{0}; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/MarbleProcessor.cpp b/rsocket/tck-test/MarbleProcessor.cpp similarity index 73% rename from tck-test/MarbleProcessor.cpp rename to rsocket/tck-test/MarbleProcessor.cpp index 93b60b34f..62038cac1 100644 --- a/tck-test/MarbleProcessor.cpp +++ b/rsocket/tck-test/MarbleProcessor.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "MarbleProcessor.h" @@ -44,7 +56,7 @@ std::map> getArgMap( } return argMap; } -} +} // namespace namespace rsocket { namespace tck { @@ -67,29 +79,26 @@ MarbleProcessor::MarbleProcessor(const std::string marble) } } -std::tuple MarbleProcessor::run( - yarpl::Reference> subscriber, +void MarbleProcessor::run( + yarpl::flowable::Subscriber& subscriber, int64_t requested) { canSend_ += requested; - if (index_ > marble_.size()) { - return std::make_tuple(requested, true); - } - while (true) { - auto c = marble_[index_]; + while (canSend_ > 0 && index_ < marble_.size()) { + const auto c = marble_[index_]; switch (c) { case '#': LOG(INFO) << "Sending onError"; - subscriber->onError(std::runtime_error("Marble Error")); - return std::make_tuple(requested, true); + subscriber.onError(std::runtime_error("Marble Error")); + break; case '|': LOG(INFO) << "Sending onComplete"; - subscriber->onComplete(); - return std::make_tuple(requested, true); - default: { + subscriber.onComplete(); + break; + default: if (canSend_ > 0) { Payload payload; - auto it = argMap_.find(folly::to(c)); + const auto it = argMap_.find(folly::to(c)); LOG(INFO) << "Sending data " << c; if (it != argMap_.end()) { LOG(INFO) << folly::sformat( @@ -102,22 +111,20 @@ std::tuple MarbleProcessor::run( payload = Payload(folly::to(c), folly::to(c)); } - subscriber->onNext(std::move(payload)); + subscriber.onNext(std::move(payload)); canSend_--; - } else { - return std::make_tuple(requested, false); } - } + break; } index_++; } } void MarbleProcessor::run( - yarpl::Reference> + std::shared_ptr> subscriber) { while (true) { - auto c = marble_[index_]; + const auto c = marble_[index_]; switch (c) { case '#': LOG(INFO) << "Sending onError"; @@ -129,7 +136,7 @@ void MarbleProcessor::run( return; default: { Payload payload; - auto it = argMap_.find(folly::to(c)); + const auto it = argMap_.find(folly::to(c)); LOG(INFO) << "Sending data " << c; if (it != argMap_.end()) { LOG(INFO) << folly::sformat( @@ -150,5 +157,5 @@ void MarbleProcessor::run( } } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/MarbleProcessor.h b/rsocket/tck-test/MarbleProcessor.h new file mode 100644 index 000000000..77217b63b --- /dev/null +++ b/rsocket/tck-test/MarbleProcessor.h @@ -0,0 +1,50 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "rsocket/Payload.h" +#include "yarpl/Flowable.h" +#include "yarpl/Single.h" + +namespace rsocket { +namespace tck { + +class MarbleProcessor { + public: + explicit MarbleProcessor(const std::string /* marble */); + + void run( + yarpl::flowable::Subscriber& subscriber, + int64_t requested); + + void run(std::shared_ptr> + subscriber); + + private: + std::string marble_; + + // Stores a mapping from marble character to Payload (data, metadata) + std::map> argMap_; + + // Keeps an account of how many messages can be sent. This could be done + // with Allowance + std::atomic canSend_{0}; + + size_t index_{0}; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/SingleSubscriber.cpp b/rsocket/tck-test/SingleSubscriber.cpp similarity index 53% rename from tck-test/SingleSubscriber.cpp rename to rsocket/tck-test/SingleSubscriber.cpp index 72b514445..2da91f5b1 100644 --- a/tck-test/SingleSubscriber.cpp +++ b/rsocket/tck-test/SingleSubscriber.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "tck-test/SingleSubscriber.h" +#include "rsocket/tck-test/SingleSubscriber.h" #include @@ -24,7 +36,7 @@ void SingleSubscriber::cancel() { } void SingleSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { VLOG(4) << "OnSubscribe in SingleSubscriber"; subscription_ = subscription; } @@ -32,10 +44,10 @@ void SingleSubscriber::onSubscribe( void SingleSubscriber::onSuccess(Payload element) noexcept { LOG(INFO) << "... received onSuccess from Publisher: " << element; { - std::unique_lock lock(mutex_); - std::string data = + const std::unique_lock lock(mutex_); + const std::string data = element.data ? element.data->moveToFbString().toStdString() : ""; - std::string metadata = element.metadata + const std::string metadata = element.metadata ? element.metadata->moveToFbString().toStdString() : ""; values_.push_back(std::make_pair(data, metadata)); @@ -43,7 +55,7 @@ void SingleSubscriber::onSuccess(Payload element) noexcept { } valuesCV_.notify_one(); { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); completed_ = true; } terminatedCV_.notify_one(); @@ -52,12 +64,12 @@ void SingleSubscriber::onSuccess(Payload element) noexcept { void SingleSubscriber::onError(folly::exception_wrapper ex) noexcept { LOG(INFO) << "... received onError from Publisher"; { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); errors_.push_back(std::move(ex)); errored_ = true; } terminatedCV_.notify_one(); } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/SingleSubscriber.h b/rsocket/tck-test/SingleSubscriber.h new file mode 100644 index 000000000..8b8b8556f --- /dev/null +++ b/rsocket/tck-test/SingleSubscriber.h @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/tck-test/BaseSubscriber.h" + +#include "yarpl/Single.h" + +namespace rsocket { +namespace tck { + +class SingleSubscriber : public BaseSubscriber, + public yarpl::single::SingleObserver { + public: + // Inherited from BaseSubscriber + void request(int n) override; + void cancel() override; + + protected: + // Inherited from flowable::Subscriber + void onSubscribe(std::shared_ptr + subscription) noexcept override; + void onSuccess(Payload element) noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; + + private: + std::shared_ptr subscription_; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/TestFileParser.cpp b/rsocket/tck-test/TestFileParser.cpp similarity index 71% rename from tck-test/TestFileParser.cpp rename to rsocket/tck-test/TestFileParser.cpp index 309a481c0..9306960c3 100644 --- a/tck-test/TestFileParser.cpp +++ b/rsocket/tck-test/TestFileParser.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "tck-test/TestFileParser.h" +#include "rsocket/tck-test/TestFileParser.h" #include #include diff --git a/rsocket/tck-test/TestFileParser.h b/rsocket/tck-test/TestFileParser.h new file mode 100644 index 000000000..7830934fb --- /dev/null +++ b/rsocket/tck-test/TestFileParser.h @@ -0,0 +1,42 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/tck-test/TestSuite.h" + +namespace rsocket { +namespace tck { + +class TestFileParser { + public: + explicit TestFileParser(const std::string& fileName); + + TestSuite parse(); + + private: + void parseCommand(const std::string& command); + void addCurrentTest(); + + std::ifstream input_; + int currentLine_; + + TestSuite testSuite_; + Test currentTest_; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/TestInterpreter.cpp b/rsocket/tck-test/TestInterpreter.cpp similarity index 81% rename from tck-test/TestInterpreter.cpp rename to rsocket/tck-test/TestInterpreter.cpp index 7f76994f0..f74eab68c 100644 --- a/tck-test/TestInterpreter.cpp +++ b/rsocket/tck-test/TestInterpreter.cpp @@ -1,16 +1,28 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "tck-test/TestInterpreter.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/TestInterpreter.h" #include #include #include #include "rsocket/RSocket.h" +#include "rsocket/tck-test/FlowableSubscriber.h" +#include "rsocket/tck-test/SingleSubscriber.h" +#include "rsocket/tck-test/TypedCommands.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "tck-test/FlowableSubscriber.h" -#include "tck-test/SingleSubscriber.h" -#include "tck-test/TypedCommands.h" using namespace folly; using namespace yarpl; @@ -34,25 +46,25 @@ bool TestInterpreter::run() { "Executing command: [{}] {}", i, command.name()); ++i; if (command.name() == "subscribe") { - auto subscribe = command.as(); + const auto subscribe = command.as(); handleSubscribe(subscribe); } else if (command.name() == "request") { - auto request = command.as(); + const auto request = command.as(); handleRequest(request); } else if (command.name() == "await") { - auto await = command.as(); + const auto await = command.as(); handleAwait(await); } else if (command.name() == "cancel") { - auto cancel = command.as(); + const auto cancel = command.as(); handleCancel(cancel); } else if (command.name() == "assert") { - auto assert = command.as(); + const auto assert = command.as(); handleAssert(assert); } else if (command.name() == "disconnect") { - auto disconnect = command.as(); + const auto disconnect = command.as(); handleDisconnect(disconnect); } else if (command.name() == "resume") { - auto resume = command.as(); + const auto resume = command.as(); handleResume(resume); } else { LOG(ERROR) << "unknown command " << command.name(); @@ -107,7 +119,7 @@ void TestInterpreter::handleSubscribe(const SubscribeCommand& command) { testSubscribers_.end()); if (command.isRequestResponseType()) { - auto testSubscriber = make_ref(); + auto testSubscriber = std::make_shared(); testSubscribers_[command.clientId() + command.id()] = testSubscriber; testClient_[command.clientId()] ->requester @@ -115,7 +127,7 @@ void TestInterpreter::handleSubscribe(const SubscribeCommand& command) { Payload(command.payloadData(), command.payloadMetadata())) ->subscribe(std::move(testSubscriber)); } else if (command.isRequestStreamType()) { - auto testSubscriber = make_ref(); + auto testSubscriber = std::make_shared(); testSubscribers_[command.clientId() + command.id()] = testSubscriber; testClient_[command.clientId()] ->requester @@ -185,9 +197,9 @@ void TestInterpreter::handleAssert(const AssertCommand& command) { } } -yarpl::Reference TestInterpreter::getSubscriber( +std::shared_ptr TestInterpreter::getSubscriber( const std::string& id) { - auto found = testSubscribers_.find(id); + const auto found = testSubscribers_.find(id); if (found == testSubscribers_.end()) { throw std::runtime_error("unable to find test subscriber with provided id"); } diff --git a/tck-test/TestInterpreter.h b/rsocket/tck-test/TestInterpreter.h similarity index 64% rename from tck-test/TestInterpreter.h rename to rsocket/tck-test/TestInterpreter.h index 97943d896..d57fb76c6 100644 --- a/tck-test/TestInterpreter.h +++ b/rsocket/tck-test/TestInterpreter.h @@ -1,17 +1,29 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include #include +#include #include "rsocket/Payload.h" #include "rsocket/RSocket.h" #include "rsocket/RSocketRequester.h" -#include "tck-test/BaseSubscriber.h" -#include "tck-test/TestSuite.h" +#include "rsocket/tck-test/BaseSubscriber.h" +#include "rsocket/tck-test/TestSuite.h" namespace folly { class EventBase; @@ -57,13 +69,13 @@ class TestInterpreter { void handleDisconnect(const DisconnectCommand& command); void handleResume(const ResumeCommand& command); - yarpl::Reference getSubscriber(const std::string& id); + std::shared_ptr getSubscriber(const std::string& id); folly::ScopedEventBaseThread worker_; folly::SocketAddress address_; const Test& test_; std::map interactionIdToType_; - std::map> testSubscribers_; + std::map> testSubscribers_; std::map> testClient_; }; diff --git a/rsocket/tck-test/TestSuite.cpp b/rsocket/tck-test/TestSuite.cpp new file mode 100644 index 000000000..8e921f347 --- /dev/null +++ b/rsocket/tck-test/TestSuite.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/TestSuite.h" + +#include + +namespace rsocket { +namespace tck { + +bool TestCommand::valid() const { + // there has to be a name to the test and at least 1 param + return params_.size() >= 1; +} + +void Test::addCommand(TestCommand command) { + CHECK(command.valid()); + commands_.push_back(std::move(command)); +} + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/TestSuite.h b/rsocket/tck-test/TestSuite.h similarity index 68% rename from tck-test/TestSuite.h rename to rsocket/tck-test/TestSuite.h index 61490b00f..f705e0d13 100644 --- a/tck-test/TestSuite.h +++ b/rsocket/tck-test/TestSuite.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/tck-test/TypedCommands.h b/rsocket/tck-test/TypedCommands.h similarity index 85% rename from tck-test/TypedCommands.h rename to rsocket/tck-test/TypedCommands.h index c25726652..d32b144ac 100644 --- a/tck-test/TypedCommands.h +++ b/rsocket/tck-test/TypedCommands.h @@ -1,11 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include -#include "tck-test/TestSuite.h" +#include "rsocket/tck-test/TestSuite.h" namespace rsocket { namespace tck { diff --git a/tck-test/client.cpp b/rsocket/tck-test/client.cpp similarity index 72% rename from tck-test/client.cpp rename to rsocket/tck-test/client.cpp index 901d7b336..acb5068c3 100644 --- a/tck-test/client.cpp +++ b/rsocket/tck-test/client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -8,8 +20,8 @@ #include "rsocket/RSocket.h" -#include "tck-test/TestFileParser.h" -#include "tck-test/TestInterpreter.h" +#include "rsocket/tck-test/TestFileParser.h" +#include "rsocket/tck-test/TestInterpreter.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" diff --git a/tck-test/clientResumptiontest.txt b/rsocket/tck-test/clientResumptiontest.txt similarity index 100% rename from tck-test/clientResumptiontest.txt rename to rsocket/tck-test/clientResumptiontest.txt diff --git a/tck-test/clienttest.txt b/rsocket/tck-test/clienttest.txt similarity index 100% rename from tck-test/clienttest.txt rename to rsocket/tck-test/clienttest.txt diff --git a/tck-test/server.cpp b/rsocket/tck-test/server.cpp similarity index 72% rename from tck-test/server.cpp rename to rsocket/tck-test/server.cpp index caf982d42..2988cfe28 100644 --- a/tck-test/server.cpp +++ b/rsocket/tck-test/server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -16,7 +28,7 @@ #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "tck-test/MarbleProcessor.h" +#include "rsocket/tck-test/MarbleProcessor.h" using namespace folly; using namespace rsocket; @@ -75,42 +87,46 @@ class ServerResponder : public RSocketResponder { marbles_ = parseMarbles(FLAGS_test_file); } - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( Payload request, StreamId) override { LOG(INFO) << "handleRequestStream " << request; - std::string data = request.data->moveToFbString().toStdString(); - std::string metadata = request.metadata->moveToFbString().toStdString(); - auto it = marbles_.streamMarbles.find(std::make_pair(data, metadata)); + const std::string data = request.data->moveToFbString().toStdString(); + const std::string metadata = + request.metadata->moveToFbString().toStdString(); + const auto it = marbles_.streamMarbles.find(std::make_pair(data, metadata)); if (it == marbles_.streamMarbles.end()) { - return yarpl::flowable::Flowables::error( + return yarpl::flowable::Flowable::error( std::logic_error("No MarbleHandler found")); } else { - auto marbleProcessor = std::make_shared(it->second); + const auto marbleProcessor = + std::make_shared(it->second); auto lambda = [marbleProcessor]( - Reference> subscriber, - int64_t requested) mutable { + auto& subscriber, int64_t requested) mutable { return marbleProcessor->run(subscriber, requested); }; return Flowable::create(std::move(lambda)); } } - yarpl::Reference> handleRequestResponse( + std::shared_ptr> handleRequestResponse( Payload request, StreamId) override { LOG(INFO) << "handleRequestResponse " << request; - std::string data = request.data->moveToFbString().toStdString(); - std::string metadata = request.metadata->moveToFbString().toStdString(); - auto it = marbles_.reqRespMarbles.find(std::make_pair(data, metadata)); + const std::string data = request.data->moveToFbString().toStdString(); + const std::string metadata = + request.metadata->moveToFbString().toStdString(); + const auto it = + marbles_.reqRespMarbles.find(std::make_pair(data, metadata)); if (it == marbles_.reqRespMarbles.end()) { return yarpl::single::Singles::error( std::logic_error("No MarbleHandler found")); } else { - auto marbleProcessor = std::make_shared(it->second); + const auto marbleProcessor = + std::make_shared(it->second); auto lambda = [marbleProcessor]( - yarpl::Reference> + std::shared_ptr> subscriber) { subscriber->onSubscribe(SingleSubscriptions::empty()); return marbleProcessor->run(subscriber); @@ -138,7 +154,7 @@ class ServiceHandler : public RSocketServiceHandler { folly::Expected, RSocketException> onResume(ResumeIdentificationToken token) override { - auto itr = store_->find(token); + const auto itr = store_->find(token); CHECK(itr != store_->end()); return itr->second; }; @@ -169,10 +185,10 @@ int main(int argc, char* argv[]) { opts.threads = 1; // RSocket server accepting on TCP - auto rs = RSocket::createServer( + const auto rs = RSocket::createServer( std::make_unique(std::move(opts))); - auto rawRs = rs.get(); + const auto rawRs = rs.get(); auto serverThread = std::thread( [=] { rawRs->startAndPark(std::make_shared()); }); diff --git a/tck-test/serverResumptiontest.txt b/rsocket/tck-test/serverResumptiontest.txt similarity index 100% rename from tck-test/serverResumptiontest.txt rename to rsocket/tck-test/serverResumptiontest.txt diff --git a/tck-test/servertest.txt b/rsocket/tck-test/servertest.txt similarity index 100% rename from tck-test/servertest.txt rename to rsocket/tck-test/servertest.txt diff --git a/test/ColdResumptionTest.cpp b/rsocket/test/ColdResumptionTest.cpp similarity index 82% rename from test/ColdResumptionTest.cpp rename to rsocket/test/ColdResumptionTest.cpp index 78a554023..d07fcb7b4 100644 --- a/test/ColdResumptionTest.cpp +++ b/rsocket/test/ColdResumptionTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -17,10 +29,10 @@ DEFINE_int32(num_clients, 5, "Number of clients to parallely cold-resume"); using namespace rsocket; using namespace rsocket::tests; using namespace rsocket::tests::client_server; -using namespace yarpl; using namespace yarpl::flowable; -typedef std::map>> HelloSubscribers; +typedef std::map>> + HelloSubscribers; namespace { class HelloSubscriber : public BaseSubscriber { @@ -78,14 +90,14 @@ class HelloResumeHandler : public ColdResumeHandler { : subscribers_(std::move(subscribers)) {} std::string generateStreamToken(const Payload& payload, StreamId, StreamType) - override { - auto streamToken = + const override { + const auto streamToken = payload.data->cloneAsValue().moveToFbString().toStdString(); VLOG(3) << "Generated token: " << streamToken; return streamToken; } - Reference> handleRequesterResumeStream( + std::shared_ptr> handleRequesterResumeStream( std::string streamToken, size_t consumerAllowance) override { CHECK(subscribers_.find(streamToken) != subscribers_.end()); @@ -97,7 +109,7 @@ class HelloResumeHandler : public ColdResumeHandler { private: HelloSubscribers subscribers_; }; -} +} // namespace std::unique_ptr createResumedClient( folly::EventBase* evb, @@ -118,10 +130,10 @@ std::unique_ptr createResumedClient( kDefaultKeepaliveInterval, nullptr, /* stats */ nullptr, /* connectionEvents */ - ProtocolVersion::Current(), + ProtocolVersion::Latest, stateMachineEvb) .get(); - } catch (RSocketException ex) { + } catch (const RSocketException& ex) { retries--; VLOG(1) << "Creation of resumed client failed. Exception " << ex.what() << ". Retries Left: " << retries; @@ -154,7 +166,7 @@ void coldResumer(uint32_t port, uint32_t client_num) { auto resumeManager = std::make_shared(RSocketStats::noop()); { - auto firstSub = make_ref(0); + auto firstSub = std::make_shared(0); { auto coldResumeHandler = std::make_shared( HelloSubscribers({{firstPayload, firstSub}})); @@ -176,7 +188,7 @@ void coldResumer(uint32_t port, uint32_t client_num) { } } worker.getEventBase()->runInEventBaseThreadAndWait( - [ client_num, &firstLatestValue, firstSub = std::move(firstSub) ]() { + [client_num, &firstLatestValue, firstSub = std::move(firstSub)]() { firstLatestValue = firstSub->getLatestValue(); VLOG(1) << folly::sformat( "client{} {}", client_num, firstLatestValue); @@ -185,8 +197,8 @@ void coldResumer(uint32_t port, uint32_t client_num) { } { - auto firstSub = yarpl::make_ref(firstLatestValue); - auto secondSub = yarpl::make_ref(0); + auto firstSub = std::make_shared(firstLatestValue); + auto secondSub = std::make_shared(0); { auto coldResumeHandler = std::make_shared( HelloSubscribers({{firstPayload, firstSub}})); @@ -211,25 +223,26 @@ void coldResumer(uint32_t port, uint32_t client_num) { std::this_thread::yield(); } } - worker.getEventBase()->runInEventBaseThreadAndWait([ - client_num, - &firstLatestValue, - firstSub = std::move(firstSub), - &secondLatestValue, - secondSub = std::move(secondSub) - ]() { - firstLatestValue = firstSub->getLatestValue(); - secondLatestValue = secondSub->getLatestValue(); - VLOG(1) << folly::sformat("client{} {}", client_num, firstLatestValue); - VLOG(1) << folly::sformat("client{} {}", client_num, secondLatestValue); - VLOG(1) << folly::sformat("client{} Second Resume", client_num); - }); + worker.getEventBase()->runInEventBaseThreadAndWait( + [client_num, + &firstLatestValue, + firstSub = std::move(firstSub), + &secondLatestValue, + secondSub = std::move(secondSub)]() { + firstLatestValue = firstSub->getLatestValue(); + secondLatestValue = secondSub->getLatestValue(); + VLOG(1) << folly::sformat( + "client{} {}", client_num, firstLatestValue); + VLOG(1) << folly::sformat( + "client{} {}", client_num, secondLatestValue); + VLOG(1) << folly::sformat("client{} Second Resume", client_num); + }); } { - auto firstSub = yarpl::make_ref(firstLatestValue); - auto secondSub = yarpl::make_ref(secondLatestValue); - auto thirdSub = yarpl::make_ref(0); + auto firstSub = std::make_shared(firstLatestValue); + auto secondSub = std::make_shared(secondLatestValue); + auto thirdSub = std::make_shared(0); auto coldResumeHandler = std::make_shared(HelloSubscribers( {{firstPayload, firstSub}, {secondPayload, secondSub}})); @@ -258,7 +271,7 @@ void coldResumer(uint32_t port, uint32_t client_num) { } } -TEST(ColdResumptionTest, SuccessfulResumption) { +TEST(ColdResumptionTest, DISABLED_SuccessfulResumption) { auto server = makeResumableServer(std::make_shared()); auto port = *server->listeningPort(); @@ -288,7 +301,7 @@ TEST(ColdResumptionTest, DifferentEvb) { auto resumeManager = std::make_shared(RSocketStats::noop()); { - auto firstSub = make_ref(0); + auto firstSub = std::make_shared(0); { auto coldResumeHandler = std::make_shared( HelloSubscribers({{payload, firstSub}})); @@ -311,7 +324,7 @@ TEST(ColdResumptionTest, DifferentEvb) { } } SMWorker.getEventBase()->runInEventBaseThreadAndWait( - [&latestValue, firstSub = std::move(firstSub) ]() { + [&latestValue, firstSub = std::move(firstSub)]() { latestValue = firstSub->getLatestValue(); VLOG(1) << latestValue; VLOG(1) << "First Resume"; @@ -319,7 +332,7 @@ TEST(ColdResumptionTest, DifferentEvb) { } { - auto firstSub = yarpl::make_ref(latestValue); + auto firstSub = std::make_shared(latestValue); { auto coldResumeHandler = std::make_shared( HelloSubscribers({{payload, firstSub}})); @@ -359,7 +372,7 @@ TEST(ColdResumptionTest, DisconnectResumption) { auto token = ResumeIdentificationToken::generateNew(); auto resumeManager = std::make_shared(RSocketStats::noop()); - auto sub = make_ref(0); + auto sub = std::make_shared(0); auto crh = std::make_shared(HelloSubscribers({{payload, sub}})); std::shared_ptr client; @@ -373,7 +386,7 @@ TEST(ColdResumptionTest, DisconnectResumption) { std::this_thread::yield(); } - auto resumedSub = make_ref(7); + auto resumedSub = std::make_shared(7); auto resumedCrh = std::make_shared( HelloSubscribers({{payload, resumedSub}})); diff --git a/test/ConnectionEventsTest.cpp b/rsocket/test/ConnectionEventsTest.cpp similarity index 73% rename from test/ConnectionEventsTest.cpp rename to rsocket/test/ConnectionEventsTest.cpp index 584de01cc..c6ec2ba27 100644 --- a/test/ConnectionEventsTest.cpp +++ b/rsocket/test/ConnectionEventsTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -94,14 +106,13 @@ TEST(ConnectionEventsTest, DifferentEvb) { auto clientConnEvents = std::make_shared>(); - EXPECT_CALL(*clientConnEvents, onConnected()).WillOnce( - Invoke([evb = SMWorker.getEventBase()]() { - EXPECT_TRUE(evb->isInEventBaseThread()); - })); + EXPECT_CALL(*clientConnEvents, onConnected()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); // create server supporting resumption - auto server = makeResumableServer( - std::make_shared()); + auto server = makeResumableServer(std::make_shared()); // create resumable client auto client = makeWarmResumableClient( @@ -122,23 +133,23 @@ TEST(ConnectionEventsTest, DifferentEvb) { } // disconnect - EXPECT_CALL(*clientConnEvents, onDisconnected(_)).WillOnce( - InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onDisconnected(_)) + .WillOnce(InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); - EXPECT_CALL(*clientConnEvents, onStreamsPaused()).WillOnce( - Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onStreamsPaused()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); client->disconnect(std::runtime_error("Test triggered disconnect")); // resume - EXPECT_CALL(*clientConnEvents, onConnected()).WillOnce( - Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onConnected()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); - EXPECT_CALL(*clientConnEvents, onStreamsResumed()).WillOnce( - Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onStreamsResumed()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); EXPECT_NO_THROW(client->resume().get()); @@ -149,19 +160,19 @@ TEST(ConnectionEventsTest, DifferentEvb) { ts->assertValueCount(10); // disconnect - EXPECT_CALL(*clientConnEvents, onDisconnected(_)).WillOnce( - InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onDisconnected(_)) + .WillOnce(InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); - EXPECT_CALL(*clientConnEvents, onStreamsPaused()).WillOnce( - Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onStreamsPaused()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); client->disconnect(std::runtime_error("Test triggered disconnect")); // relinquish resources - EXPECT_CALL(*clientConnEvents, onClosed(_)).WillOnce( - InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { + EXPECT_CALL(*clientConnEvents, onClosed(_)) + .WillOnce(InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { EXPECT_TRUE(evb->isInEventBaseThread()); })); } diff --git a/test/PayloadTest.cpp b/rsocket/test/PayloadTest.cpp similarity index 71% rename from test/PayloadTest.cpp rename to rsocket/test/PayloadTest.cpp index c9435ee61..3b9e8496b 100644 --- a/test/PayloadTest.cpp +++ b/rsocket/test/PayloadTest.cpp @@ -1,13 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include + #include + #include "rsocket/Payload.h" #include "rsocket/framing/Frame.h" -#include "rsocket/framing/FrameSerializer_v0_1.h" +#include "rsocket/framing/FrameSerializer_v1_0.h" -using namespace ::testing; using namespace ::rsocket; TEST(PayloadTest, EmptyMetadata) { @@ -24,17 +37,6 @@ TEST(PayloadTest, Clear) { ASSERT_FALSE(p); } -TEST(PayloadTest, GiantMetadata) { - constexpr auto metadataSize = std::numeric_limits::max(); - - auto metadata = folly::IOBuf::wrapBuffer(&metadataSize, sizeof(metadataSize)); - folly::io::Cursor cur(metadata.get()); - - EXPECT_THROW( - FrameSerializerV0_1::deserializeMetadataFrom(cur, FrameFlags::METADATA), - std::runtime_error); -} - TEST(PayloadTest, Clone) { Payload orig("data", "metadata"); diff --git a/test/RSocketClientServerTest.cpp b/rsocket/test/RSocketClientServerTest.cpp similarity index 80% rename from test/RSocketClientServerTest.cpp rename to rsocket/test/RSocketClientServerTest.cpp index 8f6e3c55c..13c4a9218 100644 --- a/test/RSocketClientServerTest.cpp +++ b/rsocket/test/RSocketClientServerTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "RSocketTests.h" @@ -46,11 +58,12 @@ TEST(RSocketClientServer, ConnectManyAsync) { auto clientFuture = makeClientAsync( workers[workerId].getEventBase(), *server->listeningPort()) - .then([&executed](std::shared_ptr client) { - ++executed; - return client; - }) - .onError([&](folly::exception_wrapper ex) { + .thenValue( + [&executed](std::shared_ptr client) { + ++executed; + return client; + }) + .thenError([&](folly::exception_wrapper ex) { LOG(ERROR) << "error: " << ex.what(); ++executed; return std::shared_ptr(nullptr); @@ -86,7 +99,7 @@ TEST(RSocketClientServer, ClientClosesOnWorker) { auto client = makeClient(worker.getEventBase(), *server->listeningPort()); // Move the client to the worker thread. - worker.getEventBase()->runInEventBaseThread([c = std::move(client)]{}); + worker.getEventBase()->runInEventBaseThread([c = std::move(client)] {}); } /// Test that sending garbage to the server doesn't crash it. @@ -98,7 +111,9 @@ TEST(RSocketClientServer, ServerGetsGarbage) { auto factory = std::make_shared(*worker.getEventBase(), address); - auto result = factory->connect().get(); + auto result = + factory->connect(ProtocolVersion::Latest, ResumeStatus::NEW_SESSION) + .get(); auto connection = std::move(result.connection); auto evb = &result.eventBase; diff --git a/rsocket/test/RSocketClientTest.cpp b/rsocket/test/RSocketClientTest.cpp new file mode 100644 index 000000000..5a96d09ae --- /dev/null +++ b/rsocket/test/RSocketClientTest.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RSocketTests.h" + +#include +#include + +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::single; + +TEST(RSocketClient, ConnectFails) { + folly::ScopedEventBaseThread worker; + + folly::SocketAddress address; + address.setFromHostPort("localhost", 1); + auto client = + RSocket::createConnectedClient(std::make_unique( + *worker.getEventBase(), std::move(address))); + + std::move(client) + .thenValue([&](auto&&) { FAIL() << "the test needs to fail"; }) + .thenError( + folly::tag_t{}, + [&](const std::exception&) { + LOG(INFO) << "connection failed as expected"; + }) + .get(); +} + +TEST(RSocketClient, PreallocatedBytesInFrames) { + auto connection = std::make_unique(); + EXPECT_CALL(*connection, isFramed()).WillRepeatedly(Return(true)); + + // SETUP frame and FIRE_N_FORGET frame send + EXPECT_CALL(*connection, send_(_)) + .Times(2) + .WillRepeatedly( + Invoke([](std::unique_ptr& serializedFrame) { + // we should have headroom preallocated for the frame size field + EXPECT_EQ( + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest) + ->frameLengthFieldSize(), + serializedFrame->headroom()); + })); + + folly::ScopedEventBaseThread worker; + + worker.getEventBase()->runInEventBaseThread([&] { + auto client = RSocket::createClientFromConnection( + std::move(connection), *worker.getEventBase()); + + client->getRequester() + ->fireAndForget(Payload("hello")) + ->subscribe(SingleObservers::create()); + }); +} diff --git a/test/RSocketTests.cpp b/rsocket/test/RSocketTests.cpp similarity index 78% rename from test/RSocketTests.cpp rename to rsocket/test/RSocketTests.cpp index 15b628883..c3d309f0e 100644 --- a/test/RSocketTests.cpp +++ b/rsocket/test/RSocketTests.cpp @@ -1,9 +1,22 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/test/RSocketTests.h" -#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" +#include "rsocket/internal/WarmResumeManager.h" #include "rsocket/test/test_utils/GenericRequestResponseHandler.h" +#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" namespace rsocket { namespace tests { @@ -57,7 +70,7 @@ folly::Future> makeClientAsync( kDefaultKeepaliveInterval, std::move(stats), std::shared_ptr(), - std::shared_ptr(), + ResumeManager::makeEmpty(), std::shared_ptr(), stateMachineEvb); } @@ -67,30 +80,30 @@ std::unique_ptr makeClient( uint16_t port, folly::EventBase* stateMachineEvb, std::shared_ptr stats) { - return makeClientAsync( - eventBase, port, stateMachineEvb, std::move(stats)).get(); + return makeClientAsync(eventBase, port, stateMachineEvb, std::move(stats)) + .get(); } namespace { struct DisconnectedResponder : public rsocket::RSocketResponder { DisconnectedResponder() {} - yarpl::Reference> + std::shared_ptr> handleRequestResponse(rsocket::Payload, rsocket::StreamId) override { CHECK(false); return nullptr; } - yarpl::Reference> + std::shared_ptr> handleRequestStream(rsocket::Payload, rsocket::StreamId) override { CHECK(false); return nullptr; } - yarpl::Reference> + std::shared_ptr> handleRequestChannel( rsocket::Payload, - yarpl::Reference>, + std::shared_ptr>, rsocket::StreamId) override { CHECK(false); return nullptr; @@ -104,14 +117,13 @@ struct DisconnectedResponder : public rsocket::RSocketResponder { CHECK(false); } - ~DisconnectedResponder() {} + ~DisconnectedResponder() override {} }; -} +} // namespace std::unique_ptr makeDisconnectedClient( folly::EventBase* eventBase) { - auto server = - makeServer(std::make_shared()); + auto server = makeServer(std::make_shared()); auto client = makeClient(eventBase, *server->listeningPort()); client->disconnect().get(); @@ -133,7 +145,7 @@ std::unique_ptr makeWarmResumableClient( kDefaultKeepaliveInterval, RSocketStats::noop(), std::move(connectionEvents), - std::shared_ptr(), + std::make_shared(RSocketStats::noop()), std::shared_ptr(), stateMachineEvb) .get(); diff --git a/rsocket/test/RSocketTests.h b/rsocket/test/RSocketTests.h new file mode 100644 index 000000000..147238901 --- /dev/null +++ b/rsocket/test/RSocketTests.h @@ -0,0 +1,178 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/RSocket.h" + +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +namespace rsocket { +namespace tests { +namespace client_server { + +class RSocketStatsFlowControl : public RSocketStats { + public: + void frameWritten(FrameType frameType) { + if (frameType == FrameType::REQUEST_N) { + ++writeRequestN_; + } + } + + void frameRead(FrameType frameType) { + if (frameType == FrameType::REQUEST_N) { + ++readRequestN_; + } + } + + public: + int writeRequestN_{0}; + int readRequestN_{0}; +}; + +std::unique_ptr getConnFactory( + folly::EventBase* eventBase, + uint16_t port); + +std::unique_ptr makeServer( + std::shared_ptr responder, + std::shared_ptr stats = RSocketStats::noop()); + +std::unique_ptr makeResumableServer( + std::shared_ptr serviceHandler); + +std::unique_ptr makeClient( + folly::EventBase* eventBase, + uint16_t port, + folly::EventBase* stateMachineEvb = nullptr, + std::shared_ptr stats = RSocketStats::noop()); + +std::unique_ptr makeDisconnectedClient( + folly::EventBase* eventBase); + +folly::Future> makeClientAsync( + folly::EventBase* eventBase, + uint16_t port, + folly::EventBase* stateMachineEvb = nullptr, + std::shared_ptr stats = RSocketStats::noop()); + +std::unique_ptr makeWarmResumableClient( + folly::EventBase* eventBase, + uint16_t port, + std::shared_ptr connectionEvents = nullptr, + folly::EventBase* stateMachineEvb = nullptr); + +std::unique_ptr makeColdResumableClient( + folly::EventBase* eventBase, + uint16_t port, + ResumeIdentificationToken token, + std::shared_ptr resumeManager, + std::shared_ptr resumeHandler, + folly::EventBase* stateMachineEvb = nullptr); + +} // namespace client_server + +struct RSocketPayloadUtils { + // ~30 megabytes, for metadata+data + static constexpr size_t LargeRequestSize = 15 * 1024 * 1024; + static std::string makeLongString(size_t size, std::string pattern) { + while (pattern.size() < size) { + pattern += pattern; + } + return pattern; + } + + // Builds up an IOBuf consisting of chunks with the following sizes, and then + // the rest tacked on the end in one big iobuf chunk + static std::unique_ptr buildIOBufFromString( + std::vector const& sizes, + std::string const& from) { + folly::IOBufQueue bufQueue{folly::IOBufQueue::cacheChainLength()}; + size_t fromCursor = 0; + size_t remaining = from.size(); + for (auto size : sizes) { + if (remaining == 0) + break; + if (size > remaining) { + size = remaining; + } + + bufQueue.append( + folly::IOBuf::copyBuffer(from.c_str() + fromCursor, size)); + + fromCursor += size; + remaining -= size; + } + + if (remaining) { + bufQueue.append( + folly::IOBuf::copyBuffer(from.c_str() + fromCursor, remaining)); + } + + CHECK_EQ(bufQueue.chainLength(), from.size()); + + auto ret = bufQueue.move(); + int numChainElems = 1; + auto currentChainElem = ret.get()->next(); + while (currentChainElem != ret.get()) { + numChainElems++; + currentChainElem = currentChainElem->next(); + } + CHECK_GE(numChainElems, sizes.size()); + + // verify that the returned buffer has identical data + auto str = ret->cloneAsValue().moveToFbString().toStdString(); + CHECK_EQ(str.size(), from.size()); + CHECK(str == from); + + return ret; + } + + static void checkSameStrings( + std::string const& got, + std::string const& expect, + std::string const& context) { + CHECK_EQ(got.size(), expect.size()) + << "Got mismatched size " << context << " string (" << got.size() + << " vs " << expect.size() << ")"; + CHECK(got == expect) << context << " mismatch between got and expected"; + } + + static void checkSameStrings( + std::unique_ptr const& got, + std::string const& expect, + std::string const& context) { + CHECK_EQ(got->computeChainDataLength(), expect.size()) + << "Mismatched size " << context << ", got " + << got->computeChainDataLength() << " vs expect " << expect.size(); + + size_t expect_cursor = 0; + + for (auto range : *got) { + for (auto got_chr : range) { + // perform redundant check to avoid gtest's CHECK overhead + if (got_chr != expect[expect_cursor]) { + CHECK_EQ(got_chr, expect[expect_cursor]) + << "mismatch at byte " << expect_cursor; + } + expect_cursor++; + } + } + } +}; + +} // namespace tests +} // namespace rsocket diff --git a/test/RequestChannelTest.cpp b/rsocket/test/RequestChannelTest.cpp similarity index 54% rename from test/RequestChannelTest.cpp rename to rsocket/test/RequestChannelTest.cpp index 10d4950f5..4a815166a 100644 --- a/test/RequestChannelTest.cpp +++ b/rsocket/test/RequestChannelTest.cpp @@ -1,10 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include #include "RSocketTests.h" +#include "rsocket/test/test_utils/GenericRequestResponseHandler.h" #include "yarpl/Flowable.h" #include "yarpl/flowable/TestSubscriber.h" @@ -20,13 +33,13 @@ using namespace rsocket::tests::client_server; class TestHandlerHello : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestChannel( + std::shared_ptr> + handleRequestChannel( rsocket::Payload initialPayload, - yarpl::Reference> request, - rsocket::StreamId) override { + std::shared_ptr> stream, + rsocket::StreamId /*streamId*/) override { // say "Hello" to each name on the input stream - return request->map([initialPayload = std::move(initialPayload)]( - Payload p) { + return stream->map([initialPayload = std::move(initialPayload)](Payload p) { std::stringstream ss; ss << "[" << initialPayload.cloneDataToString() << "] " << "Hello " << p.moveDataToString() << "!"; @@ -46,7 +59,8 @@ TEST(RequestChannelTest, Hello) { auto ts = TestSubscriber::create(); requester ->requestChannel( - Flowables::justN({"/hello", "Bob", "Jane"})->map([](std::string v) { + Payload("/hello"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { return Payload(v); })) ->map([](auto p) { return p.moveDataToString(); }) @@ -69,10 +83,11 @@ TEST(RequestChannelTest, HelloNoFlowControl) { worker.getEventBase(), *server->listeningPort(), nullptr, stats); auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); + auto ts = TestSubscriber::create(1000); requester ->requestChannel( - Flowables::justN({"/hello", "Bob", "Jane"})->map([](std::string v) { + Payload("/hello"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { return Payload(v); })) ->map([](auto p) { return p.moveDataToString(); }) @@ -99,7 +114,7 @@ TEST(RequestChannelTest, RequestOnDisconnectedClient) { bool did_call_on_error = false; folly::Baton<> wait_for_on_error; - auto instream = Flowables::empty(); + auto instream = Flowable::empty(); requester->requestChannel(instream)->subscribe( [](auto /* payload */) { // onNext shouldn't be called @@ -115,7 +130,7 @@ TEST(RequestChannelTest, RequestOnDisconnectedClient) { }); wait_for_on_error.timed_wait(std::chrono::milliseconds(100)); - ASSERT(did_call_on_error); + ASSERT_TRUE(did_call_on_error); } class TestChannelResponder : public rsocket::RSocketResponder { @@ -126,9 +141,9 @@ class TestChannelResponder : public rsocket::RSocketResponder { : rangeEnd_{rangeEnd}, testSubscriber_{TestSubscriber::create(initialSubReq)} {} - yarpl::Reference> handleRequestChannel( + std::shared_ptr> handleRequestChannel( rsocket::Payload initialPayload, - yarpl::Reference> requestStream, + std::shared_ptr> requestStream, rsocket::StreamId) override { // add initial payload to testSubscriber values list testSubscriber_->manuallyPush(initialPayload.moveDataToString()); @@ -136,7 +151,7 @@ class TestChannelResponder : public rsocket::RSocketResponder { requestStream->map([](auto p) { return p.moveDataToString(); }) ->subscribe(testSubscriber_); - return Flowables::range(1, rangeEnd_)->map([&](int64_t v) { + return Flowable<>::range(1, rangeEnd_)->map([&](int64_t v) { std::stringstream ss; ss << "Responder stream: " << v << " of " << rangeEnd_; std::string s = ss.str(); @@ -144,13 +159,13 @@ class TestChannelResponder : public rsocket::RSocketResponder { }); } - Reference> getChannelSubscriber() { + std::shared_ptr> getChannelSubscriber() { return testSubscriber_; } private: int64_t rangeEnd_; - Reference> testSubscriber_; + std::shared_ptr> testSubscriber_; }; TEST(RequestChannelTest, CompleteRequesterResponderContinues) { @@ -171,23 +186,24 @@ TEST(RequestChannelTest, CompleteRequesterResponderContinues) { int64_t requesterRangeEnd = 10; auto requesterFlowable = - Flowables::range(1, requesterRangeEnd)->map([=](int64_t v) { + Flowable<>::range(1, requesterRangeEnd)->map([=](int64_t v) { std::stringstream ss; ss << "Requester stream: " << v << " of " << requesterRangeEnd; std::string s = ss.str(); return Payload(s, "metadata"); }); - requester->requestChannel(requesterFlowable) + requester->requestChannel(Payload("Initial Request"), requesterFlowable) ->map([](auto p) { return p.moveDataToString(); }) ->subscribe(requestSubscriber); // finish streaming from Requester responderSubscriber->awaitTerminalEvent(); responderSubscriber->assertSuccess(); - responderSubscriber->assertValueCount(10); - responderSubscriber->assertValueAt(0, "Requester stream: 1 of 10"); - responderSubscriber->assertValueAt(9, "Requester stream: 10 of 10"); + responderSubscriber->assertValueCount(11); + responderSubscriber->assertValueAt(0, "Initial Request"); + responderSubscriber->assertValueAt(1, "Requester stream: 1 of 10"); + responderSubscriber->assertValueAt(10, "Requester stream: 10 of 10"); // Requester stream is closed, Responder continues requestSubscriber->request(50); @@ -216,7 +232,7 @@ TEST(RequestChannelTest, CompleteResponderRequesterContinues) { int64_t requesterRangeEnd = 100; auto requesterFlowable = - Flowables::range(1, requesterRangeEnd)->map([=](int64_t v) { + Flowable<>::range(1, requesterRangeEnd)->map([=](int64_t v) { std::stringstream ss; ss << "Requester stream: " << v << " of " << requesterRangeEnd; std::string s = ss.str(); @@ -244,8 +260,8 @@ TEST(RequestChannelTest, CompleteResponderRequesterContinues) { } TEST(RequestChannelTest, FlowControl) { - int64_t responderRange = 10; - int64_t responderSubscriberInitialRequest = 0; + constexpr int64_t responderRange = 10; + constexpr int64_t responderSubscriberInitialRequest = 0; auto responder = std::make_shared( responderRange, responderSubscriberInitialRequest); @@ -255,35 +271,38 @@ TEST(RequestChannelTest, FlowControl) { auto client = makeClient(worker.getEventBase(), *server->listeningPort()); auto requester = client->getRequester(); - auto requestSubscriber = TestSubscriber::create(1); + auto requestSubscriber = TestSubscriber::create(0); auto responderSubscriber = responder->getChannelSubscriber(); - int64_t requesterRangeEnd = 10; + constexpr int64_t requesterRangeEnd = 10; auto requesterFlowable = - Flowables::range(1, requesterRangeEnd)->map([&](int64_t v) { + Flowable<>::range(1, requesterRangeEnd)->map([&](int64_t v) { std::stringstream ss; ss << "Requester stream: " << v << " of " << requesterRangeEnd; std::string s = ss.str(); return Payload(s, "metadata"); }); - requester->requestChannel(requesterFlowable) + requester->requestChannel(Payload("Initial Request"), requesterFlowable) ->map([](auto p) { return p.moveDataToString(); }) ->subscribe(requestSubscriber); + // Wait till the Channel is created responderSubscriber->awaitValueCount(1); - requestSubscriber->awaitValueCount(1); - for (int i = 2; i <= 10; i++) { + for (int i = 1; i <= 10; i++) { requestSubscriber->request(1); - responderSubscriber->request(1); - - responderSubscriber->awaitValueCount(i); requestSubscriber->awaitValueCount(i); - requestSubscriber->assertValueCount(i); - responderSubscriber->assertValueCount(i); + } + + for (int i = 1; i <= 10; i++) { + responderSubscriber->request(1); + // the channel initial payload was pushed to responderSubscriber so we + // need to add this one item to expected + responderSubscriber->awaitValueCount(i + 1); + responderSubscriber->assertValueCount(i + 1); } requestSubscriber->awaitTerminalEvent(); @@ -295,8 +314,9 @@ TEST(RequestChannelTest, FlowControl) { requestSubscriber->assertValueAt(0, "Responder stream: 1 of 10"); requestSubscriber->assertValueAt(9, "Responder stream: 10 of 10"); - responderSubscriber->assertValueAt(0, "Requester stream: 1 of 10"); - responderSubscriber->assertValueAt(9, "Requester stream: 10 of 10"); + responderSubscriber->assertValueAt(0, "Initial Request"); + responderSubscriber->assertValueAt(1, "Requester stream: 1 of 10"); + responderSubscriber->assertValueAt(10, "Requester stream: 10 of 10"); } class TestChannelResponderFailure : public rsocket::RSocketResponder { @@ -304,9 +324,9 @@ class TestChannelResponderFailure : public rsocket::RSocketResponder { TestChannelResponderFailure() : testSubscriber_{TestSubscriber::create()} {} - yarpl::Reference> handleRequestChannel( + std::shared_ptr> handleRequestChannel( rsocket::Payload initialPayload, - yarpl::Reference> requestStream, + std::shared_ptr> requestStream, rsocket::StreamId) override { // add initial payload to testSubscriber values list testSubscriber_->manuallyPush(initialPayload.moveDataToString()); @@ -314,16 +334,16 @@ class TestChannelResponderFailure : public rsocket::RSocketResponder { requestStream->map([](auto p) { return p.moveDataToString(); }) ->subscribe(testSubscriber_); - return Flowables::error( + return Flowable::error( std::runtime_error("A wild Error appeared!")); } - Reference> getChannelSubscriber() { + std::shared_ptr> getChannelSubscriber() { return testSubscriber_; } private: - Reference> testSubscriber_; + std::shared_ptr> testSubscriber_; }; TEST(RequestChannelTest, FailureOnResponderRequesterSees) { @@ -340,22 +360,151 @@ TEST(RequestChannelTest, FailureOnResponderRequesterSees) { int64_t requesterRangeEnd = 10; auto requesterFlowable = - Flowables::range(1, requesterRangeEnd)->map([&](int64_t v) { + Flowable<>::range(1, requesterRangeEnd)->map([&](int64_t v) { std::stringstream ss; ss << "Requester stream: " << v << " of " << requesterRangeEnd; std::string s = ss.str(); return Payload(s, "metadata"); }); - requester->requestChannel(requesterFlowable) + requester->requestChannel(Payload("Initial Request"), requesterFlowable) ->map([](auto p) { return p.moveDataToString(); }) ->subscribe(requestSubscriber); // failure streaming from Responder requestSubscriber->awaitTerminalEvent(); - requestSubscriber->assertOnErrorMessage("A wild Error appeared!"); + requestSubscriber->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(requestSubscriber->getException().with_exception( + [](ErrorWithPayload& err) { + EXPECT_STREQ( + "A wild Error appeared!", err.payload.moveDataToString().c_str()); + })); responderSubscriber->awaitTerminalEvent(); - responderSubscriber->assertValueAt(0, "Requester stream: 1 of 10"); - responderSubscriber->assertValueAt(9, "Requester stream: 10 of 10"); + responderSubscriber->assertSuccess(); + responderSubscriber->assertValueCount(1); + responderSubscriber->assertValueAt(0, "Initial Request"); +} + +struct LargePayloadChannelHandler : public rsocket::RSocketResponder { + LargePayloadChannelHandler(std::string const& data, std::string const& meta) + : data(data), meta(meta) {} + + std::shared_ptr> handleRequestChannel( + Payload initialPayload, + std::shared_ptr> stream, + StreamId) override { + RSocketPayloadUtils::checkSameStrings( + initialPayload.data, data, "data received in initial payload"); + RSocketPayloadUtils::checkSameStrings( + initialPayload.metadata, meta, "metadata received in initial payload"); + + return stream->map([&](Payload payload) { + RSocketPayloadUtils::checkSameStrings( + payload.data, data, "data received in server stream"); + RSocketPayloadUtils::checkSameStrings( + payload.metadata, meta, "metadata received in server stream"); + return payload; + }); + } + + std::string const& data; + std::string const& meta; +}; + +TEST(RequestChannelTest, TestLargePayload) { + LOG(INFO) << "Building up large data/metadata, this may take a moment..."; + std::string const niceLongData = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "ABCDEFGH"); + std::string const niceLongMeta = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "12345678"); + + LOG(INFO) << "Built meta size: " << niceLongMeta.size() + << " data size: " << niceLongData.size(); + + folly::ScopedEventBaseThread worker; + auto handler = + std::make_shared(niceLongData, niceLongMeta); + auto server = makeServer(handler); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto checkForSizePattern = [&](std::vector const& meta_sizes, + std::vector const& data_sizes) { + auto to = TestSubscriber::create(); + + auto seedPayload = Payload( + RSocketPayloadUtils::buildIOBufFromString(data_sizes, niceLongData), + RSocketPayloadUtils::buildIOBufFromString(meta_sizes, niceLongMeta)); + + auto makePayload = [&] { + return Payload(seedPayload.data->clone(), seedPayload.metadata->clone()); + }; + + auto requests = + yarpl::flowable::Flowable::create([&](auto& subscriber, + int64_t num) { + while (num--) { + subscriber.onNext(makePayload()); + } + })->take(3); + + requester->requestChannel(std::move(requests)) + ->map([&](Payload p) { + RSocketPayloadUtils::checkSameStrings( + p.data, niceLongData, "data received on client"); + RSocketPayloadUtils::checkSameStrings( + p.metadata, niceLongMeta, "metadata received on client"); + return 0; + }) + ->subscribe(to); + to->awaitTerminalEvent(std::chrono::seconds{20}); + to->assertValueCount(2); + to->assertSuccess(); + }; + + // All in one big chunk + checkForSizePattern({}, {}); + + // Small chunk, big chunk, small chunk + checkForSizePattern({100, 5 * 1024 * 1024, 100}, {100, 5 * 1024 * 1024, 100}); +} + +TEST(RequestChannelTest, MultiSubscribe) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto ts = TestSubscriber::create(); + auto stream = + requester + ->requestChannel( + Payload("/hello"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { + return Payload(v); + })) + ->map([](auto p) { return p.moveDataToString(); }); + + // First subscribe + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(2); + // assert that we echo back the 2nd and 3rd request values + // with the 1st initial payload prepended to each + ts->assertValueAt(0, "[/hello] Hello Bob!"); + ts->assertValueAt(1, "[/hello] Hello Jane!"); + + // Second subscribe + ts = TestSubscriber::create(); + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(2); + // assert that we echo back the 2nd and 3rd request values + // with the 1st initial payload prepended to each + ts->assertValueAt(0, "[/hello] Hello Bob!"); + ts->assertValueAt(1, "[/hello] Hello Jane!"); } diff --git a/test/RequestResponseTest.cpp b/rsocket/test/RequestResponseTest.cpp similarity index 51% rename from test/RequestResponseTest.cpp rename to rsocket/test/RequestResponseTest.cpp index d50bb01d5..313ff124d 100644 --- a/test/RequestResponseTest.cpp +++ b/rsocket/test/RequestResponseTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -10,7 +22,6 @@ #include "yarpl/Single.h" #include "yarpl/single/SingleTestObserver.h" -using namespace yarpl; using namespace yarpl::single; using namespace rsocket; using namespace rsocket::tests; @@ -23,8 +34,9 @@ class TestHandlerCancel : public rsocket::RSocketResponder { std::shared_ptr> onCancel, std::shared_ptr> onSubscribe) : onCancel_(std::move(onCancel)), onSubscribe_(std::move(onSubscribe)) {} - Reference> handleRequestResponse(Payload request, StreamId) - override { + std::shared_ptr> handleRequestResponse( + Payload request, + StreamId) override { // used to signal to the client when the subscribe is received onSubscribe_->post(); // used to block this responder thread until a cancel is sent from client @@ -33,32 +45,29 @@ class TestHandlerCancel : public rsocket::RSocketResponder { // used to signal to the client once we receive a cancel auto onCancel = onCancel_; auto requestString = request.moveDataToString(); - return Single::create( - [ name = std::move(requestString), cancelFromClient, onCancel ]( - auto subscriber) mutable { - std::thread([ - subscriber = std::move(subscriber), - name = std::move(name), - cancelFromClient, - onCancel - ]() { - auto subscription = SingleSubscriptions::create( - [cancelFromClient] { cancelFromClient->post(); }); - subscriber->onSubscribe(subscription); - // simulate slow processing or IO being done - // and block this current background thread - // until we are cancelled - cancelFromClient->wait(); - if (subscription->isCancelled()) { - // this is used by the unit test to assert the cancel was - // received - onCancel->post(); - } else { - // if not cancelled would do work and emit here - } - }) - .detach(); - }); + return Single::create([name = std::move(requestString), + cancelFromClient, + onCancel](auto subscriber) mutable { + std::thread([subscriber = std::move(subscriber), + name = std::move(name), + cancelFromClient, + onCancel]() { + auto subscription = SingleSubscriptions::create( + [cancelFromClient] { cancelFromClient->post(); }); + subscriber->onSubscribe(subscription); + // simulate slow processing or IO being done + // and block this current background thread + // until we are cancelled + cancelFromClient->timed_wait(std::chrono::seconds(1)); + if (subscription->isCancelled()) { + // this is used by the unit test to assert the cancel was + // received + onCancel->post(); + } else { + // if not cancelled would do work and emit here + } + }).detach(); + }); } private: @@ -133,7 +142,10 @@ TEST(RequestResponseTest, FailureInResponse) { ->map(payload_to_stringpair) ->subscribe(to); to->awaitTerminalEvent(); - to->assertOnErrorMessage("whew!"); + to->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(to->getException().with_exception([](ErrorWithPayload& err) { + EXPECT_STREQ("whew!", err.payload.moveDataToString().c_str()); + })); } TEST(RequestResponseTest, RequestOnDisconnectedClient) { @@ -155,7 +167,7 @@ TEST(RequestResponseTest, RequestOnDisconnectedClient) { }); wait_for_on_error.timed_wait(std::chrono::milliseconds(100)); - ASSERT(did_call_on_error); + ASSERT_TRUE(did_call_on_error); } // TODO: test that multiple requests on a requestResponse @@ -198,3 +210,96 @@ TEST(RequestResponseTest, FailureOnRequest) { to->awaitTerminalEvent(); EXPECT_TRUE(to->getError()); } + +struct LargePayloadReqRespHandler : public rsocket::RSocketResponder { + LargePayloadReqRespHandler(std::string const& data, std::string const& meta) + : data(data), meta(meta) {} + + std::shared_ptr> handleRequestResponse( + Payload payload, + StreamId) override { + RSocketPayloadUtils::checkSameStrings( + payload.data, data, "data received in payload"); + RSocketPayloadUtils::checkSameStrings( + payload.metadata, meta, "metadata received in payload"); + + return yarpl::single::Single::create( + [p = std::move(payload)](auto sub) mutable { + sub->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + sub->onSuccess(std::move(p)); + }); + } + + std::string const& data; + std::string const& meta; +}; + +TEST(RequestResponseTest, TestLargePayload) { + VLOG(1) << "Building up large data/metadata, this may take a moment..."; + std::string niceLongData = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "ABCDEFGH"); + std::string niceLongMeta = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "12345678"); + VLOG(1) << "Built meta size: " << niceLongMeta.size() + << " data size: " << niceLongData.size(); + + auto checkForSizePattern = [&](std::vector const& meta_sizes, + std::vector const& data_sizes) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + niceLongData, niceLongMeta)); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto to = SingleTestObserver::create(); + + requester + ->requestResponse(Payload( + RSocketPayloadUtils::buildIOBufFromString(data_sizes, niceLongData), + RSocketPayloadUtils::buildIOBufFromString( + meta_sizes, niceLongMeta))) + ->map([&](Payload p) { + RSocketPayloadUtils::checkSameStrings( + p.data, niceLongData, "data (received on client)"); + RSocketPayloadUtils::checkSameStrings( + p.metadata, niceLongMeta, "metadata (received on client)"); + return 0; + }) + ->subscribe(to); + to->awaitTerminalEvent(); + to->assertSuccess(); + }; + + // All in one big chunk + checkForSizePattern({}, {}); + + // Small chunk, big chunk, small chunk + checkForSizePattern( + {100, 10 * 1024 * 1024, 100}, {100, 10 * 1024 * 1024, 100}); +} + +TEST(RequestResponseTest, MultiSubscribe) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + [](StringPair const& request) { + return payload_response( + "Hello, " + request.first + " " + request.second + "!", ":)"); + })); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = SingleTestObserver::create(); + auto single = requester->requestResponse(Payload("Jane", "Doe")) + ->map(payload_to_stringpair); + + // Subscribe once + single->subscribe(to); + to->awaitTerminalEvent(); + to->assertOnSuccessValue({"Hello, Jane Doe!", ":)"}); + + // Subscribe twice + to = SingleTestObserver::create(); + single->subscribe(to); + to->awaitTerminalEvent(); + to->assertOnSuccessValue({"Hello, Jane Doe!", ":)"}); +} diff --git a/rsocket/test/RequestStreamTest.cpp b/rsocket/test/RequestStreamTest.cpp new file mode 100644 index 000000000..22633697c --- /dev/null +++ b/rsocket/test/RequestStreamTest.cpp @@ -0,0 +1,357 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "RSocketTests.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +using namespace yarpl::flowable; +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; + +namespace { +class TestHandlerSync : public rsocket::RSocketResponder { + public: + std::shared_ptr> handleRequestStream( + Payload request, + StreamId) override { + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 10)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + } +}; + +TEST(RequestStreamTest, HelloSync) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); +} + +TEST(RequestStreamTest, HelloFlowControl) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(5); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + + ts->awaitValueCount(5); + + ts->assertValueCount(5); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(4, "Hello Bob 5!"); + + ts->request(5); + + ts->awaitValueCount(10); + + ts->assertValueCount(10); + ts->assertValueAt(5, "Hello Bob 6!"); + ts->assertValueAt(9, "Hello Bob 10!"); + + ts->awaitTerminalEvent(); + ts->assertSuccess(); +} + +TEST(RequestStreamTest, HelloNoFlowControl) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto stats = std::make_shared(); + auto client = makeClient( + worker.getEventBase(), *server->listeningPort(), nullptr, stats); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(1000); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); + + // Make sure that the initial requestN in the Stream Request Frame + // is already enough and no other requestN messages are sent. + EXPECT_EQ(stats->writeRequestN_, 0); +} + +class TestHandlerAsync : public rsocket::RSocketResponder { + public: + explicit TestHandlerAsync(folly::Executor& executor) : executor_(executor) {} + + std::shared_ptr> handleRequestStream( + Payload request, + StreamId) override { + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 40) + ->map([name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }) + ->subscribeOn(executor_); + } + + private: + folly::Executor& executor_; +}; +} // namespace + +TEST(RequestStreamTest, HelloAsync) { + folly::ScopedEventBaseThread worker; + folly::ScopedEventBaseThread worker2; + auto server = + makeServer(std::make_shared(*worker2.getEventBase())); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(40); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(39, "Hello Bob 40!"); +} + +TEST(RequestStreamTest, RequestOnDisconnectedClient) { + folly::ScopedEventBaseThread worker; + auto client = makeDisconnectedClient(worker.getEventBase()); + auto requester = client->getRequester(); + + bool did_call_on_error = false; + folly::Baton<> wait_for_on_error; + + requester->requestStream(Payload("foo", "bar")) + ->subscribe( + [](auto /* payload */) { + // onNext shouldn't be called + FAIL(); + }, + [&](folly::exception_wrapper) { + did_call_on_error = true; + wait_for_on_error.post(); + }, + []() { + // onComplete shouldn't be called + FAIL(); + }); + + wait_for_on_error.timed_wait(std::chrono::milliseconds(100)); + ASSERT_TRUE(did_call_on_error); +} + +class TestHandlerResponder : public rsocket::RSocketResponder { + public: + std::shared_ptr> handleRequestStream(Payload, StreamId) + override { + return Flowable::error( + std::runtime_error("A wild Error appeared!")); + } +}; + +TEST(RequestStreamTest, HandleError) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + // Hide the user error from the logs + ts->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(ts->getException().with_exception([](ErrorWithPayload& err) { + EXPECT_STREQ( + "A wild Error appeared!", err.payload.moveDataToString().c_str()); + })); +} + +class TestErrorAfterOnNextResponder : public rsocket::RSocketResponder { + public: + std::shared_ptr> handleRequestStream( + Payload request, + StreamId) override { + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable::create( + [name = std::move(requestString)]( + Subscriber& subscriber, int64_t requested) { + EXPECT_GT(requested, 1); + subscriber.onNext(Payload(name, "meta")); + subscriber.onNext(Payload(name, "meta")); + subscriber.onNext(Payload(name, "meta")); + subscriber.onNext(Payload(name, "meta")); + subscriber.onError(std::runtime_error("A wild Error appeared!")); + }); + } +}; + +TEST(RequestStreamTest, HandleErrorMidStream) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertValueCount(4); + ts->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(ts->getException().with_exception([](ErrorWithPayload& err) { + EXPECT_STREQ( + "A wild Error appeared!", err.payload.moveDataToString().c_str()); + })); +} + +struct LargePayloadStreamHandler : public rsocket::RSocketResponder { + LargePayloadStreamHandler( + std::string const& data, + std::string const& meta, + Payload const& seedPayload) + : data(data), meta(meta), seedPayload(seedPayload) {} + + std::shared_ptr> handleRequestStream( + Payload initialPayload, + StreamId) override { + RSocketPayloadUtils::checkSameStrings( + initialPayload.data, data, "data received in initial payload"); + RSocketPayloadUtils::checkSameStrings( + initialPayload.metadata, meta, "metadata received in initial payload"); + + return yarpl::flowable::Flowable::create([&](auto& subscriber, + int64_t num) { + while (num--) { + auto p = Payload( + seedPayload.data->clone(), seedPayload.metadata->clone()); + subscriber.onNext(std::move(p)); + } + }) + ->take(3); + } + + std::string const& data; + std::string const& meta; + Payload const& seedPayload; +}; + +TEST(RequestStreamTest, TestLargePayload) { + LOG(INFO) << "Building up large data/metadata, this may take a moment..."; + // ~20 megabytes per frame (metadata + data) + std::string const niceLongData = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "ABCDEFGH"); + std::string const niceLongMeta = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "12345678"); + + LOG(INFO) << "Built meta size: " << niceLongMeta.size() + << " data size: " << niceLongData.size(); + + auto checkForSizePattern = [&](std::vector const& meta_sizes, + std::vector const& data_sizes) { + folly::ScopedEventBaseThread worker; + auto seedPayload = Payload( + RSocketPayloadUtils::buildIOBufFromString(data_sizes, niceLongData), + RSocketPayloadUtils::buildIOBufFromString(meta_sizes, niceLongMeta)); + auto makePayload = [&] { + return Payload(seedPayload.data->clone(), seedPayload.metadata->clone()); + }; + + auto handler = std::make_shared( + niceLongData, niceLongMeta, seedPayload); + auto server = makeServer(handler); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = TestSubscriber::create(); + + requester->requestStream(makePayload()) + ->map([&](Payload p) { + RSocketPayloadUtils::checkSameStrings( + p.data, niceLongData, "data received on client"); + RSocketPayloadUtils::checkSameStrings( + p.metadata, niceLongMeta, "metadata received on client"); + return 0; + }) + ->subscribe(to); + to->awaitTerminalEvent(std::chrono::seconds{20}); + to->assertValueCount(3); + to->assertSuccess(); + }; + + // All in one big chunk + checkForSizePattern({}, {}); + + // Small chunk, big chunk, small chunk + checkForSizePattern({100, 5 * 1024 * 1024, 100}, {100, 5 * 1024 * 1024, 100}); +} + +TEST(RequestStreamTest, MultiSubscribe) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + auto stream = requester->requestStream(Payload("Bob"))->map([](auto p) { + return p.moveDataToString(); + }); + + // First subscribe + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); + + // Second subscribe + ts = TestSubscriber::create(); + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); +} diff --git a/rsocket/test/RequestStreamTest_concurrency.cpp b/rsocket/test/RequestStreamTest_concurrency.cpp new file mode 100644 index 000000000..ce9b7e0b6 --- /dev/null +++ b/rsocket/test/RequestStreamTest_concurrency.cpp @@ -0,0 +1,153 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "RSocketTests.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +#include "yarpl/test_utils/Mocks.h" + +using namespace yarpl::flowable; +using namespace rsocket; +using namespace rsocket::tests::client_server; + +struct LockstepBatons { + folly::Baton<> onSecondPayloadSent; + folly::Baton<> onCancelSent; + folly::Baton<> onCancelReceivedToserver; + folly::Baton<> onCancelReceivedToclient; + folly::Baton<> onRequestReceived; + folly::Baton<> clientFinished; + folly::Baton<> serverFinished; +}; + +using namespace ::testing; + +constexpr std::chrono::milliseconds timeout{100}; + +class LockstepAsyncHandler : public rsocket::RSocketResponder { + LockstepBatons& batons_; + Sequence& subscription_seq_; + folly::ScopedEventBaseThread worker_; + + public: + LockstepAsyncHandler(LockstepBatons& batons, Sequence& subscription_seq) + : batons_(batons), subscription_seq_(subscription_seq) {} + + std::shared_ptr> handleRequestStream(Payload p, StreamId) + override { + EXPECT_EQ(p.moveDataToString(), "initial"); + + auto step1 = Flowable::empty()->doOnComplete([this]() { + this->batons_.onRequestReceived.timed_wait(timeout); + VLOG(3) << "SERVER: sending onNext(foo)"; + }); + + auto step2 = Flowable<>::justOnce(Payload("foo"))->doOnComplete([this]() { + this->batons_.onCancelSent.timed_wait(timeout); + this->batons_.onCancelReceivedToserver.timed_wait(timeout); + VLOG(3) << "SERVER: sending onNext(bar)"; + }); + + auto step3 = Flowable<>::justOnce(Payload("bar"))->doOnComplete([this]() { + this->batons_.onSecondPayloadSent.post(); + VLOG(3) << "SERVER: sending onComplete()"; + }); + + auto generator = Flowable<>::concat(step1, step2, step3) + ->doOnComplete([this]() { + VLOG(3) << "SERVER: posting serverFinished"; + this->batons_.serverFinished.post(); + }) + ->subscribeOn(*worker_.getEventBase()); + + // checked once the subscription is destroyed + auto requestCheckpoint = std::make_shared>(); + EXPECT_CALL(*requestCheckpoint, Call(2)) + .InSequence(this->subscription_seq_) + .WillOnce(Invoke([=](auto n) { + VLOG(3) << "SERVER: got request(" << n << ")"; + EXPECT_EQ(n, 2); + this->batons_.onRequestReceived.post(); + })); + + auto cancelCheckpoint = std::make_shared>(); + EXPECT_CALL(*cancelCheckpoint, Call()) + .InSequence(this->subscription_seq_) + .WillOnce(Invoke([=] { + VLOG(3) << "SERVER: received cancel()"; + this->batons_.onCancelReceivedToclient.post(); + this->batons_.onCancelReceivedToserver.post(); + })); + + return generator + ->doOnRequest( + [requestCheckpoint](auto n) { requestCheckpoint->Call(n); }) + ->doOnCancel([cancelCheckpoint] { cancelCheckpoint->Call(); }); + } +}; + +TEST(RequestStreamTest, OperationsAfterCancel) { + LockstepBatons batons; + Sequence server_seq; + Sequence client_seq; + + auto server = + makeServer(std::make_shared(batons, server_seq)); + folly::ScopedEventBaseThread worker; + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto subscriber_mock = std::make_shared< + testing::StrictMock>>(0); + + std::shared_ptr subscription; + EXPECT_CALL(*subscriber_mock, onSubscribe_(_)) + .InSequence(client_seq) + .WillOnce(Invoke([&](auto s) { + VLOG(3) << "CLIENT: got onSubscribe(), sending request(2)"; + EXPECT_NE(s, nullptr); + subscription = s; + subscription->request(2); + })); + EXPECT_CALL(*subscriber_mock, onNext_("foo")) + .InSequence(client_seq) + .WillOnce(Invoke([&](auto) { + EXPECT_NE(subscription, nullptr); + VLOG(3) << "CLIENT: got onNext(foo), sending cancel()"; + subscription->cancel(); + batons.onCancelSent.post(); + batons.onCancelReceivedToclient.timed_wait(timeout); + batons.onSecondPayloadSent.timed_wait(timeout); + batons.clientFinished.post(); + })); + + // shouldn't receive 'bar', we canceled syncronously with the Subscriber + // had 'cancel' been called in a different thread with no synchronization, + // the client's Subscriber _could_ have received 'bar' + + VLOG(3) << "RUNNER: doing requestStream()"; + requester->requestStream(Payload("initial")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(subscriber_mock); + + batons.clientFinished.timed_wait(timeout); + batons.serverFinished.timed_wait(timeout); + VLOG(3) << "RUNNER: finished!"; +} diff --git a/rsocket/test/Test.cpp b/rsocket/test/Test.cpp new file mode 100644 index 000000000..512a2281e --- /dev/null +++ b/rsocket/test/Test.cpp @@ -0,0 +1,24 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +int main(int argc, char** argv) { + FLAGS_logtostderr = true; + testing::InitGoogleMock(&argc, argv); + folly::init(&argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/test/WarmResumeManagerTest.cpp b/rsocket/test/WarmResumeManagerTest.cpp similarity index 95% rename from test/WarmResumeManagerTest.cpp rename to rsocket/test/WarmResumeManagerTest.cpp index c5e69ce83..861860595 100644 --- a/test/WarmResumeManagerTest.cpp +++ b/rsocket/test/WarmResumeManagerTest.cpp @@ -1,6 +1,17 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include #include diff --git a/test/WarmResumptionTest.cpp b/rsocket/test/WarmResumptionTest.cpp similarity index 79% rename from test/WarmResumptionTest.cpp rename to rsocket/test/WarmResumptionTest.cpp index 3f21082dd..2481456db 100644 --- a/test/WarmResumptionTest.cpp +++ b/rsocket/test/WarmResumptionTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -32,8 +44,8 @@ TEST(WarmResumptionTest, SuccessfulResumption) { } auto result = client->disconnect(std::runtime_error("Test triggered disconnect")) - .then([&] { return client->resume(); }); - EXPECT_NO_THROW(result.get()); + .thenValue([&](auto&&) { return client->resume(); }); + EXPECT_NO_THROW(std::move(result).get()); ts->request(3); ts->awaitTerminalEvent(); ts->assertSuccess(); @@ -59,9 +71,10 @@ TEST(WarmResumptionTest, FailedResumption1) { } client->disconnect(std::runtime_error("Test triggered disconnect")) - .then([&] { return client->resume(); }) - .then([] { FAIL() << "Resumption succeeded when it should not"; }) - .onError([listeningPort, &worker](folly::exception_wrapper) { + .thenValue([&](auto&&) { return client->resume(); }) + .thenValue( + [](auto&&) { FAIL() << "Resumption succeeded when it should not"; }) + .thenError([listeningPort, &worker](folly::exception_wrapper) { folly::ScopedEventBaseThread worker2; auto newClient = makeWarmResumableClient(worker2.getEventBase(), listeningPort); @@ -106,10 +119,11 @@ TEST(WarmResumptionTest, FailedResumption2) { std::shared_ptr newClient; client->disconnect(std::runtime_error("Test triggered disconnect")) - .then([&] { return client->resume(); }) - .then([] { FAIL() << "Resumption succeeded when it should not"; }) - .onError([listeningPort, newTs, &newClient, &worker2]( - folly::exception_wrapper) { + .thenValue([&](auto&&) { return client->resume(); }) + .thenValue( + [](auto&&) { FAIL() << "Resumption succeeded when it should not"; }) + .thenError([listeningPort, newTs, &newClient, &worker2]( + folly::exception_wrapper) { newClient = makeWarmResumableClient(worker2.getEventBase(), listeningPort); newClient->getRequester() @@ -150,8 +164,8 @@ TEST(WarmResumptionTest, DifferentEvb) { } auto result = client->disconnect(std::runtime_error("Test triggered disconnect")) - .then([&] { return client->resume(); }); - EXPECT_NO_THROW(result.get()); + .thenValue([&](auto&&) { return client->resume(); }); + EXPECT_NO_THROW(std::move(result).get()); ts->request(3); ts->awaitTerminalEvent(); ts->assertSuccess(); diff --git a/test/framing/FrameTest.cpp b/rsocket/test/framing/FrameTest.cpp similarity index 74% rename from test/framing/FrameTest.cpp rename to rsocket/test/framing/FrameTest.cpp index 6bd478227..3efb79449 100644 --- a/test/framing/FrameTest.cpp +++ b/rsocket/test/framing/FrameTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -8,29 +20,13 @@ #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameSerializer.h" -using namespace ::testing; using namespace ::rsocket; -// TODO(stupaq): tests with malformed frames - -template -Frame reserialize_resume(bool resumable, Args... args) { - Frame givenFrame, newFrame; - givenFrame = Frame(std::forward(args)...); - auto frameSerializer = FrameSerializer::createFrameSerializer( - ProtocolVersion::Current()); - EXPECT_TRUE(frameSerializer->deserializeFrom( - newFrame, - frameSerializer->serializeOut(std::move(givenFrame), resumable), - resumable)); - return newFrame; -} - template Frame reserialize(Args... args) { Frame givenFrame = Frame(std::forward(args)...); - auto frameSerializer = FrameSerializer::createFrameSerializer( - ProtocolVersion::Current()); + auto frameSerializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); auto serializedFrame = frameSerializer->serializeOut(std::move(givenFrame)); Frame newFrame; EXPECT_TRUE( @@ -60,8 +56,8 @@ TEST(FrameTest, Frame_REQUEST_STREAM) { expectHeader(FrameType::REQUEST_STREAM, flags, streamId, frame); EXPECT_EQ(requestN, frame.requestN_); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_REQUEST_CHANNEL) { @@ -75,8 +71,8 @@ TEST(FrameTest, Frame_REQUEST_CHANNEL) { expectHeader(FrameType::REQUEST_CHANNEL, flags, streamId, frame); EXPECT_EQ(requestN, frame.requestN_); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_REQUEST_N) { @@ -84,14 +80,14 @@ TEST(FrameTest, Frame_REQUEST_N) { uint32_t requestN = 24; auto frame = reserialize(streamId, requestN); - expectHeader(FrameType::REQUEST_N, FrameFlags::EMPTY, streamId, frame); + expectHeader(FrameType::REQUEST_N, FrameFlags::EMPTY_, streamId, frame); EXPECT_EQ(requestN, frame.requestN_); } TEST(FrameTest, Frame_CANCEL) { uint32_t streamId = 42; auto frame = reserialize(streamId); - expectHeader(FrameType::CANCEL, FrameFlags::EMPTY, streamId, frame); + expectHeader(FrameType::CANCEL, FrameFlags::EMPTY_, streamId, frame); } TEST(FrameTest, Frame_PAYLOAD) { @@ -103,8 +99,8 @@ TEST(FrameTest, Frame_PAYLOAD) { streamId, flags, Payload(data->clone(), metadata->clone())); expectHeader(FrameType::PAYLOAD, flags, streamId, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_PAYLOAD_NoMeta) { @@ -116,7 +112,7 @@ TEST(FrameTest, Frame_PAYLOAD_NoMeta) { expectHeader(FrameType::PAYLOAD, flags, streamId, frame); EXPECT_FALSE(frame.payload_.metadata); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_ERROR) { @@ -130,22 +126,8 @@ TEST(FrameTest, Frame_ERROR) { expectHeader(FrameType::ERROR, flags, streamId, frame); EXPECT_EQ(errorCode, frame.errorCode_); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); -} - -TEST(FrameTest, Frame_KEEPALIVE_resume) { - uint32_t streamId = 0; - ResumePosition position = 101; - auto flags = FrameFlags::KEEPALIVE_RESPOND; - auto data = folly::IOBuf::copyBuffer("424242"); - auto frame = - reserialize_resume(true, flags, position, data->clone()); - - expectHeader( - FrameType::KEEPALIVE, FrameFlags::KEEPALIVE_RESPOND, streamId, frame); - EXPECT_EQ(position, frame.position_); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.data_)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_KEEPALIVE) { @@ -153,23 +135,22 @@ TEST(FrameTest, Frame_KEEPALIVE) { ResumePosition position = 101; auto flags = FrameFlags::KEEPALIVE_RESPOND; auto data = folly::IOBuf::copyBuffer("424242"); - auto frame = reserialize_resume( - false, flags, position, data->clone()); + auto frame = reserialize(flags, position, data->clone()); expectHeader( FrameType::KEEPALIVE, FrameFlags::KEEPALIVE_RESPOND, streamId, frame); // Default position - auto currProtVersion = ProtocolVersion::Current(); + auto currProtVersion = ProtocolVersion::Latest; if (currProtVersion == ProtocolVersion(0, 1)) { EXPECT_EQ(0, frame.position_); } else if (currProtVersion == ProtocolVersion(1, 0)) { EXPECT_EQ(position, frame.position_); } - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.data_)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.data_)); } TEST(FrameTest, Frame_SETUP) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; uint16_t versionMajor = 4; uint16_t versionMinor = 5; uint32_t keepaliveTime = Frame_SETUP::kMaxKeepaliveTime; @@ -196,11 +177,11 @@ TEST(FrameTest, Frame_SETUP) { EXPECT_EQ(ResumeIdentificationToken(), frame.token_); EXPECT_EQ("md", frame.metadataMimeType_); EXPECT_EQ("d", frame.dataMimeType_); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_SETUP_resume) { - FrameFlags flags = FrameFlags::EMPTY | FrameFlags::RESUME_ENABLE; + FrameFlags flags = FrameFlags::EMPTY_ | FrameFlags::RESUME_ENABLE; uint16_t versionMajor = 0; uint16_t versionMinor = 0; uint32_t keepaliveTime = Frame_SETUP::kMaxKeepaliveTime; @@ -226,11 +207,11 @@ TEST(FrameTest, Frame_SETUP_resume) { EXPECT_EQ(token, frame.token_); EXPECT_EQ("md", frame.metadataMimeType_); EXPECT_EQ("d", frame.dataMimeType_); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_LEASE) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; uint32_t ttl = Frame_LEASE::kMaxTtl; auto numberOfRequests = Frame_LEASE::kMaxNumRequests; auto frame = reserialize(ttl, numberOfRequests); @@ -249,8 +230,8 @@ TEST(FrameTest, Frame_REQUEST_RESPONSE) { streamId, flags, Payload(data->clone(), metadata->clone())); expectHeader(FrameType::REQUEST_RESPONSE, flags, streamId, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_REQUEST_FNF) { @@ -262,8 +243,8 @@ TEST(FrameTest, Frame_REQUEST_FNF) { streamId, flags, Payload(data->clone(), metadata->clone())); expectHeader(FrameType::REQUEST_FNF, flags, streamId, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_METADATA_PUSH) { @@ -272,11 +253,11 @@ TEST(FrameTest, Frame_METADATA_PUSH) { auto frame = reserialize(metadata->clone()); expectHeader(FrameType::METADATA_PUSH, flags, 0, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.metadata_)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.metadata_)); } TEST(FrameTest, Frame_RESUME) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; uint16_t versionMajor = 4; uint16_t versionMinor = 5; ResumeIdentificationToken token = ResumeIdentificationToken::generateNew(); @@ -300,10 +281,24 @@ TEST(FrameTest, Frame_RESUME) { } TEST(FrameTest, Frame_RESUME_OK) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; ResumePosition position = 6; auto frame = reserialize(position); expectHeader(FrameType::RESUME_OK, flags, 0, frame); EXPECT_EQ(position, frame.position_); } + +TEST(FrameTest, Frame_PreallocatedFrameLengthField) { + uint32_t streamId = 42; + FrameFlags flags = FrameFlags::COMPLETE; + auto data = folly::IOBuf::copyBuffer("424242"); + auto frameSerializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + frameSerializer->preallocateFrameSizeField() = true; + + auto frame = Frame_PAYLOAD(streamId, flags, Payload(data->clone())); + auto serializedFrame = frameSerializer->serializeOut(std::move(frame)); + + EXPECT_LT(0, serializedFrame->headroom()); +} diff --git a/test/framing/FrameTransportTest.cpp b/rsocket/test/framing/FrameTransportTest.cpp similarity index 63% rename from test/framing/FrameTransportTest.cpp rename to rsocket/test/framing/FrameTransportTest.cpp index ee6f3c67e..48017d682 100644 --- a/test/framing/FrameTransportTest.cpp +++ b/rsocket/test/framing/FrameTransportTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -15,7 +27,7 @@ namespace { * Compare a `const folly::IOBuf&` against a `const std::string&`. */ MATCHER_P(IOBufStringEq, s, "") { - return folly::IOBufEqual()(*arg, *folly::IOBuf::copyBuffer(s)); + return folly::IOBufEqualTo()(*arg, *folly::IOBuf::copyBuffer(s)); } } // namespace @@ -24,7 +36,7 @@ TEST(FrameTransport, Close) { auto connection = std::make_unique>(); EXPECT_CALL(*connection, setInput_(_)); - auto transport = yarpl::make_ref(std::move(connection)); + auto transport = std::make_shared(std::move(connection)); transport->setFrameProcessor( std::make_shared>()); transport->close(); @@ -37,7 +49,7 @@ TEST(FrameTransport, SimpleNoQueue) { EXPECT_CALL(*connection, send_(IOBufStringEq("Hello"))); EXPECT_CALL(*connection, send_(IOBufStringEq("World"))); - auto transport = yarpl::make_ref(std::move(connection)); + auto transport = std::make_shared(std::move(connection)); transport->setFrameProcessor( std::make_shared>()); @@ -52,7 +64,7 @@ TEST(FrameTransport, InputSendsError) { auto connection = std::make_unique>([](auto input) { auto subscription = - yarpl::make_ref>(); + std::make_shared>(); EXPECT_CALL(*subscription, request_(_)); EXPECT_CALL(*subscription, cancel_()); @@ -60,7 +72,7 @@ TEST(FrameTransport, InputSendsError) { input->onError(std::runtime_error("Oops")); }); - auto transport = yarpl::make_ref(std::move(connection)); + auto transport = std::make_shared(std::move(connection)); auto processor = std::make_shared>(); EXPECT_CALL(*processor, onTerminal_(_)); diff --git a/test/framing/FramedReaderTest.cpp b/rsocket/test/framing/FramedReaderTest.cpp similarity index 67% rename from test/framing/FramedReaderTest.cpp rename to rsocket/test/framing/FramedReaderTest.cpp index 41bcd4d19..d3b6f9e0c 100644 --- a/test/framing/FramedReaderTest.cpp +++ b/rsocket/test/framing/FramedReaderTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -11,7 +23,7 @@ using namespace yarpl::mocks; TEST(FramedReader, TinyFrame) { auto version = std::make_shared(ProtocolVersion::Latest); - auto reader = yarpl::make_ref(version); + auto reader = std::make_shared(version); // Not using hex string-literal as std::string ctor hits '\x00' and stops // reading. @@ -22,10 +34,10 @@ TEST(FramedReader, TinyFrame) { buf->writableData()[2] = '\x00'; buf->writableData()[3] = '\x02'; - reader->onSubscribe(yarpl::flowable::Subscription::empty()); + reader->onSubscribe(yarpl::flowable::Subscription::create()); reader->onNext(std::move(buf)); - auto subscriber = yarpl::make_ref< + auto subscriber = std::make_shared< StrictMock>>>(); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onError_(_)); @@ -37,14 +49,14 @@ TEST(FramedReader, TinyFrame) { TEST(FramedReader, CantDetectVersion) { auto version = std::make_shared(ProtocolVersion::Unknown); - auto reader = yarpl::make_ref(version); + auto reader = std::make_shared(version); auto buf = folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP"); - reader->onSubscribe(yarpl::flowable::Subscription::empty()); + reader->onSubscribe(yarpl::flowable::Subscription::create()); reader->onNext(std::move(buf)); - auto subscriber = yarpl::make_ref< + auto subscriber = std::make_shared< StrictMock>>>(); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onError_(_)); @@ -56,15 +68,15 @@ TEST(FramedReader, CantDetectVersion) { TEST(FramedReader, SubscriberCompleteAfterError) { auto version = std::make_shared(ProtocolVersion::Latest); - auto reader = yarpl::make_ref(version); + auto reader = std::make_shared(version); - auto subscription = yarpl::make_ref>(); + auto subscription = std::make_shared>(); EXPECT_CALL(*subscription, request_(_)); EXPECT_CALL(*subscription, cancel_()); reader->onSubscribe(subscription); - auto subscriber = yarpl::make_ref< + auto subscriber = std::make_shared< StrictMock>>>(); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onError_(_)) @@ -79,15 +91,15 @@ TEST(FramedReader, SubscriberCompleteAfterError) { TEST(FramedReader, SubscriberErrorAfterError) { auto version = std::make_shared(ProtocolVersion::Latest); - auto reader = yarpl::make_ref(version); + auto reader = std::make_shared(version); - auto subscription = yarpl::make_ref>(); + auto subscription = std::make_shared>(); EXPECT_CALL(*subscription, request_(_)); EXPECT_CALL(*subscription, cancel_()); reader->onSubscribe(subscription); - auto subscriber = yarpl::make_ref< + auto subscriber = std::make_shared< StrictMock>>>(); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onError_(_)) diff --git a/rsocket/test/framing/FramerTest.cpp b/rsocket/test/framing/FramerTest.cpp new file mode 100644 index 000000000..81fe97ffd --- /dev/null +++ b/rsocket/test/framing/FramerTest.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/Framer.h" +#include +#include + +using namespace rsocket; +using namespace testing; + +class FramerMock : public Framer { + public: + explicit FramerMock(ProtocolVersion protocolVersion = ProtocolVersion::Latest) + : Framer(protocolVersion, true) {} + + MOCK_METHOD1(error, void(const char*)); + MOCK_METHOD1(onFrame_, void(std::unique_ptr&)); + + void onFrame(std::unique_ptr frame) override { + onFrame_(frame); + } +}; + +MATCHER_P(isIOBuffEq, n, "") { + return folly::IOBufEqualTo()(arg, n); +} + +TEST(Framer, TinyFrame) { + FramerMock framer; + + // Not using hex string-literal as std::string ctor hits '\x00' and stops + // reading. + auto buf = folly::IOBuf::createCombined(4); + buf->append(4); + buf->writableData()[0] = '\x00'; + buf->writableData()[1] = '\x00'; + buf->writableData()[2] = '\x00'; + buf->writableData()[3] = '\x02'; + + EXPECT_CALL(framer, error(_)); + framer.addFrameChunk(std::move(buf)); +} + +TEST(Framer, CantDetectVersion) { + FramerMock framer(ProtocolVersion::Unknown); + + EXPECT_CALL(framer, error(_)); + + auto buf = folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP"); + framer.addFrameChunk(std::move(buf)); +} + +TEST(Framer, ParseFrame) { + FramerMock framer; + + auto buf = folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP"); + EXPECT_CALL(framer, onFrame_(Pointee(isIOBuffEq(*buf)))); + + framer.addFrameChunk(framer.prependSize(std::move(buf))); +} diff --git a/test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 b/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 similarity index 100% rename from test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 rename to rsocket/test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 diff --git a/test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 b/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 similarity index 100% rename from test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 rename to rsocket/test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 diff --git a/test/fuzzers/frame_fuzzer.cpp b/rsocket/test/fuzzers/frame_fuzzer.cpp similarity index 78% rename from test/fuzzers/frame_fuzzer.cpp rename to rsocket/test/fuzzers/frame_fuzzer.cpp index fa3b1814a..eb9d04efc 100644 --- a/test/fuzzers/frame_fuzzer.cpp +++ b/rsocket/test/fuzzers/frame_fuzzer.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include @@ -30,11 +42,10 @@ struct FuzzerConnectionAcceptor : rsocket::ConnectionAcceptor { struct FuzzerDuplexConnection : rsocket::DuplexConnection { using Subscriber = rsocket::DuplexConnection::Subscriber; - using DuplexSubscriber = rsocket::DuplexConnection::DuplexSubscriber; FuzzerDuplexConnection() {} - void setInput(yarpl::Reference sub) override { + void setInput(std::shared_ptr sub) override { VLOG(1) << "FuzzerDuplexConnection::setInput()" << std::endl; input_sub = sub; } @@ -44,7 +55,7 @@ struct FuzzerDuplexConnection : rsocket::DuplexConnection { << folly::humanify(buf->moveToFbString()) << "\")" << std::endl; } - yarpl::Reference input_sub; + std::shared_ptr input_sub; }; struct NoopSubscription : yarpl::flowable::Subscription { @@ -91,7 +102,7 @@ int main(int argc, char* argv[]) { evb.loopOnce(); CHECK(input_sub); - auto input_subscription = yarpl::make_ref(); + auto input_subscription = std::make_shared(); input_sub->onSubscribe(input_subscription); std::string fuzz_input = get_stdin(); diff --git a/test/handlers/HelloServiceHandler.cpp b/rsocket/test/handlers/HelloServiceHandler.cpp similarity index 59% rename from test/handlers/HelloServiceHandler.cpp rename to rsocket/test/handlers/HelloServiceHandler.cpp index 130f7a79e..aa7e07ca5 100644 --- a/test/handlers/HelloServiceHandler.cpp +++ b/rsocket/test/handlers/HelloServiceHandler.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/test/handlers/HelloServiceHandler.h" #include "rsocket/test/handlers/HelloStreamRequestHandler.h" diff --git a/test/handlers/HelloServiceHandler.h b/rsocket/test/handlers/HelloServiceHandler.h similarity index 61% rename from test/handlers/HelloServiceHandler.h rename to rsocket/test/handlers/HelloServiceHandler.h index 2802c9b19..55cf86d53 100644 --- a/test/handlers/HelloServiceHandler.h +++ b/rsocket/test/handlers/HelloServiceHandler.h @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include #include "rsocket/RSocketServiceHandler.h" namespace rsocket { diff --git a/rsocket/test/handlers/HelloStreamRequestHandler.cpp b/rsocket/test/handlers/HelloStreamRequestHandler.cpp new file mode 100644 index 000000000..a4bf8a34d --- /dev/null +++ b/rsocket/test/handlers/HelloStreamRequestHandler.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "HelloStreamRequestHandler.h" +#include +#include +#include "yarpl/Flowable.h" + +using namespace yarpl::flowable; + +namespace rsocket { +namespace tests { +/// Handles a new inbound Stream requested by the other end. +std::shared_ptr> +HelloStreamRequestHandler::handleRequestStream( + rsocket::Payload request, + rsocket::StreamId) { + VLOG(3) << "HelloStreamRequestHandler.handleRequestStream " << request; + + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 10)->map( + [name = std::move(requestString)](int64_t v) { + return Payload(folly::to(v), "metadata"); + }); +} +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/handlers/HelloStreamRequestHandler.h b/rsocket/test/handlers/HelloStreamRequestHandler.h new file mode 100644 index 000000000..3aa48fb08 --- /dev/null +++ b/rsocket/test/handlers/HelloStreamRequestHandler.h @@ -0,0 +1,31 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/RSocketResponder.h" +#include "yarpl/Flowable.h" + +namespace rsocket { +namespace tests { + +class HelloStreamRequestHandler : public RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) + override; +}; +} // namespace tests +} // namespace rsocket diff --git a/test/internal/AllowanceTest.cpp b/rsocket/test/internal/AllowanceTest.cpp similarity index 63% rename from test/internal/AllowanceTest.cpp rename to rsocket/test/internal/AllowanceTest.cpp index a619266d0..d2e77a36d 100644 --- a/test/internal/AllowanceTest.cpp +++ b/rsocket/test/internal/AllowanceTest.cpp @@ -1,10 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "rsocket/internal/Allowance.h" #include #include -#include "rsocket/internal/Allowance.h" -using namespace ::testing; using namespace ::rsocket; TEST(AllowanceTest, Finite) { diff --git a/test/internal/ConnectionSetTest.cpp b/rsocket/test/internal/ConnectionSetTest.cpp similarity index 56% rename from test/internal/ConnectionSetTest.cpp rename to rsocket/test/internal/ConnectionSetTest.cpp index fe98d631b..ea0e215aa 100644 --- a/test/internal/ConnectionSetTest.cpp +++ b/rsocket/test/internal/ConnectionSetTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -12,7 +24,6 @@ #include "rsocket/statemachine/RSocketStateMachine.h" using namespace rsocket; -using namespace testing; namespace { @@ -23,11 +34,11 @@ std::shared_ptr makeStateMachine(folly::EventBase* evb) { RSocketMode::SERVER, RSocketStats::noop(), std::make_shared(), - nullptr /* resumeManager */, + ResumeManager::makeEmpty(), nullptr /* coldResumeHandler */ - ); -} + ); } +} // namespace TEST(ConnectionSet, ImmediateDtor) { ConnectionSet set; @@ -37,9 +48,9 @@ TEST(ConnectionSet, CloseViaMachine) { folly::EventBase evb; auto machine = makeStateMachine(&evb); - auto set = std::make_shared(); - set->insert(machine, &evb); - machine->registerSet(set); + ConnectionSet set; + set.insert(machine, &evb); + machine->registerCloseCallback(&set); machine->close({}, StreamCompletionSignal::CANCEL); } @@ -48,7 +59,7 @@ TEST(ConnectionSet, CloseViaSetDtor) { folly::EventBase evb; auto machine = makeStateMachine(&evb); - auto set = std::make_shared(); - set->insert(machine, &evb); - machine->registerSet(set); + ConnectionSet set; + set.insert(machine, &evb); + machine->registerCloseCallback(&set); } diff --git a/test/internal/KeepaliveTimerTest.cpp b/rsocket/test/internal/KeepaliveTimerTest.cpp similarity index 65% rename from test/internal/KeepaliveTimerTest.cpp rename to rsocket/test/internal/KeepaliveTimerTest.cpp index ff8ec00fa..e721691a3 100644 --- a/test/internal/KeepaliveTimerTest.cpp +++ b/rsocket/test/internal/KeepaliveTimerTest.cpp @@ -1,6 +1,17 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include #include #include @@ -19,7 +30,7 @@ class MockConnectionAutomaton : public FrameSink { public: // MOCK_METHOD doesn't take functions with unique_ptr args. // A workaround for sendKeepalive method. - virtual void sendKeepalive(std::unique_ptr b) override { + void sendKeepalive(std::unique_ptr b) override { sendKeepalive_(b); } MOCK_METHOD1(sendKeepalive_, void(std::unique_ptr&)); @@ -30,7 +41,7 @@ class MockConnectionAutomaton : public FrameSink { disconnectOrCloseWithError_(error); } }; -} +} // namespace TEST(FollyKeepaliveTimerTest, StartStopWithResponse) { auto connectionAutomaton = @@ -44,11 +55,11 @@ TEST(FollyKeepaliveTimerTest, StartStopWithResponse) { timer.start(connectionAutomaton); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); timer.keepaliveReceived(); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); timer.stop(); } @@ -66,9 +77,9 @@ TEST(FollyKeepaliveTimerTest, NoResponse) { timer.start(connectionAutomaton); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); timer.stop(); } diff --git a/rsocket/test/internal/ResumeIdentificationToken.cpp b/rsocket/test/internal/ResumeIdentificationToken.cpp new file mode 100644 index 000000000..48d4b172c --- /dev/null +++ b/rsocket/test/internal/ResumeIdentificationToken.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "rsocket/framing/ResumeIdentificationToken.h" + +using namespace rsocket; + +TEST(ResumeIdentificationTokenTest, Conversion) { + for (int i = 0; i < 10; i++) { + auto token = ResumeIdentificationToken::generateNew(); + auto token2 = ResumeIdentificationToken(token.str()); + CHECK_EQ(token, token2); + CHECK_EQ(token.str(), token2.str()); + } +} diff --git a/test/internal/SetupResumeAcceptorTest.cpp b/rsocket/test/internal/SetupResumeAcceptorTest.cpp similarity index 61% rename from test/internal/SetupResumeAcceptorTest.cpp rename to rsocket/test/internal/SetupResumeAcceptorTest.cpp index 6c8f0a95c..5365f4964 100644 --- a/test/internal/SetupResumeAcceptorTest.cpp +++ b/rsocket/test/internal/SetupResumeAcceptorTest.cpp @@ -1,9 +1,22 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include +#include "rsocket/framing/FrameSerializer.h" #include "rsocket/framing/FrameTransportImpl.h" #include "rsocket/internal/SetupResumeAcceptor.h" #include "rsocket/test/test_utils/MockDuplexConnection.h" @@ -19,10 +32,10 @@ namespace { * Make a legitimate-looking SETUP frame. */ Frame_SETUP makeSetup() { - auto version = ProtocolVersion::Current(); + auto version = ProtocolVersion::Latest; Frame_SETUP frame; - frame.header_ = FrameHeader{FrameType::SETUP, FrameFlags::EMPTY, 0}; + frame.header_ = FrameHeader{FrameType::SETUP, FrameFlags::EMPTY_, 0}; frame.versionMajor_ = version.major; frame.versionMinor_ = version.minor; frame.keepaliveTime_ = Frame_SETUP::kMaxKeepaliveTime; @@ -39,7 +52,7 @@ Frame_SETUP makeSetup() { */ Frame_RESUME makeResume() { Frame_RESUME frame; - frame.header_ = FrameHeader{FrameType::RESUME, FrameFlags::EMPTY, 0}; + frame.header_ = FrameHeader{FrameType::RESUME, FrameFlags::EMPTY_, 0}; frame.versionMajor_ = 1; frame.versionMinor_ = 0; frame.token_ = ResumeIdentificationToken::generateNew(); @@ -48,15 +61,12 @@ Frame_RESUME makeResume() { return frame; } -void setupFail(yarpl::Reference transport, SetupParameters) { - transport->close(); +void setupFail(std::unique_ptr, SetupParameters) { FAIL() << "setupFail() was called"; } -bool resumeFail(yarpl::Reference transport, ResumeParameters) { - transport->close(); - ADD_FAILURE() << "resumeFail() was called"; - return false; +void resumeFail(std::unique_ptr, ResumeParameters) { + FAIL() << "resumeFail() was called"; } } // namespace @@ -75,12 +85,12 @@ TEST(SetupResumeAcceptor, CloseWithActiveConnection) { folly::EventBase evb; SetupResumeAcceptor acceptor{&evb}; - yarpl::Reference outerInput; + std::shared_ptr outerInput; auto connection = std::make_unique>([&](auto input) { outerInput = input; - input->onSubscribe(yarpl::flowable::Subscription::empty()); + input->onSubscribe(yarpl::flowable::Subscription::create()); }); ON_CALL(*connection, send_(_)).WillByDefault(Invoke([](auto&) { FAIL(); })); @@ -101,7 +111,7 @@ TEST(SetupResumeAcceptor, EarlyComplete) { auto connection = std::make_unique>([](auto input) { - input->onSubscribe(yarpl::flowable::Subscription::empty()); + input->onSubscribe(yarpl::flowable::Subscription::create()); input->onComplete(); }); @@ -116,7 +126,7 @@ TEST(SetupResumeAcceptor, EarlyError) { auto connection = std::make_unique>([](auto input) { - input->onSubscribe(yarpl::flowable::Subscription::empty()); + input->onSubscribe(yarpl::flowable::Subscription::create()); input->onError(std::runtime_error("Whoops")); }); @@ -132,8 +142,8 @@ TEST(SetupResumeAcceptor, SingleSetup) { auto connection = std::make_unique>([](auto input) { auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); - input->onSubscribe(yarpl::flowable::Subscription::empty()); + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + input->onSubscribe(yarpl::flowable::Subscription::create()); input->onNext(serializer->serializeOut(makeSetup())); input->onComplete(); }); @@ -142,10 +152,7 @@ TEST(SetupResumeAcceptor, SingleSetup) { acceptor.accept( std::move(connection), - [&](auto transport, auto) { - transport->close(); - setupCalled = true; - }, + [&](auto, auto) { setupCalled = true; }, resumeFail); evb.loop(); @@ -160,20 +167,20 @@ TEST(SetupResumeAcceptor, InvalidSetup) { auto connection = std::make_unique>([](auto input) { auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); // Bogus keepalive time that can't be deserialized. auto setup = makeSetup(); setup.keepaliveTime_ = -5; - input->onSubscribe(yarpl::flowable::Subscription::empty()); + input->onSubscribe(yarpl::flowable::Subscription::create()); input->onNext(serializer->serializeOut(std::move(setup))); input->onComplete(); }); EXPECT_CALL(*connection, send_(_)).WillOnce(Invoke([](auto& buf) { auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); Frame_ERROR frame; EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); EXPECT_EQ(frame.errorCode_, ErrorCode::CONNECTION_ERROR); @@ -188,18 +195,19 @@ TEST(SetupResumeAcceptor, RejectedSetup) { folly::EventBase evb; SetupResumeAcceptor acceptor{&evb}; + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + auto connection = - std::make_unique>([](auto input) { - auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); - input->onSubscribe(yarpl::flowable::Subscription::empty()); + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); input->onNext(serializer->serializeOut(makeSetup())); input->onComplete(); }); EXPECT_CALL(*connection, send_(_)).WillOnce(Invoke([](auto& buf) { auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); Frame_ERROR frame; EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); EXPECT_EQ(frame.errorCode_, ErrorCode::REJECTED_SETUP); @@ -209,9 +217,10 @@ TEST(SetupResumeAcceptor, RejectedSetup) { acceptor.accept( std::move(connection), - [&](auto, auto) { + [&](std::unique_ptr connection, auto) { setupCalled = true; - throw std::runtime_error("Oops"); + connection->send( + serializer->serializeOut(Frame_ERROR::rejectedSetup("Oops"))); }, resumeFail); @@ -224,18 +233,19 @@ TEST(SetupResumeAcceptor, RejectedResume) { folly::EventBase evb; SetupResumeAcceptor acceptor{&evb}; + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + auto connection = - std::make_unique>([](auto input) { - auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); - input->onSubscribe(yarpl::flowable::Subscription::empty()); + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); input->onNext(serializer->serializeOut(makeResume())); input->onComplete(); }); EXPECT_CALL(*connection, send_(_)).WillOnce(Invoke([](auto& buf) { auto serializer = - FrameSerializer::createFrameSerializer(ProtocolVersion::Current()); + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); Frame_ERROR frame; EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); EXPECT_EQ(frame.errorCode_, ErrorCode::REJECTED_RESUME); @@ -243,12 +253,62 @@ TEST(SetupResumeAcceptor, RejectedResume) { bool resumeCalled = false; - acceptor.accept(std::move(connection), setupFail, [&](auto, auto) { - resumeCalled = true; - throw std::runtime_error("Cant resume"); - }); + acceptor.accept( + std::move(connection), + setupFail, + [&](std::unique_ptr connection, auto) { + resumeCalled = true; + connection->send(serializer->serializeOut( + Frame_ERROR::rejectedResume("Cant resume"))); + }); evb.loop(); EXPECT_TRUE(resumeCalled); } + +TEST(SetupResumeAcceptor, SetupBadVersion) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + auto connection = + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + + auto setup = makeSetup(); + setup.versionMajor_ = 57; + setup.versionMinor_ = 39; + + input->onNext(serializer->serializeOut(std::move(setup))); + input->onComplete(); + }); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + evb.loop(); +} + +TEST(SetupResumeAcceptor, ResumeBadVersion) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + auto connection = + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + + auto resume = makeResume(); + resume.versionMajor_ = 57; + resume.versionMinor_ = 39; + + input->onNext(serializer->serializeOut(std::move(resume))); + input->onComplete(); + }); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + evb.loop(); +} diff --git a/test/internal/SwappableEventBaseTest.cpp b/rsocket/test/internal/SwappableEventBaseTest.cpp similarity index 76% rename from test/internal/SwappableEventBaseTest.cpp rename to rsocket/test/internal/SwappableEventBaseTest.cpp index f80a0eb9d..b02c7f391 100644 --- a/test/internal/SwappableEventBaseTest.cpp +++ b/rsocket/test/internal/SwappableEventBaseTest.cpp @@ -1,6 +1,17 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include #include @@ -20,7 +31,7 @@ struct DidExecTracker { const std::string file; const std::string name; DidExecTracker(int line, std::string file, std::string name) - : line(line), file(file), name(name) {} + : line(line), file(file), name(name) {} MOCK_METHOD0(mark, void()); }; @@ -28,10 +39,18 @@ struct DETMarkedOnce : public ::testing::CardinalityInterface { explicit DETMarkedOnce(DidExecTracker const& det) : det(det) {} DidExecTracker const& det; - int ConservativeLowerBound() const override { return 1; } - int ConservativeUpperBound() const override { return 1; } - bool IsSatisfiedByCallCount(int cc) const override { return cc == 1; } - bool IsSaturatedByCallCount(int cc) const override { return cc == 1; } + int ConservativeLowerBound() const override { + return 1; + } + int ConservativeUpperBound() const override { + return 1; + } + bool IsSatisfiedByCallCount(int cc) const override { + return cc == 1; + } + bool IsSaturatedByCallCount(int cc) const override { + return cc == 1; + } void DescribeTo(std::ostream* os) const override { *os << "is called exactly once on "; @@ -43,19 +62,19 @@ ::testing::Cardinality MarkedOnce(DidExecTracker const& det) { } class SwappableEbTest : public ::testing::Test { -public: + public: std::vector> ebs; std::vector> did_exec_trackers; void loop_ebs() { { ::testing::InSequence s; - for(auto tracker : did_exec_trackers) { + for (auto tracker : did_exec_trackers) { EXPECT_CALL(*tracker, mark()).Times(MarkedOnce(*tracker)); } } - for(auto& eb : ebs) { + for (auto& eb : ebs) { ASSERT_TRUE(eb->loop()); } @@ -64,10 +83,9 @@ class SwappableEbTest : public ::testing::Test { } std::shared_ptr make_did_exec_tracker_impl( - int line, - std::string const& file, - std::string const& name - ) { + int line, + std::string const& file, + std::string const& name) { did_exec_trackers.emplace_back(new DidExecTracker(line, file, name)); return did_exec_trackers.back(); } diff --git a/rsocket/test/statemachine/RSocketStateMachineTest.cpp b/rsocket/test/statemachine/RSocketStateMachineTest.cpp new file mode 100644 index 000000000..9d9909398 --- /dev/null +++ b/rsocket/test/statemachine/RSocketStateMachineTest.cpp @@ -0,0 +1,430 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/RSocketStateMachine.h" +#include +#include +#include +#include +#include +#include "rsocket/RSocketConnectionEvents.h" +#include "rsocket/RSocketResponder.h" +#include "rsocket/framing/FrameSerializer_v1_0.h" +#include "rsocket/framing/FrameTransportImpl.h" +#include "rsocket/internal/Common.h" +#include "rsocket/statemachine/ChannelRequester.h" +#include "rsocket/statemachine/ChannelResponder.h" +#include "rsocket/statemachine/RequestResponseResponder.h" +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace testing; +using namespace yarpl::mocks; +using namespace yarpl::single; + +namespace rsocket { + +class ResponderMock : public RSocketResponder { + public: + MOCK_METHOD1( + handleRequestResponse_, + std::shared_ptr>(StreamId)); + MOCK_METHOD1( + handleRequestStream_, + std::shared_ptr>(StreamId)); + MOCK_METHOD2( + handleRequestChannel_, + std::shared_ptr>( + std::shared_ptr> requestStream, + StreamId streamId)); + + std::shared_ptr> handleRequestResponse(Payload, StreamId id) + override { + return handleRequestResponse_(id); + } + + std::shared_ptr> handleRequestStream( + Payload, + StreamId id) override { + return handleRequestStream_(id); + } + + std::shared_ptr> handleRequestChannel( + Payload, + std::shared_ptr> requestStream, + StreamId streamId) override { + return handleRequestChannel_(requestStream, streamId); + } +}; + +struct ConnectionEventsMock : public RSocketConnectionEvents { + MOCK_METHOD1(onDisconnected, void(const folly::exception_wrapper&)); + MOCK_METHOD0(onStreamsPaused, void()); +}; + +class RSocketStateMachineTest : public Test { + public: + auto createClient( + std::unique_ptr connection, + std::shared_ptr responder) { + EXPECT_CALL(*connection, setInput_(_)); + EXPECT_CALL(*connection, isFramed()); + + auto transport = + std::make_shared(std::move(connection)); + + auto stateMachine = std::make_shared( + std::move(responder), + nullptr, + RSocketMode::CLIENT, + nullptr, + nullptr, + ResumeManager::makeEmpty(), + nullptr); + + SetupParameters setupParameters; + setupParameters.resumable = false; // Not resumable! + stateMachine->connectClient( + std::move(transport), std::move(setupParameters)); + + return stateMachine; + } + + auto createServer( + std::unique_ptr connection, + std::shared_ptr responder, + folly::Optional resumeToken = folly::none, + std::shared_ptr connectionEvents = nullptr) { + auto transport = + std::make_shared(std::move(connection)); + + auto stateMachine = std::make_shared( + std::move(responder), + nullptr, + RSocketMode::SERVER, + nullptr, + std::move(connectionEvents), + ResumeManager::makeEmpty(), + nullptr); + + if (resumeToken) { + SetupParameters setupParameters; + setupParameters.resumable = true; + setupParameters.token = *resumeToken; + stateMachine->connectServer(std::move(transport), setupParameters); + } else { + SetupParameters setupParameters; + setupParameters.resumable = false; + stateMachine->connectServer(std::move(transport), setupParameters); + } + + return stateMachine; + } + + const std::unordered_map>& + getStreams(RSocketStateMachine& stateMachine) { + return stateMachine.streams_; + } + + void setupRequestStream( + RSocketStateMachine& stateMachine, + StreamId streamId, + uint32_t requestN, + Payload payload) { + stateMachine.onRequestStreamFrame( + streamId, requestN, std::move(payload), false); + } + + void setupRequestChannel( + RSocketStateMachine& stateMachine, + StreamId streamId, + uint32_t requestN, + Payload payload) { + stateMachine.onRequestChannelFrame( + streamId, requestN, std::move(payload), false, true, false); + } + + void setupRequestResponse( + RSocketStateMachine& stateMachine, + StreamId streamId, + Payload payload) { + stateMachine.onRequestResponseFrame(streamId, std::move(payload), false); + } + + void setupFireAndForget( + RSocketStateMachine& stateMachine, + StreamId streamId, + Payload payload) { + stateMachine.onFireAndForgetFrame(streamId, std::move(payload), false); + } +}; + +TEST_F(RSocketStateMachineTest, RequestStream) { + auto connection = std::make_unique>(); + // Setup frame and request stream frame + EXPECT_CALL(*connection, send_(_)).Times(2); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto subscriber = std::make_shared>>(1000); + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onComplete_()); + + stateMachine->requestStream(Payload{}, subscriber); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(1, streams.size()); + + // This line causes: subscriber.onComplete() + streams.at(1)->endStream(StreamCompletionSignal::CANCEL); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RequestStream_EarlyClose) { + auto connection = std::make_unique>(); + // Setup frame, two request stream frames, one extra frame + EXPECT_CALL(*connection, send_(_)).Times(3); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto subscriber = std::make_shared>>(1000); + EXPECT_CALL(*subscriber, onSubscribe_(_)).Times(2); + EXPECT_CALL(*subscriber, onComplete_()); + + stateMachine->requestStream(Payload{}, subscriber); + + // Second stream + stateMachine->requestStream(Payload{}, subscriber); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(2, streams.size()); + + // Close the stream + auto writer = std::dynamic_pointer_cast(stateMachine); + writer->onStreamClosed(1); + + // Push more data to the closed stream + auto processor = std::dynamic_pointer_cast(stateMachine); + FrameSerializerV1_0 serializer; + processor->processFrame( + serializer.serializeOut(Frame_PAYLOAD(1, FrameFlags::COMPLETE, {}))); + + // Second stream should still be valid + ASSERT_EQ(1, streams.size()); + + streams.at(3)->endStream(StreamCompletionSignal::CANCEL); + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RequestChannel) { + auto connection = std::make_unique>(); + // Setup frame and request channel frame + EXPECT_CALL(*connection, send_(_)).Times(2); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto in = std::make_shared>>(1000); + EXPECT_CALL(*in, onSubscribe_(_)); + EXPECT_CALL(*in, onComplete_()); + + auto out = stateMachine->requestChannel(Payload{}, true, in); + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + out->onSubscribe(subscription); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(1, streams.size()); + + // This line causes: in.onComplete() and outSubscription.cancel() + streams.at(1)->endStream(StreamCompletionSignal::CANCEL); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RequestResponse) { + auto connection = std::make_unique>(); + // Setup frame and request channel frame + EXPECT_CALL(*connection, send_(_)).Times(2); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto in = std::make_shared>(); + stateMachine->requestResponse(Payload{}, in); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(1, streams.size()); + + // This line closes the stream + streams.at(1)->handlePayload(Payload{"test", "123"}, true, false, false); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RespondStream) { + auto connection = std::make_unique>(); + int requestCount = 5; + // Payload frames plus a SETUP frame and an ERROR frame + EXPECT_CALL(*connection, send_(_)).Times(requestCount + 2); + + int sendCount = 0; + auto responder = std::make_shared>(); + EXPECT_CALL(*responder, handleRequestStream_(_)) + .WillOnce(Return( + yarpl::flowable::Flowable::fromGenerator([&sendCount]() { + ++sendCount; + return Payload{}; + }))); + + auto stateMachine = createClient(std::move(connection), responder); + setupRequestStream(*stateMachine, 2, requestCount, Payload{}); + EXPECT_EQ(requestCount, sendCount); + + auto& streams = getStreams(*stateMachine); + EXPECT_EQ(1, streams.size()); + + // releases connection and the responder + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RespondChannel) { + auto connection = std::make_unique>(); + int requestCount = 5; + // + the cancel frame when the stateMachine gets destroyed + EXPECT_CALL(*connection, send_(_)).Times(requestCount + 1); + + int sendCount = 0; + auto responder = std::make_shared>(); + EXPECT_CALL(*responder, handleRequestChannel_(_, _)) + .WillOnce(Return( + yarpl::flowable::Flowable::fromGenerator([&sendCount]() { + ++sendCount; + return Payload{}; + }))); + + auto stateMachine = createClient(std::move(connection), responder); + setupRequestChannel(*stateMachine, 2, requestCount, Payload{}); + EXPECT_EQ(requestCount, sendCount); + + auto& streams = getStreams(*stateMachine); + EXPECT_EQ(1, streams.size()); + + // releases connection and the responder + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RespondRequest) { + auto connection = std::make_unique>(); + EXPECT_CALL(*connection, send_(_)).Times(2); + + int sendCount = 0; + auto responder = std::make_shared>(); + EXPECT_CALL(*responder, handleRequestResponse_(_)) + .WillOnce(Return(Singles::fromGenerator([&sendCount]() { + ++sendCount; + return Payload{}; + }))); + + auto stateMachine = createClient(std::move(connection), responder); + setupRequestResponse(*stateMachine, 2, Payload{}); + EXPECT_EQ(sendCount, 1); + + auto& streams = getStreams(*stateMachine); + EXPECT_EQ(0, streams.size()); // already completed + + // releases connection and the responder + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, StreamImmediateCancel) { + auto connection = std::make_unique>(); + // Only send a SETUP frame. A REQUEST_STREAM frame should never be sent. + EXPECT_CALL(*connection, send_(_)); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto subscriber = std::make_shared>>(); + EXPECT_CALL(*subscriber, onSubscribe_(_)) + .WillOnce(Invoke( + [](std::shared_ptr subscription) { + subscription->cancel(); + })); + + stateMachine->requestStream(Payload{}, subscriber); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(0, streams.size()); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, TransportOnNextClose) { + auto connection = std::make_unique>(); + // Only SETUP frame gets sent. + EXPECT_CALL(*connection, setInput_(_)); + EXPECT_CALL(*connection, isFramed()); + EXPECT_CALL(*connection, send_(_)); + + auto transport = std::make_shared(std::move(connection)); + auto stateMachine = std::make_shared( + std::make_shared>(), + nullptr, + RSocketMode::CLIENT, + nullptr, + nullptr, + ResumeManager::makeEmpty(), + nullptr); + + SetupParameters params; + params.resumable = false; + stateMachine->connectClient(transport, std::move(params)); + + auto rawTransport = transport.get(); + + // Leak the cycle. + stateMachine.reset(); + transport.reset(); + + FrameSerializerV1_0 serializer; + auto buf = serializer.serializeOut(Frame_ERROR::connectionError("Hah!")); + rawTransport->onNext(std::move(buf)); +} + +TEST_F(RSocketStateMachineTest, ResumeWithCurrentConnection) { + auto resumeToken = ResumeIdentificationToken::generateNew(); + + auto eventsMock = std::make_shared(); + auto stateMachine = createServer( + std::make_unique>(), + std::make_shared(), + resumeToken, + eventsMock); + + EXPECT_CALL(*eventsMock, onDisconnected(_)).Times(0); + EXPECT_CALL(*eventsMock, onStreamsPaused()).Times(0); + + ResumeParameters resumeParams{resumeToken, 0, 0, ProtocolVersion::Latest}; + auto transport = std::make_shared( + std::make_unique>()); + stateMachine->resumeServer(transport, resumeParams); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +} // namespace rsocket diff --git a/rsocket/test/statemachine/StreamResponderTest.cpp b/rsocket/test/statemachine/StreamResponderTest.cpp new file mode 100644 index 000000000..5c57937ff --- /dev/null +++ b/rsocket/test/statemachine/StreamResponderTest.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "rsocket/statemachine/StreamResponder.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::mocks; + +TEST(StreamResponder, OnComplete) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(3); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onComplete(); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, OnError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onError(std::runtime_error{"Test"}); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, HandleError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->handleError(std::runtime_error("Test")); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, HandleCancel) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->handleCancel(); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, EndStream) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->endStream(StreamCompletionSignal::SOCKET_CLOSED); + ASSERT_TRUE(responder->publisherClosed()); +} diff --git a/rsocket/test/statemachine/StreamStateTest.cpp b/rsocket/test/statemachine/StreamStateTest.cpp new file mode 100644 index 000000000..3786708a9 --- /dev/null +++ b/rsocket/test/statemachine/StreamStateTest.cpp @@ -0,0 +1,332 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "rsocket/internal/Common.h" +#include "rsocket/statemachine/ChannelRequester.h" +#include "rsocket/statemachine/ChannelResponder.h" +#include "rsocket/statemachine/StreamStateMachineBase.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::mocks; + +class TestStreamStateMachineBase : public StreamStateMachineBase { + public: + using StreamStateMachineBase::StreamStateMachineBase; + void handlePayload(Payload&&, bool, bool, bool) override { + // ignore... + } +}; + +// @see github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel +TEST(StreamState, NewStateMachineBase) { + auto writer = std::make_shared>(); + EXPECT_CALL(*writer, onStreamClosed(_)); + + TestStreamStateMachineBase ssm(writer, 1u); + ssm.getConsumerAllowance(); + ssm.handleCancel(); + ssm.handleError(std::runtime_error("test")); + ssm.handlePayload(Payload{}, false, true, false); + ssm.handleRequestN(1); +} + +TEST(StreamState, ChannelRequesterOnError) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + EXPECT_CALL(*subscription, request_(1)); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + requester->subscribe(mockSubscriber); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + subscriber->onError(std::runtime_error("test")); + + ASSERT_TRUE(requester->consumerClosed()); + ASSERT_TRUE(requester->publisherClosed()); +} + +TEST(StreamState, ChannelResponderOnError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + EXPECT_CALL(*writer, writeRequestN_(_)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + responder->subscribe(mockSubscriber); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + subscriber->onError(std::runtime_error("test")); + + ASSERT_TRUE(responder->consumerClosed()); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamState, ChannelRequesterHandleError) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writeError_(_)).Times(0); + EXPECT_CALL(*writer, onStreamClosed(1u)).Times(0); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + requester->subscribe(mockSubscriber); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + EXPECT_CALL(*subscription, request_(1)); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + ConsumerBase* consumer = requester.get(); + consumer->handleError(std::runtime_error("test")); + + ASSERT_TRUE(requester->consumerClosed()); + ASSERT_TRUE(requester->publisherClosed()); +} + +TEST(StreamState, ChannelResponderHandleError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writeError_(_)).Times(0); + EXPECT_CALL(*writer, onStreamClosed(1u)).Times(0); + EXPECT_CALL(*writer, writeRequestN_(_)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + + responder->subscribe(mockSubscriber); + + // Initialize the responder + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + EXPECT_CALL(*subscription, request_(1)).Times(0); + + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + ConsumerBase* consumer = responder.get(); + consumer->handleError(std::runtime_error("test")); + + ASSERT_TRUE(responder->consumerClosed()); + ASSERT_TRUE(responder->publisherClosed()); +} + +// https://github.com/rsocket/rsocket/blob/master/Protocol.md#cancel-from-requester-responder-terminates +TEST(StreamState, ChannelRequesterCancel) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeCancel_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)).Times(0); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + requester->subscribe(mockSubscriber); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + EXPECT_CALL(*subscription, request_(1)); + EXPECT_CALL(*subscription, request_(2)); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + ConsumerBase* consumer = requester.get(); + consumer->cancel(); + + ASSERT_TRUE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + // Still capable of using the producer side + StreamStateMachineBase* base = requester.get(); + base->handleRequestN(2u); + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); +} + +TEST(StreamState, ChannelResponderCancel) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeCancel_(_)); + EXPECT_CALL(*writer, writeRequestN_(_)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + + responder->subscribe(mockSubscriber); + + // Initialize the responder + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + EXPECT_CALL(*subscription, request_(2)); + + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + ConsumerBase* consumer = responder.get(); + consumer->cancel(); + + ASSERT_TRUE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + // Still capable of using the producer side + StreamStateMachineBase* base = responder.get(); + base->handleRequestN(2u); + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); +} + +TEST(StreamState, ChannelRequesterHandleCancel) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writePayload_(_)).Times(0); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + requester->subscribe(mockSubscriber); // cycle: requester <-> mockSubscriber + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + EXPECT_CALL(*subscription, request_(1)); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + ConsumerBase* consumer = requester.get(); + consumer->handleCancel(); + + ASSERT_TRUE(requester->publisherClosed()); + ASSERT_FALSE(requester->consumerClosed()); + + // As the publisher is closed, this payload will be dropped + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); + + // Break the cycle: requester <-> mockSubscriber + EXPECT_CALL(*writer, writeCancel_(_)); + auto consumerSubscription = mockSubscriber->subscription(); + consumerSubscription->cancel(); +} + +TEST(StreamState, ChannelResponderHandleCancel) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writePayload_(_)).Times(0); + EXPECT_CALL(*writer, writeRequestN_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + responder->subscribe(mockSubscriber); // cycle: responder <-> mockSubscriber + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + ConsumerBase* consumer = responder.get(); + consumer->handleCancel(); + + ASSERT_TRUE(responder->publisherClosed()); + ASSERT_FALSE(responder->consumerClosed()); + + // As the publisher is closed, this payload will be dropped + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); + + // Break the cycle: responder <-> mockSubscriber + EXPECT_CALL(*writer, writeCancel_(_)); + auto consumerSubscription = mockSubscriber->subscription(); + consumerSubscription->cancel(); +} diff --git a/rsocket/test/statemachine/StreamsWriterTest.cpp b/rsocket/test/statemachine/StreamsWriterTest.cpp new file mode 100644 index 000000000..e764df8c6 --- /dev/null +++ b/rsocket/test/statemachine/StreamsWriterTest.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "rsocket/statemachine/ChannelRequester.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace rsocket; +using namespace testing; + +TEST(StreamsWriterTest, DelegateMock) { + auto writer = std::make_shared>(); + auto& impl = writer->delegateToImpl(); + EXPECT_CALL(impl, outputFrame_(_)); + EXPECT_CALL(impl, shouldQueue()).WillOnce(Return(false)); + EXPECT_CALL(*writer, writeNewStream_(_, _, _, _)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); +} + +TEST(StreamsWriterTest, NewStreamsMockWriterImpl) { + auto writer = std::make_shared>(); + EXPECT_CALL(*writer, outputFrame_(_)); + EXPECT_CALL(*writer, shouldQueue()).WillOnce(Return(false)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); +} + +TEST(StreamsWriterTest, QueueFrames) { + auto writer = std::make_shared>(); + auto& impl = writer->delegateToImpl(); + impl.shouldQueue_ = true; + + EXPECT_CALL(impl, outputFrame_(_)).Times(0); + EXPECT_CALL(impl, shouldQueue()).WillOnce(Return(true)); + EXPECT_CALL(*writer, writeNewStream_(_, _, _, _)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); +} + +TEST(StreamsWriterTest, FlushQueuedFrames) { + auto writer = std::make_shared>(); + auto& impl = writer->delegateToImpl(); + impl.shouldQueue_ = true; + + EXPECT_CALL(impl, outputFrame_(_)).Times(1); + EXPECT_CALL(impl, shouldQueue()).Times(3); + EXPECT_CALL(*writer, writeNewStream_(_, _, _, _)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); + + // Will queue again + impl.sendPendingFrames(); + + // Now send them actually + impl.shouldQueue_ = false; + impl.sendPendingFrames(); + // it will not send the pending frames twice + impl.sendPendingFrames(); +} diff --git a/test/test_utils/ColdResumeManager.cpp b/rsocket/test/test_utils/ColdResumeManager.cpp similarity index 91% rename from test/test_utils/ColdResumeManager.cpp rename to rsocket/test/test_utils/ColdResumeManager.cpp index faab3e0a8..cf53d1b40 100644 --- a/test/test_utils/ColdResumeManager.cpp +++ b/rsocket/test/test_utils/ColdResumeManager.cpp @@ -1,8 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "ColdResumeManager.h" #include +#include #include @@ -18,7 +31,7 @@ constexpr folly::StringPiece REQUESTER = "Requester"; constexpr folly::StringPiece STREAM_TOKEN = "StreamToken"; constexpr folly::StringPiece PROD_ALLOWANCE = "ProducerAllowance"; constexpr folly::StringPiece CONS_ALLOWANCE = "ConsumerAllowance"; -} +} // namespace namespace rsocket { @@ -197,4 +210,4 @@ void ColdResumeManager::onStreamOpen( streamId, StreamResumeInfo(streamType, requester, streamToken)); } -} // reactivesocket +} // namespace rsocket diff --git a/test/test_utils/ColdResumeManager.h b/rsocket/test/test_utils/ColdResumeManager.h similarity index 65% rename from test/test_utils/ColdResumeManager.h rename to rsocket/test/test_utils/ColdResumeManager.h index 543da10c7..17cee3afb 100644 --- a/test/test_utils/ColdResumeManager.h +++ b/rsocket/test/test_utils/ColdResumeManager.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -44,11 +56,11 @@ class ColdResumeManager : public WarmResumeManager { void onStreamClosed(StreamId streamId) override; - const StreamResumeInfos& getStreamResumeInfos() override { + const StreamResumeInfos& getStreamResumeInfos() const override { return streamResumeInfos_; } - StreamId getLargestUsedStreamId() override { + StreamId getLargestUsedStreamId() const override { return largestUsedStreamId_; } @@ -61,4 +73,4 @@ class ColdResumeManager : public WarmResumeManager { // Largest used StreamId so far. StreamId largestUsedStreamId_{0}; }; -} +} // namespace rsocket diff --git a/test/test_utils/GenericRequestResponseHandler.h b/rsocket/test/test_utils/GenericRequestResponseHandler.h similarity index 62% rename from test/test_utils/GenericRequestResponseHandler.h rename to rsocket/test/test_utils/GenericRequestResponseHandler.h index 4ac0787b0..f3f79b6d1 100644 --- a/test/test_utils/GenericRequestResponseHandler.h +++ b/rsocket/test/test_utils/GenericRequestResponseHandler.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -35,17 +47,30 @@ struct GenericRequestResponseHandler : public rsocket::RSocketResponder { explicit GenericRequestResponseHandler(HandlerFunc&& func) : handler_(std::make_unique(std::move(func))) {} - yarpl::Reference> handleRequestResponse( + std::shared_ptr> handleRequestResponse( Payload request, StreamId) override { - auto data = request.moveDataToString(); - auto meta = request.moveMetadataToString(); + auto ioBufChainToString = [](std::unique_ptr buf) { + folly::IOBufQueue queue; + queue.append(std::move(buf)); + + std::string ret; + while (auto elem = queue.pop_front()) { + auto part = elem->moveToFbString(); + ret += part.toStdString(); + } + + return ret; + }; + + std::string data = ioBufChainToString(std::move(request.data)); + std::string meta = ioBufChainToString(std::move(request.metadata)); StringPair req(data, meta); Response resp = (*handler_)(req); return yarpl::single::Single::create( - [ resp = std::move(resp), this ](auto subscriber) { + [resp = std::move(resp), this](auto subscriber) { subscriber->onSubscribe(yarpl::single::SingleSubscriptions::empty()); if (resp->type == ResponseImpl::Type::PAYLOAD) { @@ -80,11 +105,13 @@ Response error_response(T const& err) { inline StringPair payload_to_stringpair(Payload p) { return StringPair(p.moveDataToString(), p.moveMetadataToString()); } -} -} +} // namespace tests +} // namespace rsocket -inline std::ostream& operator<<( +namespace std { +inline ostream& operator<<( std::ostream& os, rsocket::tests::StringPair const& payload) { return os << "('" << payload.first << "', '" << payload.second << "')"; } +} // namespace std diff --git a/test/test_utils/MockDuplexConnection.h b/rsocket/test/test_utils/MockDuplexConnection.h similarity index 50% rename from test/test_utils/MockDuplexConnection.h rename to rsocket/test/test_utils/MockDuplexConnection.h index de2dbb580..a64c64436 100644 --- a/test/test_utils/MockDuplexConnection.h +++ b/rsocket/test/test_utils/MockDuplexConnection.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -24,7 +36,7 @@ class MockDuplexConnection : public DuplexConnection { // DuplexConnection. - void setInput(yarpl::Reference in) override { + void setInput(std::shared_ptr in) override { setInput_(std::move(in)); } @@ -34,8 +46,9 @@ class MockDuplexConnection : public DuplexConnection { // Mocks. - MOCK_METHOD1(setInput_, void(yarpl::Reference)); + MOCK_METHOD1(setInput_, void(std::shared_ptr)); MOCK_METHOD1(send_, void(std::unique_ptr&)); + MOCK_CONST_METHOD0(isFramed, bool()); }; } // namespace rsocket diff --git a/rsocket/test/test_utils/MockFrameProcessor.h b/rsocket/test/test_utils/MockFrameProcessor.h new file mode 100644 index 000000000..385a143f1 --- /dev/null +++ b/rsocket/test/test_utils/MockFrameProcessor.h @@ -0,0 +1,40 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "rsocket/framing/FrameProcessor.h" + +namespace rsocket { + +class MockFrameProcessor : public FrameProcessor { + public: + void processFrame(std::unique_ptr buf) override { + processFrame_(buf); + } + + void onTerminal(folly::exception_wrapper ew) override { + onTerminal_(std::move(ew)); + } + + MOCK_METHOD1(processFrame_, void(std::unique_ptr&)); + MOCK_METHOD1(onTerminal_, void(folly::exception_wrapper)); +}; + +} // namespace rsocket diff --git a/test/test_utils/MockStats.h b/rsocket/test/test_utils/MockStats.h similarity index 58% rename from test/test_utils/MockStats.h rename to rsocket/test/test_utils/MockStats.h index b6c7be1c2..c1707299f 100644 --- a/test/test_utils/MockStats.h +++ b/rsocket/test/test_utils/MockStats.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -32,4 +44,4 @@ class MockStats : public RSocketStats { MOCK_METHOD2(resumeBufferChanged, void(int, int)); MOCK_METHOD2(streamBufferChanged, void(int64_t, int64_t)); }; -} +} // namespace rsocket diff --git a/rsocket/test/test_utils/MockStreamsWriter.h b/rsocket/test/test_utils/MockStreamsWriter.h new file mode 100644 index 000000000..4d6593c48 --- /dev/null +++ b/rsocket/test/test_utils/MockStreamsWriter.h @@ -0,0 +1,152 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/RSocketStats.h" +#include "rsocket/framing/FrameSerializer_v1_0.h" +#include "rsocket/statemachine/StreamsWriter.h" + +namespace rsocket { + +class MockStreamsWriterImpl : public StreamsWriterImpl { + public: + MOCK_METHOD1(onStreamClosed, void(StreamId)); + MOCK_METHOD1(outputFrame_, void(folly::IOBuf*)); + MOCK_METHOD0(shouldQueue, bool()); + + MockStreamsWriterImpl() { + using namespace testing; + ON_CALL(*this, shouldQueue()).WillByDefault(Invoke([this]() { + return this->shouldQueue_; + })); + } + + void outputFrame(std::unique_ptr buf) override { + outputFrame_(buf.get()); + } + + FrameSerializer& serializer() override { + return frameSerializer; + } + + RSocketStats& stats() override { + return *stats_; + } + + std::shared_ptr> onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) override { + // ignoring... + return nullptr; + } + + void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) + override { + // ignoring... + } + + using StreamsWriterImpl::sendPendingFrames; + + bool shouldQueue_{false}; + std::shared_ptr stats_ = RSocketStats::noop(); + FrameSerializerV1_0 frameSerializer; +}; + +class MockStreamsWriter : public StreamsWriter { + public: + MOCK_METHOD4(writeNewStream_, void(StreamId, StreamType, uint32_t, Payload&)); + MOCK_METHOD1(writeRequestN_, void(rsocket::Frame_REQUEST_N)); + MOCK_METHOD1(writeCancel_, void(rsocket::Frame_CANCEL)); + MOCK_METHOD1(writePayload_, void(rsocket::Frame_PAYLOAD&)); + MOCK_METHOD1(writeError_, void(rsocket::Frame_ERROR&)); + MOCK_METHOD1(onStreamClosed, void(rsocket::StreamId)); + + // Delegate the Mock calls to the implementation in StreamsWriterImpl. + MockStreamsWriterImpl& delegateToImpl() { + delegateToImpl_ = true; + using namespace testing; + ON_CALL(*this, onStreamClosed(_)) + .WillByDefault(Invoke(&impl_, &StreamsWriter::onStreamClosed)); + return impl_; + } + + void writeNewStream(StreamId id, StreamType type, uint32_t i, Payload p) + override { + writeNewStream_(id, type, i, p); + if (delegateToImpl_) { + impl_.writeNewStream(id, type, i, std::move(p)); + } + } + + void writeRequestN(rsocket::Frame_REQUEST_N&& request) override { + if (delegateToImpl_) { + impl_.writeRequestN(std::move(request)); + } + writeRequestN_(request); + } + + void writeCancel(rsocket::Frame_CANCEL&& cancel) override { + writeCancel_(cancel); + if (delegateToImpl_) { + impl_.writeCancel(std::move(cancel)); + } + } + + void writePayload(rsocket::Frame_PAYLOAD&& payload) override { + writePayload_(payload); + if (delegateToImpl_) { + impl_.writePayload(std::move(payload)); + } + } + + void writeError(rsocket::Frame_ERROR&& error) override { + writeError_(error); + if (delegateToImpl_) { + impl_.writeError(std::move(error)); + } + } + + std::shared_ptr> onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) override { + // ignoring... + return nullptr; + } + + void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) + override { + // ignoring... + } + + protected: + MockStreamsWriterImpl impl_; + bool delegateToImpl_{false}; +}; + +} // namespace rsocket diff --git a/test/test_utils/PrintSubscriber.cpp b/rsocket/test/test_utils/PrintSubscriber.cpp similarity index 51% rename from test/test_utils/PrintSubscriber.cpp rename to rsocket/test/test_utils/PrintSubscriber.cpp index c25e134cc..baf3c14ef 100644 --- a/test/test_utils/PrintSubscriber.cpp +++ b/rsocket/test/test_utils/PrintSubscriber.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "PrintSubscriber.h" #include @@ -12,7 +24,7 @@ PrintSubscriber::~PrintSubscriber() { } void PrintSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { LOG(INFO) << "PrintSubscriber " << this << " onSubscribe"; subscription->request(std::numeric_limits::max()); } @@ -28,4 +40,4 @@ void PrintSubscriber::onComplete() noexcept { void PrintSubscriber::onError(folly::exception_wrapper ex) noexcept { LOG(INFO) << "PrintSubscriber " << this << " onError " << ex; } -} +} // namespace rsocket diff --git a/rsocket/test/test_utils/PrintSubscriber.h b/rsocket/test/test_utils/PrintSubscriber.h new file mode 100644 index 000000000..5a392c1dd --- /dev/null +++ b/rsocket/test/test_utils/PrintSubscriber.h @@ -0,0 +1,31 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" +#include "yarpl/flowable/Subscriber.h" + +namespace rsocket { +class PrintSubscriber : public yarpl::flowable::Subscriber { + public: + ~PrintSubscriber(); + + void onSubscribe(std::shared_ptr + subscription) noexcept override; + void onNext(Payload element) noexcept override; + void onComplete() noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; +}; +} // namespace rsocket diff --git a/test/test_utils/StatsPrinter.cpp b/rsocket/test/test_utils/StatsPrinter.cpp similarity index 71% rename from test/test_utils/StatsPrinter.cpp rename to rsocket/test/test_utils/StatsPrinter.cpp index e29cf094d..90f4f0511 100644 --- a/test/test_utils/StatsPrinter.cpp +++ b/rsocket/test/test_utils/StatsPrinter.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "StatsPrinter.h" #include @@ -65,4 +77,4 @@ void StatsPrinter::keepaliveSent() { void StatsPrinter::keepaliveReceived() { LOG(INFO) << "keepalive response received"; } -} +} // namespace rsocket diff --git a/test/test_utils/StatsPrinter.h b/rsocket/test/test_utils/StatsPrinter.h similarity index 60% rename from test/test_utils/StatsPrinter.h rename to rsocket/test/test_utils/StatsPrinter.h index 33fa91bbf..afe2fa8a0 100644 --- a/test/test_utils/StatsPrinter.h +++ b/rsocket/test/test_utils/StatsPrinter.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -30,4 +42,4 @@ class StatsPrinter : public RSocketStats { void keepaliveSent() override; void keepaliveReceived() override; }; -} +} // namespace rsocket diff --git a/test/transport/DuplexConnectionTest.cpp b/rsocket/test/transport/DuplexConnectionTest.cpp similarity index 76% rename from test/transport/DuplexConnectionTest.cpp rename to rsocket/test/transport/DuplexConnectionTest.cpp index f36e0fc10..d8bcd33b7 100644 --- a/test/transport/DuplexConnectionTest.cpp +++ b/rsocket/test/transport/DuplexConnectionTest.cpp @@ -1,3 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "DuplexConnectionTest.h" #include @@ -9,14 +23,13 @@ namespace tests { using namespace folly; using namespace rsocket; using namespace ::testing; -using namespace yarpl::flowable; void makeMultipleSetInputGetOutputCalls( std::unique_ptr serverConnection, EventBase* serverEvb, std::unique_ptr clientConnection, EventBase* clientEvb) { - auto serverSubscriber = yarpl::make_ref< + auto serverSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(10); @@ -27,7 +40,7 @@ void makeMultipleSetInputGetOutputCalls( }); for (int i = 0; i < 10; ++i) { - auto clientSubscriber = yarpl::make_ref< + auto clientSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); EXPECT_CALL(*clientSubscriber, onNext_(_)); @@ -36,12 +49,12 @@ void makeMultipleSetInputGetOutputCalls( // Set another subscriber and receive messages clientConnection->setInput(clientSubscriber); // Get another subscriber and send messages - clientConnection->send(folly::IOBuf::copyBuffer("01234")); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); }); serverSubscriber->awaitFrames(1); serverEvb->runInEventBaseThreadAndWait( - [&] { serverConnection->send(folly::IOBuf::copyBuffer("43210")); }); + [&] { serverConnection->send(folly::IOBuf::copyBuffer("6543210")); }); clientSubscriber->awaitFrames(1); clientEvb->runInEventBaseThreadAndWait( @@ -57,9 +70,9 @@ void makeMultipleSetInputGetOutputCalls( subscriber->subscription()->cancel(); }); clientEvb->runInEventBaseThreadAndWait( - [connection = std::move(clientConnection)]{}); + [connection = std::move(clientConnection)] {}); serverEvb->runInEventBaseThreadAndWait( - [connection = std::move(serverConnection)]{}); + [connection = std::move(serverConnection)] {}); } /** @@ -70,7 +83,7 @@ void verifyInputAndOutputIsUntied( EventBase* serverEvb, std::unique_ptr clientConnection, EventBase* clientEvb) { - auto serverSubscriber = yarpl::make_ref< + auto serverSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(3); @@ -78,13 +91,13 @@ void verifyInputAndOutputIsUntied( serverEvb->runInEventBaseThreadAndWait( [&] { serverConnection->setInput(serverSubscriber); }); - auto clientSubscriber = yarpl::make_ref< + auto clientSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); clientEvb->runInEventBaseThreadAndWait([&] { clientConnection->setInput(clientSubscriber); - clientConnection->send(folly::IOBuf::copyBuffer("01234")); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); }); serverSubscriber->awaitFrames(1); @@ -95,25 +108,25 @@ void verifyInputAndOutputIsUntied( auto deleteSubscriber = std::move(clientSubscriber); } // Output is still active - clientConnection->send(folly::IOBuf::copyBuffer("01234")); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); }); serverSubscriber->awaitFrames(1); // Another client subscriber - clientSubscriber = yarpl::make_ref< + clientSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); EXPECT_CALL(*clientSubscriber, onNext_(_)); clientEvb->runInEventBaseThreadAndWait([&] { // Set new input subscriber clientConnection->setInput(clientSubscriber); - clientConnection->send(folly::IOBuf::copyBuffer("01234")); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); }); serverSubscriber->awaitFrames(1); // Still sending message from server to the client. serverEvb->runInEventBaseThreadAndWait( - [&] { serverConnection->send(folly::IOBuf::copyBuffer("43210")); }); + [&] { serverConnection->send(folly::IOBuf::copyBuffer("6543210")); }); clientSubscriber->awaitFrames(1); // Cleanup @@ -126,9 +139,9 @@ void verifyInputAndOutputIsUntied( subscriber->subscription()->cancel(); }); clientEvb->runInEventBaseThreadAndWait( - [connection = std::move(clientConnection)]{}); + [connection = std::move(clientConnection)] {}); serverEvb->runInEventBaseThreadAndWait( - [connection = std::move(serverConnection)]{}); + [connection = std::move(serverConnection)] {}); } void verifyClosingInputAndOutputDoesntCloseConnection( @@ -136,14 +149,14 @@ void verifyClosingInputAndOutputDoesntCloseConnection( folly::EventBase* serverEvb, std::unique_ptr clientConnection, folly::EventBase* clientEvb) { - auto serverSubscriber = yarpl::make_ref< + auto serverSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); serverEvb->runInEventBaseThreadAndWait( [&] { serverConnection->setInput(serverSubscriber); }); - auto clientSubscriber = yarpl::make_ref< + auto clientSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); @@ -160,7 +173,7 @@ void verifyClosingInputAndOutputDoesntCloseConnection( }); // Set new subscribers as the connection is not closed - serverSubscriber = yarpl::make_ref< + serverSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(1); @@ -171,7 +184,7 @@ void verifyClosingInputAndOutputDoesntCloseConnection( serverEvb->runInEventBaseThreadAndWait( [&] { serverConnection->setInput(serverSubscriber); }); - clientSubscriber = yarpl::make_ref< + clientSubscriber = std::make_shared< yarpl::mocks::MockSubscriber>>(); EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); EXPECT_CALL(*clientSubscriber, onNext_(_)).Times(1); @@ -181,20 +194,20 @@ void verifyClosingInputAndOutputDoesntCloseConnection( clientEvb->runInEventBaseThreadAndWait([&] { clientConnection->setInput(clientSubscriber); - clientConnection->send(folly::IOBuf::copyBuffer("01234")); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); }); serverSubscriber->awaitFrames(1); // Wait till client is ready before sending message from server. serverEvb->runInEventBaseThreadAndWait( - [&] { serverConnection->send(folly::IOBuf::copyBuffer("43210")); }); + [&] { serverConnection->send(folly::IOBuf::copyBuffer("6543210")); }); clientSubscriber->awaitFrames(1); // Cleanup clientEvb->runInEventBaseThreadAndWait( - [connection = std::move(clientConnection)]{}); + [connection = std::move(clientConnection)] {}); serverEvb->runInEventBaseThreadAndWait( - [connection = std::move(serverConnection)]{}); + [connection = std::move(serverConnection)] {}); } } // namespace tests diff --git a/test/transport/DuplexConnectionTest.h b/rsocket/test/transport/DuplexConnectionTest.h similarity index 59% rename from test/transport/DuplexConnectionTest.h rename to rsocket/test/transport/DuplexConnectionTest.h index e013b6ffb..c370975e9 100644 --- a/test/transport/DuplexConnectionTest.h +++ b/rsocket/test/transport/DuplexConnectionTest.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/test/transport/TcpDuplexConnectionTest.cpp b/rsocket/test/transport/TcpDuplexConnectionTest.cpp similarity index 66% rename from test/transport/TcpDuplexConnectionTest.cpp rename to rsocket/test/transport/TcpDuplexConnectionTest.cpp index 360e54d06..ae17a51dd 100644 --- a/test/transport/TcpDuplexConnectionTest.cpp +++ b/rsocket/test/transport/TcpDuplexConnectionTest.cpp @@ -1,20 +1,32 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include +#include #include +#include "rsocket/test/transport/DuplexConnectionTest.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "test/transport/DuplexConnectionTest.h" namespace rsocket { namespace tests { using namespace folly; using namespace rsocket; -using namespace ::testing; /** * Synchronously create a server and a client. @@ -29,9 +41,12 @@ makeSingleClientServer( EventBase* clientEvb) { Promise serverPromise; - TcpConnectionAcceptor::Options options( - 0 /*port*/, 1 /*threads*/, 0 /*backlog*/); - auto server = std::make_unique(options); + TcpConnectionAcceptor::Options options; + options.address = folly::SocketAddress{"::", 0}; + options.threads = 1; + options.backlog = 0; + + auto server = std::make_unique(std::move(options)); server->start( [&serverPromise, &serverConnection, &serverEvb]( std::unique_ptr connection, EventBase& eventBase) { @@ -43,22 +58,22 @@ makeSingleClientServer( int16_t port = server->listeningPort().value(); auto client = std::make_unique( - *clientEvb, - SocketAddress("localhost", port, true)); - client->connect().then( - [&clientConnection]( - ConnectionFactory::ConnectedDuplexConnection connection) { + *clientEvb, SocketAddress("localhost", port, true)); + client->connect(ProtocolVersion::Latest, ResumeStatus::NEW_SESSION) + .thenValue([&clientConnection]( + ConnectionFactory::ConnectedDuplexConnection connection) { clientConnection = std::move(connection.connection); - }).wait(); + }) + .wait(); - serverPromise.getFuture().wait(); + serverPromise.getSemiFuture().wait(); return std::make_pair(std::move(server), std::move(client)); } TEST(TcpDuplexConnection, MultipleSetInputGetOutputCalls) { folly::ScopedEventBaseThread worker; std::unique_ptr serverConnection, clientConnection; - EventBase *serverEvb = nullptr; + EventBase* serverEvb = nullptr; auto keepAlive = makeSingleClientServer( serverConnection, &serverEvb, clientConnection, worker.getEventBase()); makeMultipleSetInputGetOutputCalls( @@ -71,7 +86,7 @@ TEST(TcpDuplexConnection, MultipleSetInputGetOutputCalls) { TEST(TcpDuplexConnection, InputAndOutputIsUntied) { folly::ScopedEventBaseThread worker; std::unique_ptr serverConnection, clientConnection; - EventBase *serverEvb = nullptr; + EventBase* serverEvb = nullptr; auto keepAlive = makeSingleClientServer( serverConnection, &serverEvb, clientConnection, worker.getEventBase()); verifyInputAndOutputIsUntied( @@ -84,7 +99,7 @@ TEST(TcpDuplexConnection, InputAndOutputIsUntied) { TEST(TcpDuplexConnection, ConnectionAndSubscribersAreUntied) { folly::ScopedEventBaseThread worker; std::unique_ptr serverConnection, clientConnection; - EventBase *serverEvb = nullptr; + EventBase* serverEvb = nullptr; auto keepAlive = makeSingleClientServer( serverConnection, &serverEvb, clientConnection, worker.getEventBase()); verifyClosingInputAndOutputDoesntCloseConnection( diff --git a/rsocket/transports/RSocketTransport.h b/rsocket/transports/RSocketTransport.h new file mode 100644 index 000000000..d86a4669a --- /dev/null +++ b/rsocket/transports/RSocketTransport.h @@ -0,0 +1,49 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace rsocket { +class RSocketTransportHandler { + public: + virtual ~RSocketTransportHandler() = default; + + // connection scope signals + virtual void onKeepAlive( + ResumePosition resumePosition, + std::unique_ptr data, + bool keepAliveRespond) = 0; + virtual void onMetadataPush(std::unique_ptr metadata) = 0; + virtual void onResumeOk(ResumePosition resumePosition); + virtual void onError(ErrorCode errorCode, Payload payload) = 0; + + // stream scope signals + virtual void onStreamRequestN(StreamId streamId, uint32_t requestN) = 0; + virtual void onStreamCancel(StreamId streamId) = 0; + virtual void onStreamError(StreamId streamId, Payload payload) = 0; + virtual void onStreamPayload( + StreamId streamId, + Payload payload, + bool flagsFollows, + bool flagsComplete, + bool flagsNext) = 0; +}; + +class RSocketTransport { + public: + virtual ~RSocketTransport() = default; + + // TODO: +}; +} // namespace rsocket diff --git a/rsocket/transports/tcp/TcpConnectionAcceptor.cpp b/rsocket/transports/tcp/TcpConnectionAcceptor.cpp index 0b0fff909..12ac289f9 100644 --- a/rsocket/transports/tcp/TcpConnectionAcceptor.cpp +++ b/rsocket/transports/tcp/TcpConnectionAcceptor.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" @@ -6,10 +18,7 @@ #include #include #include -#include -#include -#include "rsocket/framing/FramedDuplexConnection.h" #include "rsocket/transports/tcp/TcpDuplexConnection.h" namespace rsocket { @@ -18,22 +27,24 @@ class TcpConnectionAcceptor::SocketCallback : public folly::AsyncServerSocket::AcceptCallback { public: explicit SocketCallback(OnDuplexConnectionAccept& onAccept) - : onAccept_{onAccept} {} + : thread_{folly::sformat("rstcp-acceptor")}, onAccept_{onAccept} {} void connectionAccepted( - int fd, + folly::NetworkSocket fdNetworkSocket, const folly::SocketAddress& address) noexcept override { + int fd = fdNetworkSocket.toFd(); + VLOG(2) << "Accepting TCP connection from " << address << " on FD " << fd; folly::AsyncTransportWrapper::UniquePtr socket( - new folly::AsyncSocket(eventBase(), fd)); + new folly::AsyncSocket(eventBase(), folly::NetworkSocket::fromFd(fd))); auto connection = std::make_unique(std::move(socket)); onAccept_(std::move(connection), *eventBase()); } - void acceptError(const std::exception& ex) noexcept override { - VLOG(2) << "TCP error: " << ex.what(); + void acceptError(folly::exception_wrapper ex) noexcept override { + VLOG(2) << "TCP error: " << ex; } folly::EventBase* eventBase() const { @@ -48,8 +59,6 @@ class TcpConnectionAcceptor::SocketCallback OnDuplexConnectionAccept& onAccept_; }; -//////////////////////////////////////////////////////////////////////////////// - TcpConnectionAcceptor::TcpConnectionAcceptor(Options options) : options_(std::move(options)) {} @@ -60,25 +69,18 @@ TcpConnectionAcceptor::~TcpConnectionAcceptor() { } } -//////////////////////////////////////////////////////////////////////////////// - void TcpConnectionAcceptor::start(OnDuplexConnectionAccept onAccept) { if (onAccept_ != nullptr) { throw std::runtime_error("TcpConnectionAcceptor::start() already called"); } onAccept_ = std::move(onAccept); - serverThread_ = std::make_unique(); - serverThread_->getEventBase()->runInEventBaseThread( - [] { folly::setThreadName("TcpConnectionAcceptor.Listener"); }); + serverThread_ = + std::make_unique("rstcp-listener"); callbacks_.reserve(options_.threads); for (size_t i = 0; i < options_.threads; ++i) { callbacks_.push_back(std::make_unique(onAccept_)); - callbacks_[i]->eventBase()->runInEventBaseThread([i] { - folly::EventBaseManager::get()->getEventBase()->setName( - folly::sformat("TCPWrk.{}", i)); - }); } VLOG(1) << "Starting TCP listener on port " << options_.address.getPort() @@ -89,30 +91,26 @@ void TcpConnectionAcceptor::start(OnDuplexConnectionAccept onAccept) { // The AsyncServerSocket needs to be accessed from the listener thread only. // This will propagate out any exceptions the listener throws. - folly::via( - serverThread_->getEventBase(), - [this] { - serverSocket_->bind(options_.address); - - for (auto const& callback : callbacks_) { - serverSocket_->addAcceptCallback( - callback.get(), callback->eventBase()); - } - - serverSocket_->listen(options_.backlog); - serverSocket_->startAccepting(); - - for (auto& i : serverSocket_->getAddresses()) { - VLOG(1) << "Listening on " << i.describe(); - } - }) - .get(); + folly::via(serverThread_->getEventBase(), [this] { + serverSocket_->bind(options_.address); + + for (auto const& callback : callbacks_) { + serverSocket_->addAcceptCallback(callback.get(), callback->eventBase()); + } + + serverSocket_->listen(options_.backlog); + serverSocket_->startAccepting(); + + for (const auto& i : serverSocket_->getAddresses()) { + VLOG(1) << "Listening on " << i.describe(); + } + }).get(); } void TcpConnectionAcceptor::stop() { VLOG(1) << "Shutting down TCP listener"; - serverThread_->getEventBase()->runInEventBaseThread( + serverThread_->getEventBase()->runInEventBaseThreadAndWait( [serverSocket = std::move(serverSocket_)]() {}); } diff --git a/rsocket/transports/tcp/TcpConnectionAcceptor.h b/rsocket/transports/tcp/TcpConnectionAcceptor.h index 840536690..5d922d06e 100644 --- a/rsocket/transports/tcp/TcpConnectionAcceptor.h +++ b/rsocket/transports/tcp/TcpConnectionAcceptor.h @@ -1,15 +1,24 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include +#include #include "rsocket/ConnectionAcceptor.h" -namespace folly { -class ScopedEventBaseThread; -} - namespace rsocket { /** @@ -20,29 +29,19 @@ namespace rsocket { class TcpConnectionAcceptor : public ConnectionAcceptor { public: struct Options { - explicit Options( - uint16_t port_ = 8080, - size_t threads_ = 2, - int backlog_ = 10) - : address("::", port_), threads(threads_), backlog(backlog_) {} - /// Address to listen on - folly::SocketAddress address; + folly::SocketAddress address{"::", 8080}; /// Number of worker threads processing requests. - size_t threads; + size_t threads{2}; /// Number of connections to buffer before accept handlers process them. - int backlog; + int backlog{10}; }; - ////////////////////////////////////////////////////////////////////////////// - explicit TcpConnectionAcceptor(Options); ~TcpConnectionAcceptor(); - ////////////////////////////////////////////////////////////////////////////// - // ConnectionAcceptor overrides. /** @@ -63,6 +62,9 @@ class TcpConnectionAcceptor : public ConnectionAcceptor { private: class SocketCallback; + /// Options this acceptor has been configured with. + const Options options_; + /// The thread driving the AsyncServerSocket. std::unique_ptr serverThread_; @@ -75,8 +77,6 @@ class TcpConnectionAcceptor : public ConnectionAcceptor { /// The socket listening for new connections. folly::AsyncServerSocket::UniquePtr serverSocket_; - - /// Options this acceptor has been configured with. - Options options_; }; -} + +} // namespace rsocket diff --git a/rsocket/transports/tcp/TcpConnectionFactory.cpp b/rsocket/transports/tcp/TcpConnectionFactory.cpp index 68563a34d..b970cd756 100644 --- a/rsocket/transports/tcp/TcpConnectionFactory.cpp +++ b/rsocket/transports/tcp/TcpConnectionFactory.cpp @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/transports/tcp/TcpConnectionFactory.h" +#include #include #include #include @@ -9,8 +22,6 @@ #include "rsocket/transports/tcp/TcpDuplexConnection.h" -using namespace rsocket; - namespace rsocket { namespace { @@ -19,24 +30,37 @@ class ConnectCallback : public folly::AsyncSocket::ConnectCallback { public: ConnectCallback( folly::SocketAddress address, + const std::shared_ptr& sslContext, folly::Promise connectPromise) - : address_(address), connectPromise_{std::move(connectPromise)} { + : address_(address), connectPromise_(std::move(connectPromise)) { VLOG(2) << "Constructing ConnectCallback"; // Set up by ScopedEventBaseThread. auto evb = folly::EventBaseManager::get()->getExistingEventBase(); DCHECK(evb); - VLOG(3) << "Starting socket"; - socket_.reset(new folly::AsyncSocket(evb)); + if (sslContext) { +#if !FOLLY_OPENSSL_HAS_ALPN + // setAdvertisedNextProtocols() is unavailable +#error ALPN is required for rsockets. \ + Your version of OpenSSL is likely too old. +#else + VLOG(3) << "Starting SSL socket"; + sslContext->setAdvertisedNextProtocols({"rs"}); +#endif + socket_.reset(new folly::AsyncSSLSocket(sslContext, evb)); + } else { + VLOG(3) << "Starting socket"; + socket_.reset(new folly::AsyncSocket(evb)); + } VLOG(3) << "Attempting connection to " << address_; socket_->connect(this, address_); } - ~ConnectCallback() { + ~ConnectCallback() override { VLOG(2) << "Destroying ConnectCallback"; } @@ -59,7 +83,7 @@ class ConnectCallback : public folly::AsyncSocket::ConnectCallback { } private: - folly::SocketAddress address_; + const folly::SocketAddress address_; folly::AsyncSocket::UniquePtr socket_; folly::Promise connectPromise_; }; @@ -68,23 +92,22 @@ class ConnectCallback : public folly::AsyncSocket::ConnectCallback { TcpConnectionFactory::TcpConnectionFactory( folly::EventBase& eventBase, - folly::SocketAddress address) - : address_{std::move(address)}, eventBase_{&eventBase} { - VLOG(1) << "Constructing TcpConnectionFactory"; -} + folly::SocketAddress address, + std::shared_ptr sslContext) + : eventBase_(&eventBase), + address_(std::move(address)), + sslContext_(std::move(sslContext)) {} -TcpConnectionFactory::~TcpConnectionFactory() { - VLOG(1) << "Destroying TcpConnectionFactory"; -} +TcpConnectionFactory::~TcpConnectionFactory() = default; folly::Future -TcpConnectionFactory::connect() { +TcpConnectionFactory::connect(ProtocolVersion, ResumeStatus /* unused */) { folly::Promise connectPromise; auto connectFuture = connectPromise.getFuture(); eventBase_->runInEventBaseThread( - [ this, connectPromise = std::move(connectPromise) ]() mutable { - new ConnectCallback(address_, std::move(connectPromise)); + [this, promise = std::move(connectPromise)]() mutable { + new ConnectCallback(address_, sslContext_, std::move(promise)); }); return connectFuture; } diff --git a/rsocket/transports/tcp/TcpConnectionFactory.h b/rsocket/transports/tcp/TcpConnectionFactory.h index 2aee46fa5..283b50eb5 100644 --- a/rsocket/transports/tcp/TcpConnectionFactory.h +++ b/rsocket/transports/tcp/TcpConnectionFactory.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -8,6 +20,11 @@ #include "rsocket/ConnectionFactory.h" #include "rsocket/DuplexConnection.h" +namespace folly { + +class SSLContext; +} + namespace rsocket { class RSocketStats; @@ -19,7 +36,10 @@ class RSocketStats; */ class TcpConnectionFactory : public ConnectionFactory { public: - TcpConnectionFactory(folly::EventBase& eventBase, folly::SocketAddress); + TcpConnectionFactory( + folly::EventBase& eventBase, + folly::SocketAddress address, + std::shared_ptr sslContext = nullptr); virtual ~TcpConnectionFactory(); /** @@ -27,14 +47,17 @@ class TcpConnectionFactory : public ConnectionFactory { * * Each call to connect() creates a new AsyncSocket. */ - folly::Future connect() override; + folly::Future connect( + ProtocolVersion, + ResumeStatus resume) override; static std::unique_ptr createDuplexConnectionFromSocket( folly::AsyncTransportWrapper::UniquePtr socket, std::shared_ptr stats = std::shared_ptr()); private: - folly::SocketAddress address_; folly::EventBase* eventBase_; + const folly::SocketAddress address_; + std::shared_ptr sslContext_; }; } // namespace rsocket diff --git a/rsocket/transports/tcp/TcpDuplexConnection.cpp b/rsocket/transports/tcp/TcpDuplexConnection.cpp index e51bbf1e9..054e31768 100644 --- a/rsocket/transports/tcp/TcpDuplexConnection.cpp +++ b/rsocket/transports/tcp/TcpDuplexConnection.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/transports/tcp/TcpDuplexConnection.h" @@ -23,7 +35,7 @@ class TcpReaderWriter : public folly::AsyncTransportWrapper::WriteCallback, std::shared_ptr stats) : socket_(std::move(socket)), stats_(std::move(stats)) {} - ~TcpReaderWriter() { + ~TcpReaderWriter() override { CHECK(isClosed()); DCHECK(!inputSubscriber_); } @@ -32,8 +44,7 @@ class TcpReaderWriter : public folly::AsyncTransportWrapper::WriteCallback, return socket_.get(); } - void setInput( - yarpl::Reference inputSubscriber) { + void setInput(std::shared_ptr inputSubscriber) { if (inputSubscriber && isClosed()) { inputSubscriber->onComplete(); return; @@ -96,10 +107,9 @@ class TcpReaderWriter : public folly::AsyncTransportWrapper::WriteCallback, intrusive_ptr_release(this); } - void writeErr( - size_t, - const folly::AsyncSocketException& exn) noexcept override { - closeErr(folly::exception_wrapper{exn}); + void writeErr(size_t, const folly::AsyncSocketException& exn) noexcept + override { + closeErr(folly::exception_wrapper{folly::copy(exn)}); intrusive_ptr_release(this); } @@ -124,7 +134,7 @@ class TcpReaderWriter : public folly::AsyncTransportWrapper::WriteCallback, } void readErr(const folly::AsyncSocketException& exn) noexcept override { - closeErr(exn); + closeErr(folly::exception_wrapper{folly::copy(exn)}); intrusive_ptr_release(this); } @@ -142,7 +152,7 @@ class TcpReaderWriter : public folly::AsyncTransportWrapper::WriteCallback, folly::AsyncTransportWrapper::UniquePtr socket_; const std::shared_ptr stats_; - yarpl::Reference inputSubscriber_; + std::shared_ptr inputSubscriber_; int refCount_{0}; }; @@ -213,10 +223,10 @@ void TcpDuplexConnection::send(std::unique_ptr buf) { } void TcpDuplexConnection::setInput( - yarpl::Reference inputSubscriber) { + std::shared_ptr inputSubscriber) { // we don't care if the subscriber will call request synchronously inputSubscriber->onSubscribe( - yarpl::make_ref(tcpReaderWriter_)); + std::make_shared(tcpReaderWriter_)); tcpReaderWriter_->setInput(std::move(inputSubscriber)); } } // namespace rsocket diff --git a/rsocket/transports/tcp/TcpDuplexConnection.h b/rsocket/transports/tcp/TcpDuplexConnection.h index 4f1ba2444..5bfa9adec 100644 --- a/rsocket/transports/tcp/TcpDuplexConnection.h +++ b/rsocket/transports/tcp/TcpDuplexConnection.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -23,7 +35,7 @@ class TcpDuplexConnection : public DuplexConnection { void send(std::unique_ptr) override; - void setInput(yarpl::Reference) override; + void setInput(std::shared_ptr) override; // Only to be used for observation purposes. folly::AsyncTransportWrapper* getTransport(); @@ -32,4 +44,4 @@ class TcpDuplexConnection : public DuplexConnection { boost::intrusive_ptr tcpReaderWriter_; std::shared_ptr stats_; }; -} +} // namespace rsocket diff --git a/scripts/build_folly.sh b/scripts/build_folly.sh new file mode 100755 index 000000000..ebe67aa00 --- /dev/null +++ b/scripts/build_folly.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# +# Copyright 2004-present Facebook. All Rights Reserved. +# +CHECKOUT_DIR=$1 +INSTALL_DIR=$2 +if [[ -z $INSTALL_DIR ]]; then + echo "usage: $0 CHECKOUT_DIR INSTALL_DIR" >&2 + exit 1 +fi + +# Convert INSTALL_DIR to an absolute path so it still refers to the same +# location after we cd into the build directory. +case "$INSTALL_DIR" in + /*) ;; + *) INSTALL_DIR="$PWD/$INSTALL_DIR" +esac + +# If folly was already installed, just return early +INSTALL_MARKER_FILE="$INSTALL_DIR/folly.installed" +if [[ -f $INSTALL_MARKER_FILE ]]; then + echo "folly was previously built" + exit 0 +fi + +set -e +set -x + +if [[ -d "$CHECKOUT_DIR" ]]; then + git -C "$CHECKOUT_DIR" fetch + git -C "$CHECKOUT_DIR" checkout master +else + git clone https://github.com/facebook/folly "$CHECKOUT_DIR" +fi + +mkdir -p "$CHECKOUT_DIR/_build" +cd "$CHECKOUT_DIR/_build" +if ! cmake \ + "-DCMAKE_PREFIX_PATH=${INSTALL_DIR}" \ + "-DCMAKE_INSTALL_PREFIX=${INSTALL_DIR}" \ + ..; then + echo "error configuring folly" >&2 + tail -n 100 CMakeFiles/CMakeError.log >&2 + exit 1 +fi +make -j4 +make install +touch "$INSTALL_MARKER_FILE" diff --git a/scripts/frame_fuzzer_test.sh b/scripts/frame_fuzzer_test.sh index 07965e54e..785cfa8d1 100755 --- a/scripts/frame_fuzzer_test.sh +++ b/scripts/frame_fuzzer_test.sh @@ -1,5 +1,7 @@ #!/usr/bin/env bash - +# +# Copyright 2004-present Facebook. All Rights Reserved. +# if [ ! -s ./build/frame_fuzzer ]; then echo "./build/frame_fuzzer binary not found!" exit 1 diff --git a/scripts/tck_test.sh b/scripts/tck_test.sh index 3c4550bfb..814b6e309 100755 --- a/scripts/tck_test.sh +++ b/scripts/tck_test.sh @@ -1,5 +1,7 @@ #!/bin/bash - +# +# Copyright 2004-present Facebook. All Rights Reserved. +# if [ "$#" -ne 4 ]; then echo "Illegal number of parameters - $#" exit 1 @@ -48,11 +50,11 @@ if [[ "$OSTYPE" == "darwin"* ]]; then timeout='gtimeout' fi -java_server="java -jar rsocket-tck-drivers-0.9.10.jar --server --host localhost --port 9898 --file tck-test/servertest.txt" -java_client="java -jar rsocket-tck-drivers-0.9.10.jar --client --host localhost --port 9898 --file tck-test/clienttest.txt" +java_server="java -jar rsocket-tck-drivers-0.9.10.jar --server --host localhost --port 9898 --file rsocket/tck-test/servertest.txt" +java_client="java -jar rsocket-tck-drivers-0.9.10.jar --client --host localhost --port 9898 --file rsocket/tck-test/clienttest.txt" -cpp_server="./build/tckserver -test_file tck-test/servertest.txt -rs_use_protocol_version 1.0" -cpp_client="./build/tckclient -test_file tck-test/clienttest.txt -rs_use_protocol_version 1.0" +cpp_server="./build/tckserver -test_file rsocket/tck-test/servertest.txt -rs_use_protocol_version 1.0" +cpp_client="./build/tckclient -test_file rsocket/tck-test/clienttest.txt -rs_use_protocol_version 1.0" server="${server_lang}_server" client="${client_lang}_client" diff --git a/tck-test/FlowableSubscriber.h b/tck-test/FlowableSubscriber.h deleted file mode 100644 index 6cd8a37d7..000000000 --- a/tck-test/FlowableSubscriber.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "tck-test/BaseSubscriber.h" - -#include "yarpl/Flowable.h" - -namespace rsocket { -namespace tck { - -class FlowableSubscriber : public BaseSubscriber, - public yarpl::flowable::Subscriber { - public: - explicit FlowableSubscriber(int initialRequestN = 0); - - // Inherited from BaseSubscriber - void request(int n) override; - void cancel() override; - - protected: - // Inherited from flowable::Subscriber - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload element) noexcept override; - void onComplete() noexcept override; - void onError(folly::exception_wrapper ex) noexcept override; - - private: - yarpl::Reference subscription_; - int initialRequestN_{0}; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/MarbleProcessor.h b/tck-test/MarbleProcessor.h deleted file mode 100644 index 6da7d798c..000000000 --- a/tck-test/MarbleProcessor.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include "rsocket/Payload.h" -#include "yarpl/Flowable.h" -#include "yarpl/Single.h" - -namespace rsocket { -namespace tck { - -class MarbleProcessor { - public: - explicit MarbleProcessor(const std::string /* marble */); - - std::tuple run( - yarpl::Reference> - subscriber, - int64_t requested); - - void run(yarpl::Reference> - subscriber); - - private: - std::string marble_; - - // Stores a mapping from marble character to Payload (data, metadata) - std::map> argMap_; - - // Keeps an account of how many messages can be sent. This could be done - // with Allowance - std::atomic canSend_{0}; - - size_t index_{0}; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/SingleSubscriber.h b/tck-test/SingleSubscriber.h deleted file mode 100644 index 91339a7b0..000000000 --- a/tck-test/SingleSubscriber.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "tck-test/BaseSubscriber.h" - -#include "yarpl/Single.h" - -namespace rsocket { -namespace tck { - -class SingleSubscriber : public BaseSubscriber, - public yarpl::single::SingleObserver { - public: - // Inherited from BaseSubscriber - void request(int n) override; - void cancel() override; - - protected: - // Inherited from flowable::Subscriber - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onSuccess(Payload element) noexcept override; - void onError(folly::exception_wrapper ex) noexcept override; - - private: - yarpl::Reference subscription_; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/TestFileParser.h b/tck-test/TestFileParser.h deleted file mode 100644 index cd0166010..000000000 --- a/tck-test/TestFileParser.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "tck-test/TestSuite.h" - -namespace rsocket { -namespace tck { - -class TestFileParser { - public: - explicit TestFileParser(const std::string& fileName); - - TestSuite parse(); - - private: - void parseCommand(const std::string& command); - void addCurrentTest(); - - std::ifstream input_; - int currentLine_; - - TestSuite testSuite_; - Test currentTest_; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/TestSuite.cpp b/tck-test/TestSuite.cpp deleted file mode 100644 index 0e0dad89e..000000000 --- a/tck-test/TestSuite.cpp +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "tck-test/TestSuite.h" - -#include - -namespace rsocket { -namespace tck { - -bool TestCommand::valid() const { - // there has to be a name to the test and at least 1 param - return params_.size() >= 1; -} - -void Test::addCommand(TestCommand command) { - CHECK(command.valid()); - commands_.push_back(std::move(command)); -} - -} // tck -} // reactivesocket diff --git a/test/RSocketClientTest.cpp b/test/RSocketClientTest.cpp deleted file mode 100644 index 1317a2d8c..000000000 --- a/test/RSocketClientTest.cpp +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "RSocketTests.h" - -#include -#include - -#include "rsocket/transports/tcp/TcpConnectionFactory.h" - -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -TEST(RSocketClient, ConnectFails) { - folly::ScopedEventBaseThread worker; - - folly::SocketAddress address; - address.setFromHostPort("localhost", 1); - auto client = RSocket::createConnectedClient( - std::make_unique(*worker.getEventBase(), - std::move(address))); - - client.then([&](auto&) { - FAIL() << "the test needs to fail"; - }).onError([&](const std::exception&) { - LOG(INFO) << "connection failed as expected"; - }).get(); -} diff --git a/test/RSocketTests.h b/test/RSocketTests.h deleted file mode 100644 index a6b7e2c6e..000000000 --- a/test/RSocketTests.h +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "rsocket/RSocket.h" - -#include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "yarpl/test_utils/utils.h" - -namespace rsocket { -namespace tests { -namespace client_server { - -class RSocketStatsFlowControl : public RSocketStats { -public: - void frameWritten(FrameType frameType) { - if (frameType == FrameType::REQUEST_N) { - ++writeRequestN_; - } - } - - void frameRead(FrameType frameType) { - if (frameType == FrameType::REQUEST_N) { - ++readRequestN_; - } - } - -public: - int writeRequestN_{0}; - int readRequestN_{0}; -}; - -std::unique_ptr getConnFactory( - folly::EventBase* eventBase, - uint16_t port); - -std::unique_ptr makeServer( - std::shared_ptr responder, - std::shared_ptr stats = RSocketStats::noop()); - -std::unique_ptr makeResumableServer( - std::shared_ptr serviceHandler); - -std::unique_ptr makeClient( - folly::EventBase* eventBase, - uint16_t port, - folly::EventBase* stateMachineEvb = nullptr, - std::shared_ptr stats = RSocketStats::noop()); - -std::unique_ptr makeDisconnectedClient( - folly::EventBase* eventBase); - -folly::Future> makeClientAsync( - folly::EventBase* eventBase, - uint16_t port, - folly::EventBase* stateMachineEvb = nullptr, - std::shared_ptr stats = RSocketStats::noop()); - -std::unique_ptr makeWarmResumableClient( - folly::EventBase* eventBase, - uint16_t port, - std::shared_ptr connectionEvents = nullptr, - folly::EventBase* stateMachineEvb = nullptr); - -std::unique_ptr makeColdResumableClient( - folly::EventBase* eventBase, - uint16_t port, - ResumeIdentificationToken token, - std::shared_ptr resumeManager, - std::shared_ptr resumeHandler, - folly::EventBase* stateMachineEvb = nullptr); - -} // namespace client_server -} // namespace tests -} // namespace rsocket diff --git a/test/RequestStreamTest.cpp b/test/RequestStreamTest.cpp deleted file mode 100644 index dd8b35bf9..000000000 --- a/test/RequestStreamTest.cpp +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include - -#include "RSocketTests.h" -#include "yarpl/Flowable.h" -#include "yarpl/flowable/TestSubscriber.h" - -using namespace yarpl; -using namespace yarpl::flowable; -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -namespace { -class TestHandlerSync : public rsocket::RSocketResponder { - public: - Reference> handleRequestStream(Payload request, StreamId) - override { - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 10)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); - } -}; - -TEST(RequestStreamTest, HelloSync) { - folly::ScopedEventBaseThread worker; - auto server = makeServer(std::make_shared()); - auto client = makeClient(worker.getEventBase(), *server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(10); - ts->assertValueAt(0, "Hello Bob 1!"); - ts->assertValueAt(9, "Hello Bob 10!"); -} - -TEST(RequestStreamTest, HelloFlowControl) { - folly::ScopedEventBaseThread worker; - auto server = makeServer(std::make_shared()); - auto client = makeClient(worker.getEventBase(), *server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(5); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - - ts->awaitValueCount(5); - - ts->assertValueCount(5); - ts->assertValueAt(0, "Hello Bob 1!"); - ts->assertValueAt(4, "Hello Bob 5!"); - - ts->request(5); - - ts->awaitValueCount(10); - - ts->assertValueCount(10); - ts->assertValueAt(5, "Hello Bob 6!"); - ts->assertValueAt(9, "Hello Bob 10!"); - - ts->awaitTerminalEvent(); - ts->assertSuccess(); -} - -TEST(RequestStreamTest, HelloNoFlowControl) { - folly::ScopedEventBaseThread worker; - auto server = makeServer(std::make_shared()); - auto stats = std::make_shared(); - auto client = makeClient( - worker.getEventBase(), *server->listeningPort(), nullptr, stats); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(10); - ts->assertValueAt(0, "Hello Bob 1!"); - ts->assertValueAt(9, "Hello Bob 10!"); - - // Make sure that the initial requestN in the Stream Request Frame - // is already enough and no other requestN messages are sent. - EXPECT_EQ(stats->writeRequestN_, 0); -} - -class TestHandlerAsync : public rsocket::RSocketResponder { - public: - Reference> handleRequestStream(Payload request, StreamId) - override { - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::fromPublisher< - Payload>([requestString = std::move(requestString)]( - Reference> subscriber) { - std::thread([ - requestString = std::move(requestString), - subscriber = std::move(subscriber) - ]() { - Flowables::range(1, 40) - ->map([name = std::move(requestString)](int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }) - ->subscribe(subscriber); - }) - .detach(); - }); - } -}; -} // namespace - -TEST(RequestStreamTest, HelloAsync) { - folly::ScopedEventBaseThread worker; - auto server = makeServer(std::make_shared()); - auto client = makeClient(worker.getEventBase(), *server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(40); - ts->assertValueAt(0, "Hello Bob 1!"); - ts->assertValueAt(39, "Hello Bob 40!"); -} - -TEST(RequestStreamTest, RequestOnDisconnectedClient) { - folly::ScopedEventBaseThread worker; - auto client = makeDisconnectedClient(worker.getEventBase()); - auto requester = client->getRequester(); - - bool did_call_on_error = false; - folly::Baton<> wait_for_on_error; - - requester->requestStream(Payload("foo", "bar")) - ->subscribe( - [](auto /* payload */) { - // onNext shouldn't be called - FAIL(); - }, - [&](folly::exception_wrapper) { - did_call_on_error = true; - wait_for_on_error.post(); - }, - []() { - // onComplete shouldn't be called - FAIL(); - }); - - CHECK_WAIT(wait_for_on_error); - ASSERT(did_call_on_error); -} - -class TestHandlerResponder : public rsocket::RSocketResponder { - public: - Reference> handleRequestStream(Payload, StreamId) override { - return Flowables::error( - std::runtime_error("A wild Error appeared!")); - } -}; - -TEST(RequestStreamTest, HandleError) { - folly::ScopedEventBaseThread worker; - auto server = makeServer(std::make_shared()); - auto client = makeClient(worker.getEventBase(), *server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertOnErrorMessage("A wild Error appeared!"); -} - -class TestErrorAfterOnNextResponder : public rsocket::RSocketResponder { - public: - Reference> handleRequestStream(Payload request, StreamId) - override { - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowable::create([name = std::move(requestString)]( - Reference> subscriber, int64_t requested) { - EXPECT_GT(requested, 1); - subscriber->onNext(Payload(name, "meta")); - subscriber->onNext(Payload(name, "meta")); - subscriber->onNext(Payload(name, "meta")); - subscriber->onNext(Payload(name, "meta")); - subscriber->onError(std::runtime_error("A wild Error appeared!")); - return std::make_tuple(int64_t(4), true); - }); - } -}; - -TEST(RequestStreamTest, HandleErrorMidStream) { - folly::ScopedEventBaseThread worker; - auto server = makeServer(std::make_shared()); - auto client = makeClient(worker.getEventBase(), *server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertValueCount(4); - ts->assertOnErrorMessage("A wild Error appeared!"); -} diff --git a/test/RequestStreamTest_concurrency.cpp b/test/RequestStreamTest_concurrency.cpp deleted file mode 100644 index f68b19af4..000000000 --- a/test/RequestStreamTest_concurrency.cpp +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include -#include - -#include "RSocketTests.h" -#include "yarpl/Flowable.h" -#include "yarpl/flowable/TestSubscriber.h" - -#include "yarpl/test_utils/Mocks.h" - -using namespace yarpl; -using namespace yarpl::flowable; -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -struct LockstepBatons { - folly::Baton<> onSecondPayloadSent; - folly::Baton<> onCancelSent; - folly::Baton<> onCancelReceivedToserver; - folly::Baton<> onCancelReceivedToclient; - folly::Baton<> onRequestReceived; - folly::Baton<> clientFinished; - folly::Baton<> serverFinished; -}; - -using namespace yarpl::mocks; -using namespace ::testing; - -#define LOCKSTEP_DEBUG(expr) VLOG(3) << expr - -class LockstepAsyncHandler : public rsocket::RSocketResponder { - LockstepBatons& batons_; - Sequence& subscription_seq_; - - public: - LockstepAsyncHandler(LockstepBatons& batons, Sequence& subscription_seq) - : batons_(batons), subscription_seq_(subscription_seq){} - - Reference> handleRequestStream(Payload p, StreamId) - override { - EXPECT_EQ(p.moveDataToString(), "initial"); - return Flowables::fromPublisher( - [this](Reference> subscriber) { - auto subscription = make_ref>(); - - std::thread([=] { - CHECK_WAIT(this->batons_.onRequestReceived); - - LOCKSTEP_DEBUG("SERVER: sending onNext(foo)"); - subscriber->onNext(Payload("foo")); - CHECK_WAIT(this->batons_.onCancelSent); - CHECK_WAIT(this->batons_.onCancelReceivedToserver); - - LOCKSTEP_DEBUG("SERVER: sending onNext(bar)"); - subscriber->onNext(Payload("bar")); - this->batons_.onSecondPayloadSent.post(); - LOCKSTEP_DEBUG("SERVER: sending onComplete()"); - subscriber->onComplete(); - LOCKSTEP_DEBUG("SERVER: posting serverFinished"); - this->batons_.serverFinished.post(); - }).detach(); - - // checked once the subscription is destroyed - EXPECT_CALL(*subscription, request_(2)) - .InSequence(this->subscription_seq_) - .WillOnce(Invoke([=](auto n) { - LOCKSTEP_DEBUG("SERVER: got request(" << n << ")"); - EXPECT_EQ(n, 2); - this->batons_.onRequestReceived.post(); - })); - - EXPECT_CALL(*subscription, cancel_()) - .InSequence(this->subscription_seq_) - .WillOnce(Invoke([=] { - LOCKSTEP_DEBUG("SERVER: received cancel()"); - this->batons_.onCancelReceivedToclient.post(); - this->batons_.onCancelReceivedToserver.post(); - })); - - LOCKSTEP_DEBUG("SERVER: sending onSubscribe()"); - subscriber->onSubscribe(subscription); - }); - } -}; - -// FIXME: This hits an ASAN heap-use-after-free. Disabling for now, but we need -// to get back to this and fix it. -TEST(RequestStreamTest, OperationsAfterCancel) { - LockstepBatons batons; - Sequence server_seq; - Sequence client_seq; - - auto server = - makeServer(std::make_shared(batons, server_seq)); - folly::ScopedEventBaseThread worker; - auto client = makeClient(worker.getEventBase(), *server->listeningPort()); - auto requester = client->getRequester(); - - auto subscriber_mock = - make_ref>>( - 0); - - Reference subscription; - EXPECT_CALL(*subscriber_mock, onSubscribe_(_)) - .InSequence(client_seq) - .WillOnce(Invoke([&](auto s) { - LOCKSTEP_DEBUG("CLIENT: got onSubscribe(), sending request(2)"); - EXPECT_NE(s, nullptr); - subscription = s; - subscription->request(2); - })); - EXPECT_CALL(*subscriber_mock, onNext_("foo")) - .InSequence(client_seq) - .WillOnce(Invoke([&](auto) { - EXPECT_NE(subscription, nullptr); - LOCKSTEP_DEBUG("CLIENT: got onNext(foo), sending cancel()"); - subscription->cancel(); - batons.onCancelSent.post(); - CHECK_WAIT(batons.onCancelReceivedToclient); - CHECK_WAIT(batons.onSecondPayloadSent); - batons.clientFinished.post(); - })); - - // shouldn't receive 'bar', we canceled syncronously with the Subscriber - // had 'cancel' been called in a different thread with no synchronization, - // the client's Subscriber _could_ have received 'bar' - - LOCKSTEP_DEBUG("RUNNER: doing requestStream()"); - requester->requestStream(Payload("initial")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(subscriber_mock); - - CHECK_WAIT(batons.clientFinished); - CHECK_WAIT(batons.serverFinished); - LOCKSTEP_DEBUG("RUNNER: finished!"); -} diff --git a/test/Test.cpp b/test/Test.cpp deleted file mode 100644 index 884d4736b..000000000 --- a/test/Test.cpp +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include - -int main(int argc, char** argv) { - FLAGS_logtostderr = true; - testing::InitGoogleMock(&argc, argv); - folly::init(&argc, &argv); - return RUN_ALL_TESTS(); -} diff --git a/test/handlers/HelloStreamRequestHandler.cpp b/test/handlers/HelloStreamRequestHandler.cpp deleted file mode 100644 index 1ab33eae3..000000000 --- a/test/handlers/HelloStreamRequestHandler.cpp +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "HelloStreamRequestHandler.h" - -#include - -#include - -#include "yarpl/Flowable.h" - -using namespace ::rsocket; -using namespace yarpl; -using namespace yarpl::flowable; - -namespace rsocket { -namespace tests { -/// Handles a new inbound Stream requested by the other end. -Reference> -HelloStreamRequestHandler::handleRequestStream( - rsocket::Payload request, - rsocket::StreamId) { - VLOG(3) << "HelloStreamRequestHandler.handleRequestStream " << request; - - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 10)->map([name = std::move(requestString)]( - int64_t v) { return Payload(folly::to(v), "metadata"); }); -} -} -} diff --git a/test/handlers/HelloStreamRequestHandler.h b/test/handlers/HelloStreamRequestHandler.h deleted file mode 100644 index 768641bda..000000000 --- a/test/handlers/HelloStreamRequestHandler.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/RSocketResponder.h" -#include "yarpl/Flowable.h" - -namespace rsocket { -namespace tests { - -class HelloStreamRequestHandler : public RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) - override; -}; -} -} diff --git a/test/internal/ResumeIdentificationToken.cpp b/test/internal/ResumeIdentificationToken.cpp deleted file mode 100644 index 9ffe18c69..000000000 --- a/test/internal/ResumeIdentificationToken.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include "rsocket/internal/Common.h" - -using namespace testing; -using namespace rsocket; - -TEST(ResumeIdentificationTokenTest, Conversion) { - for (int i = 0; i < 10; i++) { - auto token = ResumeIdentificationToken::generateNew(); - auto token2 = ResumeIdentificationToken(token.str()); - CHECK_EQ(token, token2); - CHECK_EQ(token.str(), token2.str()); - } -} diff --git a/test/statemachine/StreamStateTest.cpp b/test/statemachine/StreamStateTest.cpp deleted file mode 100644 index 142b425f2..000000000 --- a/test/statemachine/StreamStateTest.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include "rsocket/statemachine/StreamState.h" -#include "test/test_utils/MockStats.h" - -using namespace rsocket; -using namespace testing; - -class StreamStateTest : public Test { - protected: - StrictMock stats_; - StreamState state_{stats_}; -}; - -TEST_F(StreamStateTest, Stats) { - auto frame1Size = 7, frame2Size = 11; - EXPECT_CALL(stats_, streamBufferChanged(1, frame1Size)); - state_.enqueueOutputPendingFrame( - folly::IOBuf::copyBuffer(std::string(frame1Size, 'x'))); - EXPECT_CALL(stats_, streamBufferChanged(1, frame2Size)); - state_.enqueueOutputPendingFrame( - folly::IOBuf::copyBuffer(std::string(frame2Size, 'x'))); - EXPECT_CALL(stats_, streamBufferChanged(-2, -(frame1Size + frame2Size))); - state_.moveOutputPendingFrames(); -} - -TEST_F(StreamStateTest, StatsUpdatedInDtor) { - auto frameSize = 7; - EXPECT_CALL(stats_, streamBufferChanged(1, frameSize)); - state_.enqueueOutputPendingFrame( - folly::IOBuf::copyBuffer(std::string(frameSize, 'x'))); - EXPECT_CALL(stats_, streamBufferChanged(-1, -frameSize)); -} diff --git a/test/test_utils/MockFrameProcessor.h b/test/test_utils/MockFrameProcessor.h deleted file mode 100644 index ab1393f0f..000000000 --- a/test/test_utils/MockFrameProcessor.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include -#include - -#include "rsocket/framing/FrameProcessor.h" - -namespace rsocket { - -class MockFrameProcessor : public FrameProcessor { -public: - void processFrame(std::unique_ptr buf) override { - processFrame_(buf); - } - - void onTerminal(folly::exception_wrapper ew) override { - onTerminal_(std::move(ew)); - } - - MOCK_METHOD1(processFrame_, void(std::unique_ptr&)); - MOCK_METHOD1(onTerminal_, void(folly::exception_wrapper)); -}; - -} diff --git a/test/test_utils/MockKeepaliveTimer.h b/test/test_utils/MockKeepaliveTimer.h deleted file mode 100644 index 272ddc357..000000000 --- a/test/test_utils/MockKeepaliveTimer.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include - -#include "rsocket/statemachine/RSocketStateMachine.h" -#include "test/deprecated/ReactiveSocket.h" - -namespace rsocket { -class MockKeepaliveTimer : public KeepaliveTimer { - public: - MOCK_METHOD1(start, void(const std::shared_ptr&)); - MOCK_METHOD0(stop, void()); - MOCK_METHOD0(keepaliveReceived, void()); - MOCK_METHOD0(keepaliveTime, std::chrono::milliseconds()); -}; -} diff --git a/test/test_utils/MockRequestHandler.h b/test/test_utils/MockRequestHandler.h deleted file mode 100644 index d0cfb7135..000000000 --- a/test/test_utils/MockRequestHandler.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include - -#include "rsocket/Payload.h" -#include "rsocket/temporary_home/RequestHandler.h" - -namespace rsocket { - -class MockRequestHandler : public RequestHandler { - public: - MOCK_METHOD3( - handleRequestChannel_, - yarpl::Reference>( - Payload& request, - StreamId streamId, - const yarpl::Reference>&)); - MOCK_METHOD3( - handleRequestStream_, - void( - Payload& request, - StreamId streamId, - const yarpl::Reference>&)); - MOCK_METHOD3( - handleRequestResponse_, - void( - Payload& request, - StreamId streamId, - const yarpl::Reference>&)); - MOCK_METHOD2( - handleFireAndForgetRequest_, - void(Payload& request, StreamId streamId)); - MOCK_METHOD1( - handleMetadataPush_, - void(std::unique_ptr& request)); - MOCK_METHOD1( - handleSetupPayload_, - std::shared_ptr(SetupParameters& request)); - MOCK_METHOD1(handleResume_, bool(ResumeParameters& resumeParams)); - - yarpl::Reference> handleRequestChannel( - Payload request, - StreamId streamId, - const yarpl::Reference>& - response) noexcept override { - return handleRequestChannel_(request, streamId, response); - } - - void handleRequestStream( - Payload request, - StreamId streamId, - const yarpl::Reference>& - response) noexcept override { - handleRequestStream_(request, streamId, response); - } - - void handleRequestResponse( - Payload request, - StreamId streamId, - const yarpl::Reference>& - response) noexcept override { - handleRequestResponse_(request, streamId, response); - } - - void handleFireAndForgetRequest( - Payload request, - StreamId streamId) noexcept override { - handleFireAndForgetRequest_(request, streamId); - } - - void handleMetadataPush( - std::unique_ptr request) noexcept override { - handleMetadataPush_(request); - } - - std::shared_ptr handleSetupPayload( - SetupParameters request) noexcept override { - return handleSetupPayload_(request); - } - - bool handleResume(ResumeParameters resumeParams) noexcept override { - return handleResume_(resumeParams); - } - - void handleCleanResume(yarpl::Reference - response) noexcept override {} - void handleDirtyResume(yarpl::Reference - response) noexcept override {} - - MOCK_METHOD1( - onSubscriptionPaused_, - void(const yarpl::Reference&)); - void onSubscriptionPaused( - const yarpl::Reference& - subscription) noexcept override { - onSubscriptionPaused_(std::move(subscription)); - } - void onSubscriptionResumed( - const yarpl::Reference& - subscription) noexcept override {} - void onSubscriberPaused( - const yarpl::Reference>& - subscriber) noexcept override {} - void onSubscriberResumed( - const yarpl::Reference>& - subscriber) noexcept override {} - - MOCK_METHOD0(socketOnConnected, void()); - - MOCK_METHOD1(socketOnClosed, void(folly::exception_wrapper& listener)); - MOCK_METHOD1(socketOnDisconnected, void(folly::exception_wrapper& listener)); -}; -} diff --git a/test/test_utils/PrintSubscriber.h b/test/test_utils/PrintSubscriber.h deleted file mode 100644 index 46436e9d9..000000000 --- a/test/test_utils/PrintSubscriber.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/Payload.h" -#include "yarpl/flowable/Subscriber.h" - -namespace rsocket { -class PrintSubscriber : public yarpl::flowable::Subscriber { - public: - ~PrintSubscriber(); - - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload element) noexcept override; - void onComplete() noexcept override; - void onError(folly::exception_wrapper ex) noexcept override; -}; -} diff --git a/yarpl/CMakeLists.txt b/yarpl/CMakeLists.txt index 5fd769d01..f4159b82c 100644 --- a/yarpl/CMakeLists.txt +++ b/yarpl/CMakeLists.txt @@ -1,12 +1,13 @@ cmake_minimum_required (VERSION 3.2) - -# To debug the project, set the build type. -set(CMAKE_BUILD_TYPE Debug) - project (yarpl) # CMake Config -set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../cmake/) +set(CMAKE_MODULE_PATH + ${CMAKE_CURRENT_SOURCE_DIR}/../cmake/ + # For shipit-transformed builds + "${CMAKE_CURRENT_SOURCE_DIR}/../build/fbcode_builder/CMake" + ${CMAKE_MODULE_PATH} +) add_definitions(-std=c++14) option(BUILD_TESTS "BUILD_TESTS" ON) @@ -15,11 +16,32 @@ option(BUILD_TESTS "BUILD_TESTS" ON) set(CMAKE_EXPORT_COMPILE_COMMANDS 1) # Common configuration for all build modes. -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-unused-parameter") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables -Wno-padded") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -momit-leaf-frame-pointer") +if (NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-unused-parameter") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables -Wno-padded") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + include(CheckCXXCompilerFlag) + CHECK_CXX_COMPILER_FLAG("-momit-leaf-frame-pointer" HAVE_OMIT_LEAF_FRAME_POINTER) + if(HAVE_OMIT_LEAF_FRAME_POINTER) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -momit-leaf-frame-pointer") + endif() +endif() + +if(YARPL_WRAP_SHARED_IN_LOCK) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DYARPL_WRAP_SHARED_IN_LOCK") + message("Compiler lacks support std::atomic; wrapping with a mutex") +elseif(YARPL_WRAP_SHARED_IN_ATOMIC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DYARPL_WRAP_SHARED_IN_ATOMIC") + message("Compiler lacks std::shared_ptr atomic overloads; wrapping in std::atomic") +else() + message("Compiler has atomic std::shared_ptr support") +endif() + + +if(${CMAKE_CXX_COMPILER_ID} MATCHES GNU) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -latomic") +endif() # The yarpl-tests binary constantly fails with an ASAN error in gtest internal # code on macOS. @@ -28,95 +50,108 @@ if(APPLE AND ${CMAKE_CXX_COMPILER_ID} MATCHES Clang) add_compile_options("-fno-sanitize=address,undefined") endif() -option(YARPL_REFCOUNT_DEBUGGING "Enable refcount debugging/leak checking in Yarpl" OFF) -if(YARPL_REFCOUNT_DEBUGGING) - add_compile_options(-DYARPL_REFCOUNT_DEBUGGING) -endif() - # Using NDEBUG in Release builds. set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG") -find_path(GLOG_INCLUDE_DIR glog/logging.h) -find_library(GLOG_LIBRARY glog) +find_package(Gflags REQUIRED) +find_package(Glog REQUIRED) +find_package(fmt CONFIG REQUIRED) IF(NOT FOLLY_VERSION) include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/InstallFolly.cmake) ENDIF() -message("glog include_dir <${GLOG_INCLUDE_DIR}> lib <${GLOG_LIBRARY}>") - -include_directories(SYSTEM ${GLOG_INCLUDE_DIR}) -include_directories(${CMAKE_SOURCE_DIR}) +include_directories(SYSTEM ${GFLAGS_INCLUDE_DIR}) # library source add_library( yarpl # public API - include/yarpl/Refcounted.h - src/yarpl/Refcounted.cpp + Refcounted.h + Common.h # Flowable public API - include/yarpl/Flowable.h - include/yarpl/flowable/EmitterFlowable.h - include/yarpl/flowable/Flowable.h - include/yarpl/flowable/FlowableOperator.h - include/yarpl/flowable/FlowableObserveOnOperator.h - include/yarpl/flowable/Flowable_FromObservable.h - include/yarpl/flowable/Flowables.h - include/yarpl/flowable/Subscriber.h - include/yarpl/flowable/Subscribers.h - include/yarpl/flowable/Subscription.h - include/yarpl/flowable/TestSubscriber.h - src/yarpl/flowable/sources/Subscription.cpp + Flowable.h + flowable/DeferFlowable.h + flowable/EmitterFlowable.h + flowable/Flowable.h + flowable/FlowableOperator.h + flowable/FlowableConcatOperators.h + flowable/FlowableDoOperator.h + flowable/FlowableObserveOnOperator.h + flowable/Flowable_FromObservable.h + flowable/Flowables.h + flowable/PublishProcessor.h + flowable/Subscriber.h + flowable/Subscription.h + flowable/TestSubscriber.h + flowable/Subscription.cpp + flowable/Flowables.cpp # Observable public API - include/yarpl/Observable.h - include/yarpl/observable/Observable.h - include/yarpl/observable/Observables.h - include/yarpl/observable/ObservableOperator.h - include/yarpl/observable/ObservableDoOperator.h - include/yarpl/observable/Observer.h - include/yarpl/observable/Observers.h - include/yarpl/observable/Subscription.h - include/yarpl/observable/Subscriptions.h - include/yarpl/observable/TestObserver.h - src/yarpl/observable/Subscriptions.cpp + Observable.h + observable/DeferObservable.h + observable/Observable.h + observable/Observables.h + observable/ObservableOperator.h + observable/ObservableConcatOperators.h + observable/ObservableDoOperator.h + observable/Observer.h + observable/Subscription.h + observable/TestObserver.h + observable/Subscription.cpp + observable/Observables.cpp # Single - include/yarpl/Single.h - include/yarpl/single/Single.h - include/yarpl/single/Singles.h - include/yarpl/single/SingleOperator.h - include/yarpl/single/SingleObserver.h - include/yarpl/single/SingleObservers.h - include/yarpl/single/SingleSubscription.h - include/yarpl/single/SingleSubscriptions.h - include/yarpl/single/SingleTestObserver.h + Single.h + single/Single.h + single/Singles.h + single/SingleOperator.h + single/SingleObserver.h + single/SingleObservers.h + single/SingleSubscription.h + single/SingleSubscriptions.h + single/SingleTestObserver.h # utils - include/yarpl/utils/type_traits.h - include/yarpl/utils/credits.h - src/yarpl/utils/credits.cpp) - + utils/credits.h + utils/credits.cpp) target_include_directories( - yarpl - PUBLIC "${PROJECT_SOURCE_DIR}/include" # allow include paths such as "yarpl/observable.h" - PUBLIC "${PROJECT_SOURCE_DIR}/src" # allow include paths such as "yarpl/flowable/FlowableRange.h" - ) + yarpl + PUBLIC + $ + $ +) + +message("yarpl source dir: ${CMAKE_CURRENT_SOURCE_DIR}") target_link_libraries( yarpl - folly - ${GLOG_LIBRARY}) - -install(TARGETS yarpl DESTINATION lib) -install(DIRECTORY include/yarpl DESTINATION include - FILES_MATCHING PATTERN "*.h") + PUBLIC Folly::folly glog::glog gflags + INTERFACE ${EXTRA_LINK_FLAGS}) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + cmake/yarpl-config.cmake.in + yarpl-config.cmake + INSTALL_DESTINATION lib/cmake/yarpl +) +install(TARGETS yarpl EXPORT yarpl-exports DESTINATION lib) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DESTINATION include FILES_MATCHING PATTERN "*.h") +install( + EXPORT yarpl-exports + NAMESPACE yarpl:: + DESTINATION lib/cmake/yarpl +) +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/yarpl-config.cmake + DESTINATION lib/cmake/yarpl +) + +# RSocket's tests also has dependency on this library +add_library( + yarpl-test-utils + test_utils/Tuple.cpp + test_utils/Tuple.h + test_utils/Mocks.h) if (BUILD_TESTS) - add_library( - yarpl-test-utils - test/yarpl/test_utils/Tuple.cpp - test/yarpl/test_utils/Tuple.h - test/yarpl/test_utils/utils.h - test/yarpl/test_utils/Mocks.h) - # Executable for experimenting. add_executable( yarpl-playground @@ -133,22 +168,24 @@ if (BUILD_TESTS) test/FlowableTest.cpp test/FlowableFlatMapTest.cpp test/Observable_test.cpp - test/RefcountedTest.cpp - test/ReferenceTest.cpp + test/PublishProcessorTest.cpp test/SubscribeObserveOnTests.cpp test/Single_test.cpp test/FlowableSubscriberTest.cpp test/credits-test.cpp test/yarpl-tests.cpp) + add_dependencies(yarpl-tests gmock) target_link_libraries( yarpl-tests yarpl yarpl-test-utils - ${GLOG_LIBRARY} + glog::glog + gflags # Inherited from rsocket-cpp CMake. - ${GMOCK_LIBS}) + ${GMOCK_LIBS} # This also needs the preceding `add_dependencies` + ) add_dependencies(yarpl-tests yarpl-test-utils gmock) diff --git a/yarpl/Common.h b/yarpl/Common.h new file mode 100644 index 000000000..be9d8287c --- /dev/null +++ b/yarpl/Common.h @@ -0,0 +1,69 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace yarpl { + +namespace observable { +template +class Observable; +} // namespace observable + +namespace flowable { +template +class Subscriber; + +// Exception thrown in case the downstream can't keep up. +class MissingBackpressureException : public std::runtime_error { + public: + MissingBackpressureException() + : std::runtime_error("BACK_PRESSURE: DROP (missing credits onNext)") {} +}; + +} // namespace flowable + +/** + *Strategy for backpressure when converting from Observable to Flowable. + */ +enum class BackpressureStrategy { + BUFFER, // Buffers all onNext values until the downstream consumes them. + DROP, // Drops the most recent onNext value if the downstream can't keep up. + ERROR, // Signals a MissingBackpressureException in case the downstream can't + // keep up. + LATEST, // Keeps only the latest onNext value, overwriting any previous value + // if the downstream can't keep up. + MISSING // OnNext events are written without any buffering or dropping. +}; + +template +class IBackpressureStrategy { + public: + virtual ~IBackpressureStrategy() = default; + + virtual void init( + std::shared_ptr> upstream, + std::shared_ptr> downstream) = 0; + + static std::shared_ptr> buffer(); + static std::shared_ptr> drop(); + static std::shared_ptr> error(); + static std::shared_ptr> latest(); + static std::shared_ptr> missing(); +}; + +} // namespace yarpl diff --git a/yarpl/Disposable.h b/yarpl/Disposable.h new file mode 100644 index 000000000..6cc5d8264 --- /dev/null +++ b/yarpl/Disposable.h @@ -0,0 +1,42 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace yarpl { + +/** + * Represents a disposable resource. + */ +class Disposable { + public: + Disposable() {} + virtual ~Disposable() = default; + Disposable(Disposable&&) = delete; + Disposable(const Disposable&) = delete; + Disposable& operator=(Disposable&&) = delete; + Disposable& operator=(const Disposable&) = delete; + + /** + * Dispose the resource, the operation should be idempotent. + */ + virtual void dispose() = 0; + + /** + * Returns true if this resource has been disposed. + * @return true if this resource has been disposed + */ + virtual bool isDisposed() = 0; +}; +} // namespace yarpl diff --git a/yarpl/Flowable.h b/yarpl/Flowable.h new file mode 100644 index 000000000..34014ddd7 --- /dev/null +++ b/yarpl/Flowable.h @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// include all the things a developer needs for using Flowable +#include "yarpl/flowable/Flowable.h" +#include "yarpl/flowable/Flowables.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/flowable/Subscription.h" + +/** + * // TODO add documentation + */ diff --git a/yarpl/Observable.h b/yarpl/Observable.h new file mode 100644 index 000000000..d115d5160 --- /dev/null +++ b/yarpl/Observable.h @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// include all the things a developer needs for using Observable +#include "yarpl/observable/Observable.h" +#include "yarpl/observable/Observables.h" +#include "yarpl/observable/Observer.h" +#include "yarpl/observable/Subscription.h" + +/** + * // TODO add documentation + */ diff --git a/yarpl/Refcounted.h b/yarpl/Refcounted.h new file mode 100644 index 000000000..ac0a4950d --- /dev/null +++ b/yarpl/Refcounted.h @@ -0,0 +1,85 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace yarpl { + +template +struct AtomicReference { + folly::Synchronized, std::mutex> ref; + + AtomicReference() = default; + + AtomicReference(std::shared_ptr&& r) { + *(ref.lock()) = std::move(r); + } +}; + +template +std::shared_ptr atomic_load(AtomicReference* ar) { + return *(ar->ref.lock()); +} + +template +std::shared_ptr atomic_exchange( + AtomicReference* ar, + std::shared_ptr r) { + auto refptr = ar->ref.lock(); + auto old = std::move(*refptr); + *refptr = std::move(r); + return old; +} + +template +std::shared_ptr atomic_exchange(AtomicReference* ar, std::nullptr_t) { + return atomic_exchange(ar, std::shared_ptr()); +} + +template +void atomic_store(AtomicReference* ar, std::shared_ptr r) { + *ar->ref.lock() = std::move(r); +} + +class enable_get_ref : public std::enable_shared_from_this { + private: + virtual void dummy_internal_get_ref() {} + + protected: + // materialize a reference to 'this', but a type even further derived from + // Derived, because C++ doesn't have covariant return types on methods + template + std::shared_ptr ref_from_this(As* ptr) { + // at runtime, ensure that the most derived class can indeed be + // converted into an 'as' + (void)ptr; // silence 'unused parameter' errors in Release builds + return std::static_pointer_cast(this->shared_from_this()); + } + + template + std::shared_ptr ref_from_this(As const* ptr) const { + // at runtime, ensure that the most derived class can indeed be + // converted into an 'as' + (void)ptr; // silence 'unused parameter' errors in Release builds + return std::static_pointer_cast(this->shared_from_this()); + } + + public: + virtual ~enable_get_ref() = default; +}; + +} /* namespace yarpl */ diff --git a/yarpl/Single.h b/yarpl/Single.h new file mode 100644 index 000000000..c5b737b5f --- /dev/null +++ b/yarpl/Single.h @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Refcounted.h" + +// include all the things a developer needs for using Single +#include "yarpl/single/Single.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleObservers.h" +#include "yarpl/single/SingleSubscriptions.h" +#include "yarpl/single/Singles.h" + +/** + * Create a single with code such as this: + * + * auto a = Single::create([](std::shared_ptr> obs) { + * obs->onSubscribe(SingleSubscriptions::empty()); + * obs->onSuccess(1); + * }); + * + * // TODO add more documentation + */ diff --git a/yarpl/cmake/yarpl-config.cmake.in b/yarpl/cmake/yarpl-config.cmake.in new file mode 100644 index 000000000..d557b2135 --- /dev/null +++ b/yarpl/cmake/yarpl-config.cmake.in @@ -0,0 +1,13 @@ +# Copyright (c) 2018, Facebook, Inc. +# All rights reserved. + +@PACKAGE_INIT@ + +if(NOT TARGET yarpl::yarpl) + include("${PACKAGE_PREFIX_DIR}/lib/cmake/yarpl/yarpl-exports.cmake") +endif() + +set(YARPL_LIBRARIES yarpl::yarpl) +if (NOT yarpl_FIND_QUIETLY) + message(STATUS "Found YARPL: ${PACKAGE_PREFIX_DIR}") +endif() diff --git a/yarpl/examples/FlowableExamples.cpp b/yarpl/examples/FlowableExamples.cpp index 915764906..224726166 100644 --- a/yarpl/examples/FlowableExamples.cpp +++ b/yarpl/examples/FlowableExamples.cpp @@ -1,30 +1,38 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "FlowableExamples.h" - +#include #include #include #include #include - -#include - #include "yarpl/Flowable.h" -using namespace yarpl; using namespace yarpl::flowable; namespace { template auto printer() { - return Subscribers::create( + return Subscriber::create( [](T value) { std::cout << " next: " << value << std::endl; }, 2 /* low [optional] batch size for demo */); } -Reference> getData() { - return Flowables::range(2, 5); +std::shared_ptr> getData() { + return Flowable<>::range(2, 5); } std::string getThreadId() { @@ -34,7 +42,7 @@ std::string getThreadId() { } void fromPublisherExample() { - auto onSubscribe = [](Reference> subscriber) { + auto onSubscribe = [](std::shared_ptr> subscriber) { class Subscription : public ::yarpl::flowable::Subscription { public: virtual void request(int64_t delta) override { @@ -46,7 +54,7 @@ void fromPublisherExample() { } }; - auto subscription = make_ref(); + auto subscription = std::make_shared(); subscriber->onSubscribe(subscription); subscriber->onNext(1234); subscriber->onNext(5678); @@ -54,7 +62,7 @@ void fromPublisherExample() { subscriber->onComplete(); }; - Flowables::fromPublisher(std::move(onSubscribe)) + Flowable::fromPublisher(std::move(onSubscribe)) ->subscribe(printer()); } @@ -62,31 +70,31 @@ void fromPublisherExample() { void FlowableExamples::run() { std::cout << "create a flowable" << std::endl; - Flowables::range(2, 2); + Flowable<>::range(2, 2); std::cout << "get a flowable from a method" << std::endl; getData()->subscribe(printer()); std::cout << "just: single value" << std::endl; - Flowables::just(23)->subscribe(printer()); + Flowable<>::just(23)->subscribe(printer()); std::cout << "just: multiple values." << std::endl; - Flowables::justN({1, 4, 7, 11})->subscribe(printer()); + Flowable<>::justN({1, 4, 7, 11})->subscribe(printer()); std::cout << "just: string values." << std::endl; - Flowables::justN({"the", "quick", "brown", "fox"}) + Flowable<>::justN({"the", "quick", "brown", "fox"}) ->subscribe(printer()); std::cout << "range operator." << std::endl; - Flowables::range(1, 4)->subscribe(printer()); + Flowable<>::range(1, 4)->subscribe(printer()); std::cout << "map example: squares" << std::endl; - Flowables::range(1, 4) + Flowable<>::range(1, 4) ->map([](int64_t v) { return v * v; }) ->subscribe(printer()); std::cout << "map example: convert to string" << std::endl; - Flowables::range(1, 4) + Flowable<>::range(1, 4) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return std::to_string(v); }) @@ -94,16 +102,15 @@ void FlowableExamples::run() { ->subscribe(printer()); std::cout << "take example: 3 out of 10 items" << std::endl; - Flowables::range(1, 11)->take(3)->subscribe(printer()); + Flowable<>::range(1, 11)->take(3)->subscribe(printer()); - auto flowable = Flowable::create([total = 0]( - Reference> subscriber, int64_t requested) mutable { - subscriber->onNext(12345678); - subscriber->onError(std::runtime_error("error")); - return std::make_tuple(int64_t{1}, false); - }); + auto flowable = Flowable::create( + [total = 0](auto& subscriber, int64_t requested) mutable { + subscriber.onNext(12345678); + subscriber.onError(std::runtime_error("error")); + }); - auto subscriber = Subscribers::create( + auto subscriber = Subscriber::create( [](int next) { std::cout << "@next: " << next << std::endl; }, [](folly::exception_wrapper ex) { std::cerr << " exception: " << ex << std::endl; @@ -115,7 +122,7 @@ void FlowableExamples::run() { folly::ScopedEventBaseThread worker; std::cout << "subscribe_on example" << std::endl; - Flowables::justN({"0: ", "1: ", "2: "}) + Flowable<>::justN({"0: ", "1: ", "2: "}) ->map([](const char* p) { return std::string(p); }) ->map([](std::string log) { return log + " on " + getThreadId(); }) ->subscribeOn(*worker.getEventBase()) diff --git a/yarpl/examples/FlowableExamples.h b/yarpl/examples/FlowableExamples.h index 0613efa5a..675140b8b 100644 --- a/yarpl/examples/FlowableExamples.h +++ b/yarpl/examples/FlowableExamples.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/yarpl/examples/yarpl-playground.cpp b/yarpl/examples/yarpl-playground.cpp index c3cfc0f99..5fbe2f4c5 100644 --- a/yarpl/examples/yarpl-playground.cpp +++ b/yarpl/examples/yarpl-playground.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include diff --git a/yarpl/flowable/AsyncGeneratorShim.h b/yarpl/flowable/AsyncGeneratorShim.h new file mode 100644 index 000000000..72d212c83 --- /dev/null +++ b/yarpl/flowable/AsyncGeneratorShim.h @@ -0,0 +1,165 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once +#include +#include +#include +#include +#include +#include +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace detail { +template +class AsyncGeneratorShim { + public: + AsyncGeneratorShim( + folly::coro::AsyncGenerator&& generator, + folly::SequencedExecutor* ex) + : generator_(std::move(generator)), + sharedState_(std::make_shared()) { + sharedState_->executor_ = folly::getKeepAliveToken(ex); + } + + void subscribe( + std::shared_ptr> subscriber) && { + class Subscription : public yarpl::flowable::Subscription { + public: + explicit Subscription(std::weak_ptr state) + : state_(std::move(state)) {} + + void request(int64_t n) override { + if (auto state = state_.lock()) { + state->executor_->add([n, state = std::move(state)]() { + if (state->requested_ == credits::kNoFlowControl || + n == credits::kNoFlowControl) { + state->requested_ = credits::kNoFlowControl; + } else { + state->requested_ += n; + } + state->baton_.post(); + }); + } + } + + void cancel() override { + if (auto state = state_.lock()) { + state->executor_->add([state = std::move(state)]() { + // requestCancellation will execute registered CancellationCallback + // inline, but CancellationCallback should be run in + // executor_ thread + state->cancelSource_.requestCancellation(); + state->baton_.post(); + }); + } + } + + private: + std::weak_ptr state_; + }; + sharedState_->executor_->add( + [keepAlive = sharedState_->executor_.copy(), + subscriber, + subscription = std::make_shared( + std::weak_ptr(sharedState_))]() mutable { + subscriber->onSubscribe(std::move(subscription)); + }); + auto executor = sharedState_->executor_.get(); + folly::coro::co_withCancellation( + sharedState_->cancelSource_.getToken(), + folly::coro::co_invoke( + [subscriber = std::move(subscriber), + self = std::move(*this)]() mutable -> folly::coro::Task { + while (true) { + while (self.sharedState_->requested_ == 0 && + !self.sharedState_->cancelSource_ + .isCancellationRequested()) { + co_await self.sharedState_->baton_; + self.sharedState_->baton_.reset(); + } + + if (self.sharedState_->cancelSource_ + .isCancellationRequested()) { + self.sharedState_->executor_->add( + [subscriber = std::move(subscriber)]() { + // destory subscriber on executor_ thread + }); + co_return; + } + + folly::Try value; + try { + auto item = co_await self.generator_.next(); + + if (item.has_value()) { + value.emplace(std::move(*item)); + } + } catch (const std::exception& ex) { + value.emplaceException(std::current_exception(), ex); + } catch (...) { + value.emplaceException(std::current_exception()); + } + + if (value.hasValue()) { + self.sharedState_->executor_->add( + [subscriber, + keepAlive = self.sharedState_->executor_.copy(), + value = std::move(value)]() mutable { + subscriber->onNext(std::move(value).value()); + }); + } else if (value.hasException()) { + self.sharedState_->executor_->add( + [subscriber = std::move(subscriber), + keepAlive = self.sharedState_->executor_.copy(), + value = std::move(value)]() mutable { + subscriber->onError(std::move(value).exception()); + }); + co_return; + } else { + self.sharedState_->executor_->add( + [subscriber = std::move(subscriber), + keepAlive = + self.sharedState_->executor_.copy()]() mutable { + subscriber->onComplete(); + }); + co_return; + } + + if (self.sharedState_->requested_ != credits::kNoFlowControl) { + self.sharedState_->requested_--; + } + } + })) + .scheduleOn(std::move(executor)) + .start(); + } + + private: + struct SharedState { + SharedState() = default; + explicit SharedState(folly::CancellationSource source) + : cancelSource_(std::move(source)) {} + folly::Executor::KeepAlive executor_; + int64_t requested_{0}; + folly::coro::Baton baton_{0}; + folly::CancellationSource cancelSource_; + }; + + folly::coro::AsyncGenerator generator_; + std::shared_ptr sharedState_; +}; +} // namespace detail + +template +std::shared_ptr> toFlowable( + folly::coro::AsyncGenerator gen, + folly::SequencedExecutor* ex = folly::getEventBase()) { + return yarpl::flowable::internal::flowableFromSubscriber( + [gen = std::move(gen), + ex](std::shared_ptr> subscriber) mutable { + detail::AsyncGeneratorShim(std::move(gen), ex) + .subscribe(std::move(subscriber)); + }); +} +} // namespace yarpl diff --git a/yarpl/flowable/CancelingSubscriber.h b/yarpl/flowable/CancelingSubscriber.h new file mode 100644 index 000000000..0933a6908 --- /dev/null +++ b/yarpl/flowable/CancelingSubscriber.h @@ -0,0 +1,47 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/Subscriber.h" + +#include + +namespace yarpl { +namespace flowable { + +/** + * A Subscriber that always cancels the subscription passed to it. + */ +template +class CancelingSubscriber final : public BaseSubscriber { + public: + void onSubscribeImpl() override { + this->cancel(); + } + + void onNextImpl(T) override { + throw std::logic_error{"CancelingSubscriber::onNext() can never be called"}; + } + void onCompleteImpl() override { + throw std::logic_error{ + "CancelingSubscriber::onComplete() can never be called"}; + } + void onErrorImpl(folly::exception_wrapper) override { + throw std::logic_error{ + "CancelingSubscriber::onError() can never be called"}; + } +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/DeferFlowable.h b/yarpl/flowable/DeferFlowable.h new file mode 100644 index 000000000..b817c85f4 --- /dev/null +++ b/yarpl/flowable/DeferFlowable.h @@ -0,0 +1,49 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template +class DeferFlowable : public Flowable { + static_assert( + std::is_same, FlowableFactory>::value, + "undecayed"); + + public: + template + explicit DeferFlowable(F&& factory) : factory_(std::forward(factory)) {} + + virtual void subscribe(std::shared_ptr> subscriber) { + std::shared_ptr> flowable; + try { + flowable = factory_(); + } catch (const std::exception& ex) { + flowable = Flowable::error(ex, std::current_exception()); + } + flowable->subscribe(std::move(subscriber)); + } + + private: + FlowableFactory factory_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/include/yarpl/flowable/EmitterFlowable.h b/yarpl/flowable/EmitterFlowable.h similarity index 50% rename from yarpl/include/yarpl/flowable/EmitterFlowable.h rename to yarpl/flowable/EmitterFlowable.h index ebc51f980..5c5089551 100644 --- a/yarpl/include/yarpl/flowable/EmitterFlowable.h +++ b/yarpl/flowable/EmitterFlowable.h @@ -1,27 +1,42 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "yarpl/utils/credits.h" - -#include - #include #include #include #include +#include + +#include "yarpl/flowable/Flowable.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/utils/credits.h" namespace yarpl { namespace flowable { namespace details { template -class EmiterBase : public virtual Refcounted { +class EmitterBase { public: - ~EmiterBase() = default; + virtual ~EmitterBase() = default; - virtual std::tuple emit(Reference>, int64_t) = 0; + virtual std::tuple emit( + std::shared_ptr>, + int64_t) = 0; }; /** @@ -31,15 +46,19 @@ class EmiterBase : public virtual Refcounted { * of a request(n) call. */ template -class EmiterSubscription : public Subscription, public Subscriber { +class EmiterSubscription final : public Subscription, + public Subscriber, + public yarpl::enable_get_ref { constexpr static auto kCanceled = credits::kCanceled; constexpr static auto kNoFlowControl = credits::kNoFlowControl; public: EmiterSubscription( - Reference> emiter, - Reference> subscriber) - : emiter_(std::move(emiter)), subscriber_(std::move(subscriber)) { + std::shared_ptr> emitter, + std::shared_ptr> subscriber) + : emitter_(std::move(emitter)), subscriber_(std::move(subscriber)) {} + + void init() { subscriber_->onSubscribe(this->ref_from_this(this)); } @@ -48,11 +67,6 @@ class EmiterSubscription : public Subscription, public Subscriber { } void request(int64_t delta) override { - if (delta <= 0) { - auto message = "request(n): " + folly::to(delta) + " <= 0"; - throw std::logic_error(message); - } - while (true) { auto current = requested_.load(std::memory_order_relaxed); @@ -88,28 +102,43 @@ class EmiterSubscription : public Subscription, public Subscriber { } // Subscriber methods. - void onSubscribe(Reference) override { + void onSubscribe(std::shared_ptr) override { LOG(FATAL) << "Do not call this method"; } void onNext(T value) override { +#ifndef NDEBUG DCHECK(!hasFinished_) << "onComplete() or onError() already called"; - - subscriber_->onNext(std::move(value)); +#endif + if (subscriber_) { + subscriber_->onNext(std::move(value)); + } else { + DCHECK(requested_.load(std::memory_order_relaxed) == kCanceled); + } } void onComplete() override { +#ifndef NDEBUG DCHECK(!hasFinished_) << "onComplete() or onError() already called"; hasFinished_ = true; - - subscriber_->onComplete(); +#endif + if (subscriber_) { + subscriber_->onComplete(); + } else { + DCHECK(requested_.load(std::memory_order_relaxed) == kCanceled); + } } void onError(folly::exception_wrapper error) override { +#ifndef NDEBUG DCHECK(!hasFinished_) << "onComplete() or onError() already called"; hasFinished_ = true; - - subscriber_->onError(error); +#endif + if (subscriber_) { + subscriber_->onError(error); + } else { + DCHECK(requested_.load(std::memory_order_relaxed) == kCanceled); + } } private: @@ -152,15 +181,28 @@ class EmiterSubscription : public Subscription, public Subscriber { int64_t emitted; bool done; - std::tie(emitted, done) = emiter_->emit(this_subscriber, current); + std::tie(emitted, done) = emitter_->emit(this_subscriber, current); while (true) { current = requested_.load(std::memory_order_relaxed); - if (current == kCanceled || (current == kNoFlowControl && !done)) { + if (current == kCanceled) { break; } - - auto updated = done ? kCanceled : current - emitted; + int64_t updated; + // generally speaking updated will be number of credits lefted over + // after emitter_->emit(), so updated = current - emitted + // need to handle case where done = true and avoid doing arithmetic + // operation on kNoFlowControl + + // in asynchrnous emitter cases, might have emitted=kNoFlowControl + // this means that emitter will take the responsibility to send the + // whole conext and credits lefted over should be set to 0. + if (current == kNoFlowControl) { + updated = + done ? kCanceled : emitted == kNoFlowControl ? 0 : kNoFlowControl; + } else { + updated = done ? kCanceled : current - emitted; + } if (requested_.compare_exchange_strong(current, updated)) { break; } @@ -169,7 +211,7 @@ class EmiterSubscription : public Subscription, public Subscriber { } void release() { - emiter_.reset(); + emitter_.reset(); subscriber_.reset(); } @@ -179,34 +221,100 @@ class EmiterSubscription : public Subscription, public Subscriber { // value represents cancellation. Other -ve values aren't permitted. std::atomic_int_fast64_t requested_{0}; +#ifndef NDEBUG bool hasFinished_{false}; // onComplete or onError called +#endif // We don't want to recursively invoke process(); one loop should do. std::atomic_bool processing_{false}; - Reference> emiter_; - Reference> subscriber_; + std::shared_ptr> emitter_; + std::shared_ptr> subscriber_; +}; + +template +class TrackingSubscriber : public Subscriber { + public: + TrackingSubscriber( + Subscriber& subscriber, + int64_t +#ifndef NDEBUG + requested +#endif + ) + : inner_(&subscriber) +#ifndef NDEBUG + , + requested_(requested) +#endif + { + } + + void onSubscribe(std::shared_ptr s) override { + inner_->onSubscribe(std::move(s)); + } + + void onComplete() override { + completed_ = true; + inner_->onComplete(); + } + + void onError(folly::exception_wrapper ex) override { + completed_ = true; + inner_->onError(std::move(ex)); + } + + void onNext(T value) override { +#ifndef NDEBUG + auto old = requested_; + DCHECK(old > credits::consume(requested_, 1)) + << "cannot emit more than requested"; +#endif + emitted_++; + inner_->onNext(std::move(value)); + } + + auto getResult() { + return std::make_tuple(emitted_, completed_); + } + + private: + int64_t emitted_{0}; + bool completed_{false}; + Subscriber* inner_; +#ifndef NDEBUG + int64_t requested_; +#endif }; template -class EmitterWrapper : public EmiterBase, public Flowable { +class EmitterWrapper : public EmitterBase, public Flowable { + static_assert( + std::is_same, Emitter>::value, + "undecayed"); + public: - explicit EmitterWrapper(Emitter emitter) : emitter_(std::move(emitter)) {} + template + explicit EmitterWrapper(F&& emitter) : emitter_(std::forward(emitter)) {} - void subscribe(Reference> subscriber) override { - make_ref>(this->ref_from_this(this), std::move(subscriber)); + void subscribe(std::shared_ptr> subscriber) override { + auto ef = std::make_shared>( + this->ref_from_this(this), std::move(subscriber)); + ef->init(); } std::tuple emit( - Reference> subscriber, + std::shared_ptr> subscriber, int64_t requested) override { - return emitter_(std::move(subscriber), requested); + TrackingSubscriber trackingSubscriber(*subscriber, requested); + emitter_(trackingSubscriber, requested); + return trackingSubscriber.getResult(); } private: Emitter emitter_; }; -} // details -} // flowable -} // yarpl +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Flowable.h b/yarpl/flowable/Flowable.h new file mode 100644 index 000000000..9dff78b03 --- /dev/null +++ b/yarpl/flowable/Flowable.h @@ -0,0 +1,749 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "yarpl/Disposable.h" +#include "yarpl/Refcounted.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { + +class TimeoutException; +namespace detail { +class TimeoutExceptionGenerator; +} + +namespace flowable { + +template +class Flowable; + +namespace details { + +template +struct IsFlowable : std::false_type {}; + +template +struct IsFlowable>> : std::true_type { + using ElemType = R; +}; + +} // namespace details + +template +class Flowable : public yarpl::enable_get_ref { + public: + virtual ~Flowable() = default; + + virtual void subscribe(std::shared_ptr>) = 0; + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + std::unique_ptr subscribe( + Next&& next, + int64_t batch = credits::kNoFlowControl) { + auto subscriber = + details::LambdaSubscriber::create(std::forward(next), batch); + subscribe(subscriber); + return std::make_unique>( + std::move(subscriber)); + } + + /** + * Subscribe overload that accepts lambdas. + * + * Takes an optional batch size for request_n. Default is no flow control. + */ + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + std::unique_ptr subscribe( + Next&& next, + Error&& e, + int64_t batch = credits::kNoFlowControl) { + auto subscriber = details::LambdaSubscriber::create( + std::forward(next), std::forward(e), batch); + subscribe(subscriber); + return std::make_unique>( + std::move(subscriber)); + } + + /** + * Subscribe overload that accepts lambdas. + * + * Takes an optional batch size for request_n. Default is no flow control. + */ + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + std::unique_ptr subscribe( + Next&& next, + Error&& e, + Complete&& complete, + int64_t batch = credits::kNoFlowControl) { + auto subscriber = details::LambdaSubscriber::create( + std::forward(next), + std::forward(e), + std::forward(complete), + batch); + subscribe(subscriber); + return std::make_unique>( + std::move(subscriber)); + } + + void subscribe() { + subscribe(Subscriber::create()); + } + + // + // creator methods: + // + + // Creates Flowable which completes the subscriber right after it subscribes + static std::shared_ptr> empty(); + + // Creates Flowable which will never terminate the subscriber + static std::shared_ptr> never(); + + // Create Flowable which will imediatelly terminate the subscriber upon + // subscription with the provided error + static std::shared_ptr> error(folly::exception_wrapper ex); + + template + static std::shared_ptr> error(Ex&) { + static_assert( + std::is_lvalue_reference::value, + "use variant of error() method accepting also exception_ptr"); + } + + template + static std::shared_ptr> error(Ex& ex, std::exception_ptr ptr) { + return Flowable::error(folly::exception_wrapper(std::move(ptr), ex)); + } + + static std::shared_ptr> just(T value) { + auto lambda = [value = std::move(value)]( + Subscriber& subscriber, int64_t requested) mutable { + DCHECK_GT(requested, 0); + subscriber.onNext(std::move(value)); + subscriber.onComplete(); + }; + + return Flowable::create(std::move(lambda)); + } + + static std::shared_ptr> justN(std::initializer_list list) { + auto lambda = [v = std::vector(std::move(list)), i = size_t{0}]( + Subscriber& subscriber, int64_t requested) mutable { + while (i < v.size() && requested-- > 0) { + subscriber.onNext(v[i++]); + } + + if (i == v.size()) { + // TODO T27302402: Even though having two subscriptions exist + // concurrently for Emitters is not possible still. At least it possible + // to resubscribe and consume the same values again. + i = 0; + subscriber.onComplete(); + } + }; + + return Flowable::create(std::move(lambda)); + } + + // this will generate a flowable which can be subscribed to only once + static std::shared_ptr> justOnce(T value) { + auto lambda = [value = std::move(value), used = false]( + Subscriber& subscriber, int64_t) mutable { + if (used) { + subscriber.onError( + std::runtime_error("justOnce value was already used")); + return; + } + + used = true; + // # requested should be > 0. Ignoring the actual parameter. + subscriber.onNext(std::move(value)); + subscriber.onComplete(); + }; + + return Flowable::create(std::move(lambda)); + } + + template + static std::shared_ptr> fromGenerator(TGenerator&& generator); + + /** + * The Defer operator waits until a subscriber subscribes to it, and then it + * generates a Flowabe with a FlowableFactory function. It + * does this afresh for each subscriber, so although each subscriber may + * think it is subscribing to the same Flowable, in fact each subscriber + * gets its own individual sequence. + */ + template < + typename FlowableFactory, + typename = typename std::enable_if>, + std::decay_t&>::value>::type> + static std::shared_ptr> defer(FlowableFactory&&); + + template < + typename Function, + typename ErrorFunction = + folly::Function, + typename R = typename folly::invoke_result_t, + typename = typename std::enable_if&, + folly::exception_wrapper&&>::value>::type> + std::shared_ptr> map( + Function&& function, + ErrorFunction&& errormapFunc = [](folly::exception_wrapper&& ew) { + return std::move(ew); + }); + + template < + typename Function, + typename R = typename details::IsFlowable< + typename folly::invoke_result_t>::ElemType> + std::shared_ptr> flatMap(Function&& func); + + template + std::shared_ptr> filter(Function&& function); + + template < + typename Function, + typename R = typename folly::invoke_result_t> + std::shared_ptr> reduce(Function&& function); + + std::shared_ptr> take(int64_t); + + std::shared_ptr> skip(int64_t); + + std::shared_ptr> ignoreElements(); + + /* + * To instruct a Flowable to do its work on a particular Executor. + * the onSubscribe, request and cancel methods will be scheduled on the + * provided executor + */ + std::shared_ptr> subscribeOn(folly::Executor&); + + std::shared_ptr> observeOn(folly::Executor&); + + std::shared_ptr> observeOn(folly::Executor::KeepAlive<>); + + std::shared_ptr> concatWith(std::shared_ptr>); + + template + std::shared_ptr> concatWith( + std::shared_ptr> first, + Args... args) { + return concatWith(first)->concatWith(args...); + } + + template + static std::shared_ptr> concat( + std::shared_ptr> first, + Args... args) { + return first->concatWith(args...); + } + + template + using enableWrapRef = + typename std::enable_if::value, Q>::type; + + // Combines multiple Flowables so that they act like a + // single Flowable. The items + // emitted by the merged Flowables may interlieve. + template + enableWrapRef merge() { + return this->flatMap([](auto f) { return std::move(f); }); + } + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnSubscribe(Function&& function); + + // function is invoked when onNext occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>::type> + std::shared_ptr> doOnNext(Function&& function); + + // function is invoked when onError occurs. + template < + typename Function, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> doOnError(Function&& function); + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnComplete(Function&& function); + + // function is invoked when either onComplete or onError occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnTerminate(Function&& function); + + // the function is invoked for each of onNext, onCompleted, onError + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnEach(Function&& function); + + // function is invoked when request(n) is called. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&, int64_t>::value>::type> + std::shared_ptr> doOnRequest(Function&& function); + + // function is invoked when cancel is called. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnCancel(Function&& function); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> + doOn(OnNextFunc&& onNext, OnCompleteFunc&& onComplete, OnErrorFunc&& onError); + + template < + typename ExceptionGenerator = yarpl::detail::TimeoutExceptionGenerator> + std::shared_ptr> timeout( + folly::EventBase& timerEvb, + std::chrono::milliseconds timeout, + std::chrono::milliseconds initTimeout, + ExceptionGenerator&& exnGen = ExceptionGenerator()); + + template < + typename Emitter, + typename = typename std::enable_if&, + Subscriber&, + int64_t>::value>::type> + static std::shared_ptr> create(Emitter&& emitter); + + template < + typename OnSubscribe, + typename = typename std::enable_if>>::value>::type> + // TODO(lehecka): enable this warning once mobile code is clear + // [[deprecated( + // "Flowable::fromPublisher is deprecated: Use PublishProcessor or " + // "contact rsocket team if you can't figure out what to replace it " + // "with")]] + static std::shared_ptr> fromPublisher(OnSubscribe&& function); +}; + +} // namespace flowable +} // namespace yarpl + +#include "yarpl/flowable/DeferFlowable.h" +#include "yarpl/flowable/EmitterFlowable.h" +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +namespace flowable { + +template +template +std::shared_ptr> Flowable::create(Emitter&& emitter) { + return std::make_shared>>( + std::forward(emitter)); +} + +template +std::shared_ptr> Flowable::empty() { + class EmptyFlowable : public Flowable { + void subscribe(std::shared_ptr> subscriber) override { + subscriber->onSubscribe(Subscription::create()); + // does not wait for request(n) to complete + subscriber->onComplete(); + } + }; + return std::make_shared(); +} + +template +std::shared_ptr> Flowable::never() { + class NeverFlowable : public Flowable { + void subscribe(std::shared_ptr> subscriber) override { + subscriber->onSubscribe(Subscription::create()); + } + }; + return std::make_shared(); +} + +template +std::shared_ptr> Flowable::error(folly::exception_wrapper ex) { + class ErrorFlowable : public Flowable { + void subscribe(std::shared_ptr> subscriber) override { + subscriber->onSubscribe(Subscription::create()); + // does not wait for request(n) to error + subscriber->onError(ex_); + } + folly::exception_wrapper ex_; + + public: + explicit ErrorFlowable(folly::exception_wrapper ew) : ex_(std::move(ew)) {} + }; + return std::make_shared(std::move(ex)); +} + +namespace internal { +template +std::shared_ptr> flowableFromSubscriber(OnSubscribe&& function) { + return std::make_shared>>( + std::forward(function)); +} +} // namespace internal + +// TODO(lehecka): remove +template +template +std::shared_ptr> Flowable::fromPublisher( + OnSubscribe&& function) { + return internal::flowableFromSubscriber( + std::forward(function)); +} + +template +template +std::shared_ptr> Flowable::fromGenerator( + TGenerator&& generator) { + auto lambda = [generator = std::forward(generator)]( + Subscriber& subscriber, int64_t requested) mutable { + try { + while (requested-- > 0) { + subscriber.onNext(generator()); + } + } catch (const std::exception& ex) { + subscriber.onError( + folly::exception_wrapper(std::current_exception(), ex)); + } catch (...) { + subscriber.onError(std::runtime_error( + "Flowable::fromGenerator() threw from Subscriber:onNext()")); + } + }; + return Flowable::create(std::move(lambda)); +} // namespace flowable + +template +template +std::shared_ptr> Flowable::defer(FlowableFactory&& factory) { + return std::make_shared< + details::DeferFlowable>>( + std::forward(factory)); +} + +template +template +std::shared_ptr> Flowable::map( + Function&& function, + ErrorFunction&& errorFunction) { + return std::make_shared< + MapOperator, std::decay_t>>( + this->ref_from_this(this), + std::forward(function), + std::forward(errorFunction)); +} + +template +template +std::shared_ptr> Flowable::filter(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +template +std::shared_ptr> Flowable::reduce(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +std::shared_ptr> Flowable::take(int64_t limit) { + return std::make_shared>(this->ref_from_this(this), limit); +} + +template +std::shared_ptr> Flowable::skip(int64_t offset) { + return std::make_shared>(this->ref_from_this(this), offset); +} + +template +std::shared_ptr> Flowable::ignoreElements() { + return std::make_shared>(this->ref_from_this(this)); +} + +template +std::shared_ptr> Flowable::subscribeOn( + folly::Executor& executor) { + return std::make_shared>( + this->ref_from_this(this), executor); +} + +template +std::shared_ptr> Flowable::observeOn(folly::Executor& executor) { + return observeOn(folly::getKeepAliveToken(executor)); +} + +template +std::shared_ptr> Flowable::observeOn( + folly::Executor::KeepAlive<> executor) { + return std::make_shared>( + this->ref_from_this(this), std::move(executor)); +} + +template +template +std::shared_ptr> Flowable::flatMap(Function&& function) { + return std::make_shared>( + this->ref_from_this(this), std::forward(function)); +} + +template +std::shared_ptr> Flowable::concatWith( + std::shared_ptr> next) { + return std::make_shared>( + this->ref_from_this(this), std::move(next)); +} + +template +template +std::shared_ptr> Flowable::doOnSubscribe(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + std::forward(function), + [](const T&) {}, + [](const auto&) {}, + [] {}, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnNext(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(function), + [](const auto&) {}, + [] {}, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnError(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + std::forward(function), + [] {}, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnComplete(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [](const auto&) {}, + std::forward(function), + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnTerminate(Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnEach(Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [sharedFunction](const T&) { (*sharedFunction)(); }, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + [](const auto&) {}, + std::forward(onComplete), + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename, + typename, + typename> +std::shared_ptr> Flowable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete, + OnErrorFunc&& onError) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + std::forward(onError), + std::forward(onComplete), + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnRequest(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, // onSubscribe + [](const auto&) {}, // onNext + [](const auto&) {}, // onError + [] {}, // onComplete + std::forward(function), // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnCancel(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, // onSubscribe + [](const auto&) {}, // onNext + [](const auto&) {}, // onError + [] {}, // onComplete + [](const auto&) {}, // onRequest + std::forward(function)); // onCancel +} + +template +template +std::shared_ptr> Flowable::timeout( + folly::EventBase& timerEvb, + std::chrono::milliseconds starvationTimeout, + std::chrono::milliseconds initTimeout, + ExceptionGenerator&& exnGen) { + return std::make_shared>( + ref_from_this(this), + timerEvb, + starvationTimeout, + initTimeout, + std::forward(exnGen)); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableConcatOperators.h b/yarpl/flowable/FlowableConcatOperators.h new file mode 100644 index 000000000..56694146b --- /dev/null +++ b/yarpl/flowable/FlowableConcatOperators.h @@ -0,0 +1,189 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template +class ConcatWithOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + ConcatWithOperator( + std::shared_ptr> first, + std::shared_ptr> second) + : first_(std::move(first)), second_(std::move(second)) { + CHECK(first_); + CHECK(second_); + } + + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = + std::make_shared(subscriber, first_, second_); + subscription->init(); + } + + private: + class ForwardSubscriber; + + // Downstream will always point to this subscription + class ConcatWithSubscription + : public yarpl::flowable::Subscription, + public std::enable_shared_from_this { + public: + ConcatWithSubscription( + std::shared_ptr> subscriber, + std::shared_ptr> first, + std::shared_ptr> second) + : downSubscriber_(std::move(subscriber)), + first_(std::move(first)), + second_(std::move(second)) {} + + void init() { + upSubscriber_ = + std::make_shared(this->shared_from_this()); + first_->subscribe(upSubscriber_); + downSubscriber_->onSubscribe(this->shared_from_this()); + } + + void request(int64_t n) override { + credits::add(&requested_, n); + if (!upSubscriber_) { + if (auto second = std::exchange(second_, nullptr)) { + upSubscriber_ = std::make_shared( + this->shared_from_this(), requested_); + second->subscribe(upSubscriber_); + } + } else { + upSubscriber_->request(n); + } + } + + void cancel() override { + if (auto subscriber = std::move(upSubscriber_)) { + subscriber->cancel(); + } + first_.reset(); + second_.reset(); + downSubscriber_.reset(); + upSubscriber_.reset(); + } + + void onNext(T value) { + credits::consume(&requested_, 1); + downSubscriber_->onNext(std::move(value)); + } + + void onComplete() { + upSubscriber_.reset(); + if (auto first = std::move(first_)) { + if (requested_ > 0) { + if (auto second = std::exchange(second_, nullptr)) { + upSubscriber_ = std::make_shared( + this->shared_from_this(), requested_); + // TODO - T28771728 + // Concat should not call 'subscribe' on onComplete + second->subscribe(upSubscriber_); + } + } + } else { + if (auto downSubscriber = std::exchange(downSubscriber_, nullptr)) { + downSubscriber->onComplete(); + } + upSubscriber_.reset(); + } + } + + void onError(folly::exception_wrapper ew) { + downSubscriber_->onError(std::move(ew)); + first_.reset(); + second_.reset(); + downSubscriber_.reset(); + upSubscriber_.reset(); + } + + private: + std::shared_ptr> downSubscriber_; + std::shared_ptr> first_; + std::shared_ptr> second_; + std::shared_ptr upSubscriber_; + std::atomic requested_{0}; + }; + + class ForwardSubscriber : public yarpl::flowable::Subscriber, + public yarpl::flowable::Subscription { + public: + ForwardSubscriber( + std::shared_ptr concatWithSubscription, + uint32_t initialRequest = 0u) + : concatWithSubscription_(std::move(concatWithSubscription)), + initialRequest_(initialRequest) {} + + void request(int64_t n) override { + subscription_->request(n); + } + + void cancel() override { + if (auto subs = std::move(subscription_)) { + subs->cancel(); + } else { + canceled_ = true; + } + } + + void onSubscribe(std::shared_ptr subscription) override { + if (canceled_) { + subscription->cancel(); + return; + } + subscription_ = std::move(subscription); + if (auto req = std::exchange(initialRequest_, 0)) { + subscription_->request(req); + } + } + + void onComplete() override { + auto sub = std::exchange(concatWithSubscription_, nullptr); + sub->onComplete(); + } + + void onError(folly::exception_wrapper ew) override { + auto sub = std::exchange(concatWithSubscription_, nullptr); + sub->onError(std::move(ew)); + } + void onNext(T value) override { + concatWithSubscription_->onNext(std::move(value)); + } + + private: + std::shared_ptr concatWithSubscription_; + std::shared_ptr subscription_; + + uint32_t initialRequest_{0}; + bool canceled_{false}; + }; + + private: + const std::shared_ptr> first_; + const std::shared_ptr> second_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableDoOperator.h b/yarpl/flowable/FlowableDoOperator.h new file mode 100644 index 000000000..256a345ba --- /dev/null +++ b/yarpl/flowable/FlowableDoOperator.h @@ -0,0 +1,190 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnRequestFunc, + typename OnCancelFunc> +class DoOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert( + std::is_same, OnSubscribeFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnNextFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnErrorFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCompleteFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnRequestFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCancelFunc>::value, + "undecayed"); + + public: + template < + typename FSubscribe, + typename FNext, + typename FError, + typename FComplete, + typename FRequest, + typename FCancel> + DoOperator( + std::shared_ptr> upstream, + FSubscribe&& onSubscribeFunc, + FNext&& onNextFunc, + FError&& onErrorFunc, + FComplete&& onCompleteFunc, + FRequest&& onRequestFunc, + FCancel&& onCancelFunc) + : upstream_(std::move(upstream)), + onSubscribeFunc_(std::forward(onSubscribeFunc)), + onNextFunc_(std::forward(onNextFunc)), + onErrorFunc_(std::forward(onErrorFunc)), + onCompleteFunc_(std::forward(onCompleteFunc)), + onRequestFunc_(std::forward(onRequestFunc)), + onCancelFunc_(std::forward(onCancelFunc)) {} + + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(subscriber)); + upstream_->subscribe( + // Note: implicit cast to a reference to a subscriber. + subscription); + } + + private: + class DoSubscription : public Super::Subscription { + using SuperSub = typename Super::Subscription; + + public: + DoSubscription( + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSub(std::move(subscriber)), flowable_(std::move(flowable)) {} + + void onSubscribeImpl() override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onSubscribeFunc_(); + SuperSub::onSubscribeImpl(); + } + } + + void onNextImpl(U value) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + const auto& valueRef = value; + flowable->onNextFunc_(valueRef); + SuperSub::subscriberOnNext(std::move(value)); + } + } + + void onErrorImpl(folly::exception_wrapper ex) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + const auto& exRef = ex; + flowable->onErrorFunc_(exRef); + SuperSub::onErrorImpl(std::move(ex)); + } + } + + void onCompleteImpl() override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onCompleteFunc_(); + SuperSub::onCompleteImpl(); + } + } + + void cancel() override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onCancelFunc_(); + SuperSub::cancel(); + } + } + + void request(int64_t n) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onRequestFunc_(n); + SuperSub::request(n); + } + } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSub::onTerminateImpl(); + } + + private: + AtomicReference flowable_; + }; + + std::shared_ptr> upstream_; + OnSubscribeFunc onSubscribeFunc_; + OnNextFunc onNextFunc_; + OnErrorFunc onErrorFunc_; + OnCompleteFunc onCompleteFunc_; + OnRequestFunc onRequestFunc_; + OnCancelFunc onCancelFunc_; +}; + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnRequestFunc, + typename OnCancelFunc> +inline auto createDoOperator( + std::shared_ptr> upstream, + OnSubscribeFunc&& onSubscribeFunc, + OnNextFunc&& onNextFunc, + OnErrorFunc&& onErrorFunc, + OnCompleteFunc&& onCompleteFunc, + OnRequestFunc&& onRequestFunc, + OnCancelFunc&& onCancelFunc) { + return std::make_shared, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t>>( + std::move(upstream), + std::forward(onSubscribeFunc), + std::forward(onNextFunc), + std::forward(onErrorFunc), + std::forward(onCompleteFunc), + std::forward(onRequestFunc), + std::forward(onCancelFunc)); +} +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableObserveOnOperator.h b/yarpl/flowable/FlowableObserveOnOperator.h new file mode 100644 index 000000000..359540980 --- /dev/null +++ b/yarpl/flowable/FlowableObserveOnOperator.h @@ -0,0 +1,124 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace flowable { +namespace detail { + +template +class ObserveOnOperatorSubscriber; + +template +class ObserveOnOperatorSubscription : public yarpl::flowable::Subscription, + public yarpl::enable_get_ref { + public: + ObserveOnOperatorSubscription( + std::shared_ptr> subscriber, + std::shared_ptr subscription) + : subscriber_(std::move(subscriber)), + subscription_(std::move(subscription)) {} + + // all requesting methods are called from 'executor_' in the + // associated subscriber + void cancel() override { + auto self = this->ref_from_this(this); + + if (auto subscriber = std::move(subscriber_)) { + subscriber->inner_ = nullptr; + } + + subscription_->cancel(); + } + + void request(int64_t n) override { + subscription_->request(n); + } + + private: + std::shared_ptr> subscriber_; + std::shared_ptr subscription_; +}; + +template +class ObserveOnOperatorSubscriber : public yarpl::flowable::Subscriber, + public yarpl::enable_get_ref { + public: + ObserveOnOperatorSubscriber( + std::shared_ptr> inner, + folly::Executor::KeepAlive<> executor) + : inner_(std::move(inner)), executor_(std::move(executor)) {} + + // all signaling methods are called from upstream EB + void onSubscribe(std::shared_ptr subscription) override { + executor_->add([self = this->ref_from_this(this), + s = std::move(subscription)]() mutable { + auto sub = std::make_shared>( + self, std::move(s)); + self->inner_->onSubscribe(std::move(sub)); + }); + } + void onNext(T next) override { + executor_->add( + [self = this->ref_from_this(this), n = std::move(next)]() mutable { + if (auto& inner = self->inner_) { + inner->onNext(std::move(n)); + } + }); + } + void onComplete() override { + executor_->add([self = this->ref_from_this(this)]() mutable { + if (auto inner = std::exchange(self->inner_, nullptr)) { + inner->onComplete(); + } + }); + } + void onError(folly::exception_wrapper err) override { + executor_->add( + [self = this->ref_from_this(this), e = std::move(err)]() mutable { + if (auto inner = std::exchange(self->inner_, nullptr)) { + inner->onError(std::move(e)); + } + }); + } + + private: + friend class ObserveOnOperatorSubscription; + + std::shared_ptr> inner_; + folly::Executor::KeepAlive<> executor_; +}; + +template +class ObserveOnOperator : public yarpl::flowable::Flowable { + public: + ObserveOnOperator( + std::shared_ptr> upstream, + folly::Executor::KeepAlive<> executor) + : upstream_(std::move(upstream)), executor_(std::move(executor)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared>( + std::move(subscriber), folly::getKeepAliveToken(executor_.get()))); + } + + std::shared_ptr> upstream_; + folly::Executor::KeepAlive<> executor_; +}; +} // namespace detail +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/include/yarpl/flowable/FlowableOperator.h b/yarpl/flowable/FlowableOperator.h similarity index 61% rename from yarpl/include/yarpl/flowable/FlowableOperator.h rename to yarpl/flowable/FlowableOperator.h index a1b21af93..314ba7f2e 100644 --- a/yarpl/include/yarpl/flowable/FlowableOperator.h +++ b/yarpl/flowable/FlowableOperator.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -15,6 +27,7 @@ #include #include #include +#include namespace yarpl { namespace flowable { @@ -25,12 +38,8 @@ namespace flowable { * upstream Flowable, and are Flowables themselves. Multi-stage pipelines can * be built: a Flowable heading a sequence of Operators. */ -template +template class FlowableOperator : public Flowable { - public: - explicit FlowableOperator(Reference> upstream) - : upstream_(std::move(upstream)) {} - protected: /// An Operator's subscription. /// @@ -42,28 +51,33 @@ class FlowableOperator : public Flowable { class Subscription : public yarpl::flowable::Subscription, public BaseSubscriber { protected: - Subscription( - Reference flowable, - Reference> subscriber) - : flowableOperator_(std::move(flowable)), - subscriber_(std::move(subscriber)) { - CHECK(flowableOperator_); - CHECK(subscriber_); + explicit Subscription(std::shared_ptr> subscriber) + : subscriber_(std::move(subscriber)) { + CHECK(yarpl::atomic_load(&subscriber_)); } - const Reference& getFlowableOperator() { - return flowableOperator_; + // Subscriber will be provided by the init(Subscriber) call + Subscription() {} + + virtual void init(std::shared_ptr> subscriber) { + if (yarpl::atomic_load(&subscriber_)) { + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onError(std::runtime_error("already initialized")); + return; + } + subscriber_ = std::move(subscriber); } void subscriberOnNext(D value) { - if (auto subscriber = subscriber_.load()) { + if (auto subscriber = yarpl::atomic_load(&subscriber_)) { subscriber->onNext(std::move(value)); } } /// Terminates both ends of an operator normally. void terminate() { - auto subscriber = subscriber_.exchange(nullptr); + std::shared_ptr> null; + auto subscriber = yarpl::atomic_exchange(&subscriber_, null); BaseSubscriber::cancel(); if (subscriber) { subscriber->onComplete(); @@ -72,7 +86,8 @@ class FlowableOperator : public Flowable { /// Terminates both ends of an operator with an error. void terminateErr(folly::exception_wrapper ew) { - auto subscriber = subscriber_.exchange(nullptr); + std::shared_ptr> null; + auto subscriber = yarpl::atomic_exchange(&subscriber_, null); BaseSubscriber::cancel(); if (subscriber) { subscriber->onError(std::move(ew)); @@ -86,57 +101,64 @@ class FlowableOperator : public Flowable { } void cancel() override { - auto subscriber = subscriber_.exchange(nullptr); + std::shared_ptr> null; + auto subscriber = yarpl::atomic_exchange(&subscriber_, null); BaseSubscriber::cancel(); } // Subscriber. void onSubscribeImpl() override { - subscriber_->onSubscribe(this->ref_from_this(this)); + yarpl::atomic_load(&subscriber_)->onSubscribe(this->ref_from_this(this)); } void onCompleteImpl() override { - if (auto subscriber = subscriber_.exchange(nullptr)) { + std::shared_ptr> null; + if (auto subscriber = yarpl::atomic_exchange(&subscriber_, null)) { subscriber->onComplete(); } } void onErrorImpl(folly::exception_wrapper ew) override { - if (auto subscriber = subscriber_.exchange(nullptr)) { + std::shared_ptr> null; + if (auto subscriber = yarpl::atomic_exchange(&subscriber_, null)) { subscriber->onError(std::move(ew)); } } private: - /// The Flowable has the lambda, and other creation parameters. - Reference flowableOperator_; - /// This subscription controls the life-cycle of the subscriber. The /// subscriber is retained as long as calls on it can be made. (Note: the /// subscriber in turn maintains a reference on this subscription object /// until cancellation and/or completion.) AtomicReference> subscriber_; }; - - Reference> upstream_; }; -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>::type> -class MapOperator : public FlowableOperator> { - using ThisOperatorT = MapOperator; - using Super = FlowableOperator; +template +class MapOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); + static_assert( + folly::is_invocable_r< + folly::exception_wrapper, + EF, + folly::exception_wrapper&&>::value, + "exception handler not invocable"); public: - MapOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} - - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( + template + MapOperator( + std::shared_ptr> upstream, + Func&& function, + ErrorFunc&& errFunction) + : upstream_(std::move(upstream)), + function_(std::forward(function)), + errFunction_(std::move(errFunction)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( this->ref_from_this(this), std::move(subscriber))); } @@ -145,40 +167,62 @@ class MapOperator : public FlowableOperator> { class Subscription : public SuperSubscription { public: Subscription( - Reference flowable, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)) {} + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)) {} void onNextImpl(U value) override { try { - auto&& map = this->getFlowableOperator(); - this->subscriberOnNext(map->function_(std::move(value))); + if (auto flowable = yarpl::atomic_load(&flowable_)) { + this->subscriberOnNext(flowable->function_(std::move(value))); + } } catch (const std::exception& exn) { folly::exception_wrapper ew{std::current_exception(), exn}; this->terminateErr(std::move(ew)); } } + + void onErrorImpl(folly::exception_wrapper ew) override { + try { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + SuperSubscription::onErrorImpl(flowable->errFunction_(std::move(ew))); + } + } catch (const std::exception& exn) { + this->terminateErr( + folly::exception_wrapper{std::current_exception(), exn}); + } + } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSubscription::onTerminateImpl(); + } + + private: + AtomicReference flowable_; }; + std::shared_ptr> upstream_; F function_; + EF errFunction_; }; -template < - typename U, - typename F, - typename = - typename std::enable_if::value>::type> -class FilterOperator : public FlowableOperator> { +template +class FilterOperator : public FlowableOperator { // for use in subclasses - using ThisOperatorT = FilterOperator; - using Super = FlowableOperator; + using Super = FlowableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); public: - FilterOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} + template + FilterOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( this->ref_from_this(this), std::move(subscriber))); } @@ -187,40 +231,49 @@ class FilterOperator : public FlowableOperator> { class Subscription : public SuperSubscription { public: Subscription( - Reference flowable, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)) {} + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)) {} void onNextImpl(U value) override { - auto&& filter = SuperSubscription::getFlowableOperator(); - if (filter->function_(value)) { - SuperSubscription::subscriberOnNext(std::move(value)); - } else { - SuperSubscription::request(1); + if (auto flowable = yarpl::atomic_load(&flowable_)) { + if (flowable->function_(value)) { + SuperSubscription::subscriberOnNext(std::move(value)); + } else { + SuperSubscription::request(1); + } } } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSubscription::onTerminateImpl(); + } + + private: + AtomicReference flowable_; }; + std::shared_ptr> upstream_; F function_; }; -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>, - typename = - typename std::enable_if::value>::type> -class ReduceOperator : public FlowableOperator> { - using ThisOperatorT = ReduceOperator; - using Super = FlowableOperator; +template +class ReduceOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(std::is_assignable::value, "not assignable"); + static_assert(folly::is_invocable_r::value, "not invocable"); public: - ReduceOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} + template + ReduceOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( this->ref_from_this(this), std::move(subscriber))); } @@ -229,9 +282,10 @@ class ReduceOperator : public FlowableOperator> { class Subscription : public SuperSubscription { public: Subscription( - Reference flowable, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)), + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)), accInitialized_(false) {} void request(int64_t) override { @@ -240,9 +294,10 @@ class ReduceOperator : public FlowableOperator> { } void onNextImpl(U value) override { - auto&& reduce = SuperSubscription::getFlowableOperator(); if (accInitialized_) { - acc_ = reduce->function_(std::move(acc_), std::move(value)); + if (auto flowable = yarpl::atomic_load(&flowable_)) { + acc_ = flowable->function_(std::move(acc_), std::move(value)); + } } else { acc_ = std::move(value); accInitialized_ = true; @@ -256,38 +311,48 @@ class ReduceOperator : public FlowableOperator> { SuperSubscription::onCompleteImpl(); } + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSubscription::onTerminateImpl(); + } + private: + AtomicReference flowable_; bool accInitialized_; D acc_; }; + std::shared_ptr> upstream_; F function_; }; template -class TakeOperator : public FlowableOperator> { - using ThisOperatorT = TakeOperator; - using Super = FlowableOperator; +class TakeOperator : public FlowableOperator { + using Super = FlowableOperator; public: - TakeOperator(Reference> upstream, int64_t limit) - : Super(std::move(upstream)), limit_(limit) {} + TakeOperator(std::shared_ptr> upstream, int64_t limit) + : upstream_(std::move(upstream)), limit_(limit) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( - this->ref_from_this(this), limit_, std::move(subscriber))); + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe( + std::make_shared(limit_, std::move(subscriber))); } private: using SuperSubscription = typename Super::Subscription; class Subscription : public SuperSubscription { public: - Subscription( - Reference flowable, - int64_t limit, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)), - limit_(limit) {} + Subscription(int64_t limit, std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), limit_(limit) {} + + void onSubscribeImpl() override { + SuperSubscription::onSubscribeImpl(); + + if (limit_ <= 0) { + SuperSubscription::terminate(); + } + } void onNextImpl(T value) override { if (limit_-- > 0) { @@ -314,33 +379,29 @@ class TakeOperator : public FlowableOperator> { int64_t limit_; }; + std::shared_ptr> upstream_; const int64_t limit_; }; template -class SkipOperator : public FlowableOperator> { - using ThisOperatorT = SkipOperator; - using Super = FlowableOperator; +class SkipOperator : public FlowableOperator { + using Super = FlowableOperator; public: - SkipOperator(Reference> upstream, int64_t offset) - : Super(std::move(upstream)), offset_(offset) {} + SkipOperator(std::shared_ptr> upstream, int64_t offset) + : upstream_(std::move(upstream)), offset_(offset) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( - this->ref_from_this(this), offset_, std::move(subscriber))); + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe( + std::make_shared(offset_, std::move(subscriber))); } private: using SuperSubscription = typename Super::Subscription; class Subscription : public SuperSubscription { public: - Subscription( - Reference flowable, - int64_t offset, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)), - offset_(offset) {} + Subscription(int64_t offset, std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), offset_(offset) {} void onNextImpl(T value) override { if (offset_ > 0) { @@ -363,52 +424,50 @@ class SkipOperator : public FlowableOperator> { bool firstRequest_{true}; }; + std::shared_ptr> upstream_; const int64_t offset_; }; template -class IgnoreElementsOperator - : public FlowableOperator> { - using ThisOperatorT = IgnoreElementsOperator; - using Super = FlowableOperator; +class IgnoreElementsOperator : public FlowableOperator { + using Super = FlowableOperator; public: - explicit IgnoreElementsOperator(Reference> upstream) - : Super(std::move(upstream)) {} + explicit IgnoreElementsOperator(std::shared_ptr> upstream) + : upstream_(std::move(upstream)) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( - this->ref_from_this(this), std::move(subscriber))); + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared(std::move(subscriber))); } private: using SuperSubscription = typename Super::Subscription; class Subscription : public SuperSubscription { public: - Subscription( - Reference flowable, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)) {} + Subscription(std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)) {} void onNextImpl(T) override {} }; + + std::shared_ptr> upstream_; }; template -class SubscribeOnOperator - : public FlowableOperator> { - using ThisOperatorT = SubscribeOnOperator; - using Super = FlowableOperator; +class SubscribeOnOperator : public FlowableOperator { + using Super = FlowableOperator; public: SubscribeOnOperator( - Reference> upstream, + std::shared_ptr> upstream, folly::Executor& executor) - : Super(std::move(upstream)), executor_(executor) {} + : upstream_(std::move(upstream)), executor_(executor) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( - this->ref_from_this(this), executor_, std::move(subscriber))); + void subscribe(std::shared_ptr> subscriber) override { + executor_.add([this, self = this->ref_from_this(this), subscriber] { + upstream_->subscribe( + std::make_shared(executor_, std::move(subscriber))); + }); } private: @@ -416,20 +475,18 @@ class SubscribeOnOperator class Subscription : public SuperSubscription { public: Subscription( - Reference flowable, folly::Executor& executor, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)), - executor_(executor) {} + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), executor_(executor) {} void request(int64_t delta) override { - executor_.add([ delta, this, self = this->ref_from_this(this) ] { + executor_.add([delta, this, self = this->ref_from_this(this)] { this->callSuperRequest(delta); }); } void cancel() override { - executor_.add([ this, self = this->ref_from_this(this) ] { + executor_.add([this, self = this->ref_from_this(this)] { this->callSuperCancel(); }); } @@ -452,16 +509,22 @@ class SubscribeOnOperator folly::Executor& executor_; }; + std::shared_ptr> upstream_; folly::Executor& executor_; }; template class FromPublisherOperator : public Flowable { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + public: - explicit FromPublisherOperator(OnSubscribe function) - : function_(std::move(function)) {} + template + explicit FromPublisherOperator(F&& function) + : function_(std::forward(function)) {} - void subscribe(Reference> subscriber) override { + void subscribe(std::shared_ptr> subscriber) override { function_(std::move(subscriber)); } @@ -470,18 +533,17 @@ class FromPublisherOperator : public Flowable { }; template -class FlatMapOperator : public FlowableOperator> { - using ThisOperatorT = FlatMapOperator; - using Super = FlowableOperator; +class FlatMapOperator : public FlowableOperator { + using Super = FlowableOperator; public: FlatMapOperator( - Reference> upstream, - folly::Function>(T)> func) - : Super(std::move(upstream)), function_(std::move(func)) {} + std::shared_ptr> upstream, + folly::Function>(T)> func) + : upstream_(std::move(upstream)), function_(std::move(func)) {} - void subscribe(Reference> subscriber) override { - Super::upstream_->subscribe(make_ref( + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( this->ref_from_this(this), std::move(subscriber))); } @@ -492,9 +554,10 @@ class FlatMapOperator : public FlowableOperator> { public: FMSubscription( - Reference flowable, - Reference> subscriber) - : SuperSubscription(std::move(flowable), std::move(subscriber)) {} + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)) {} void onSubscribeImpl() final { liveSubscribers_++; @@ -502,24 +565,23 @@ class FlatMapOperator : public FlowableOperator> { } void onNextImpl(T value) final { - auto&& flatMapOp = this->getFlowableOperator(); - Reference> mappedStream; + std::shared_ptr> mappedStream; try { - mappedStream = flatMapOp->function_(std::move(value)); + mappedStream = flowable_->function_(std::move(value)); } catch (const std::exception& exn) { folly::exception_wrapper ew{std::current_exception(), exn}; { std::lock_guard g(onErrorExGuard_); - onErrorEx_ = std::move(folly::exception_wrapper{ew}); + onErrorEx_ = ew; } // next iteration of drainLoop will cancel this subscriber as well drainLoop(); return; } - Reference mappedSubscriber = - yarpl::make_ref(this->ref_from_this(this)); + std::shared_ptr mappedSubscriber = + std::make_shared(this->ref_from_this(this)); mappedSubscriber->fmReference_ = mappedSubscriber; { @@ -635,8 +697,8 @@ class FlatMapOperator : public FlowableOperator> { // Subscribers might call onNext and then terminate; delay // removing its liveSubscriber reference until we've delivered // its element to the downstream subscriber and dropped its - // synchronized reference to `r`, as dropping the flatMapSubscription_ - // reference may invoke its destructor + // synchronized reference to `r`, as dropping the + // flatMapSubscription_ reference may invoke its destructor if (r->isTerminated) { r->freeze = true; terminatedTrash.push_back(*elem); @@ -650,9 +712,9 @@ class FlatMapOperator : public FlowableOperator> { elem->request(1); } - // phase 5: destroy any mapped subscribers which have terminated, enqueue - // another drain loop run if we do end up discarding any subscribers, as - // our live subscriber count may have gone to zero + // phase 5: destroy any mapped subscribers which have terminated, + // enqueue another drain loop run if we do end up discarding any + // subscribers, as our live subscriber count may have gone to zero if (!terminatedTrash.empty()) { drainLoopMutex_++; } @@ -666,7 +728,7 @@ class FlatMapOperator : public FlowableOperator> { } } - // called from MappedStreamSubscriber, recieves the R and the + // called from MappedStreamSubscriber, receives the R and the // subscriber which generated the R void drainLoop() { auto self = this->ref_from_this(this); @@ -684,7 +746,7 @@ class FlatMapOperator : public FlowableOperator> { auto l = lists.wlock(); auto r = elem->sync.wlock(); - if(r->freeze) { + if (r->freeze) { return; } @@ -737,8 +799,7 @@ class FlatMapOperator : public FlowableOperator> { // onComplete/onError fall through to onTerminateImpl, which // will call drainLoop and update the liveSubscribers_ count - void onCompleteImpl() final { - } + void onCompleteImpl() final {} void onErrorImpl(folly::exception_wrapper ex) final { std::lock_guard g(onErrorExGuard_); onErrorEx_ = std::move(ex); @@ -748,13 +809,13 @@ class FlatMapOperator : public FlowableOperator> { void onTerminateImpl() final { liveSubscribers_--; drainLoop(); + flowable_.reset(); } void request(int64_t n) override { - if((n + requested_) < requested_) { + if ((n + requested_) < requested_) { requested_ = std::numeric_limits::max(); - } - else { + } else { requested_ += n; } @@ -778,12 +839,17 @@ class FlatMapOperator : public FlowableOperator> { : public BaseSubscriber, public boost::intrusive::list_base_hook< boost::intrusive::link_mode> { - MappedStreamSubscriber(Reference subscription) + MappedStreamSubscriber(std::shared_ptr subscription) : flatMapSubscription_(std::move(subscription)) {} void onSubscribeImpl() final { -#ifdef DEBUG - if (auto fms = flatMapSubscription_.load()) { + auto fmsb = yarpl::atomic_load(&flatMapSubscription_); + if (!fmsb || fmsb->clearAllSubscribers_) { + BaseSubscriber::cancel(); + return; + } +#ifndef NDEBUG + if (auto fms = yarpl::atomic_load(&flatMapSubscription_)) { auto l = fms->lists.wlock(); auto r = sync.wlock(); if (!is_in_list(*this, l->pendingValue, l)) { @@ -801,14 +867,13 @@ class FlatMapOperator : public FlowableOperator> { } void onNextImpl(R value) final { - if (auto fms = flatMapSubscription_.load()) { + if (auto fms = yarpl::atomic_load(&flatMapSubscription_)) { fms->onMappedSubscriberNext(this, std::move(value)); } } // noop - void onCompleteImpl() final { - } + void onCompleteImpl() final {} void onErrorImpl(folly::exception_wrapper ex) final { auto r = sync.wlock(); @@ -816,7 +881,8 @@ class FlatMapOperator : public FlowableOperator> { } void onTerminateImpl() override { - if (auto fms = flatMapSubscription_.exchange(nullptr)) { + std::shared_ptr null; + if (auto fms = yarpl::atomic_exchange(&flatMapSubscription_, null)) { fms->onMappedSubscriberTerminate(this); } } @@ -830,11 +896,10 @@ class FlatMapOperator : public FlowableOperator> { }; folly::Synchronized sync; - // FMSubscription's 'reference' to this object. FMSubscription // clears this reference when it drops the MappedStreamSubscriber // from one of its atomic lists - Reference fmReference_{nullptr}; + std::shared_ptr fmReference_{nullptr}; // this is both a Subscriber and a Subscription AtomicReference flatMapSubscription_{nullptr}; @@ -881,10 +946,10 @@ class FlatMapOperator : public FlowableOperator> { L const& lists, bool should) { if (is_in_list(elem, list) != should) { -#ifdef DEBUG +#ifndef NDEBUG debug_is_in_list(elem, lists); #else - (void) lists; + (void)lists; #endif return false; } @@ -892,7 +957,9 @@ class FlatMapOperator : public FlowableOperator> { } template - static void debug_is_in_list(MappedStreamSubscriber const& elem, L const& lists) { + static void debug_is_in_list( + MappedStreamSubscriber const& elem, + L const& lists) { LOG(INFO) << "in without: " << is_in_list(elem, lists->withoutValue); LOG(INFO) << "in pending: " << is_in_list(elem, lists->pendingValue); LOG(INFO) << "in withval: " << is_in_list(elem, lists->withValue); @@ -911,6 +978,8 @@ class FlatMapOperator : public FlowableOperator> { return found; } + std::shared_ptr flowable_; + // got a terminating signal from the upstream flowable // always modified in the protected drainImpl() bool calledDownstreamTerminate_{false}; @@ -924,14 +993,18 @@ class FlatMapOperator : public FlowableOperator> { std::atomic requested_{0}; // number of subscribers (FMSubscription + MappedStreamSubscriber) which - // have not recieved a terminating signal yet + // have not received a terminating signal yet std::atomic liveSubscribers_{0}; }; - folly::Function>(T)> function_; + std::shared_ptr> upstream_; + folly::Function>(T)> function_; }; } // namespace flowable } // namespace yarpl +#include "yarpl/flowable/FlowableConcatOperators.h" +#include "yarpl/flowable/FlowableDoOperator.h" #include "yarpl/flowable/FlowableObserveOnOperator.h" +#include "yarpl/flowable/FlowableTimeoutOperator.h" diff --git a/yarpl/flowable/FlowableTimeoutOperator.h b/yarpl/flowable/FlowableTimeoutOperator.h new file mode 100644 index 000000000..bb14ccc9c --- /dev/null +++ b/yarpl/flowable/FlowableTimeoutOperator.h @@ -0,0 +1,162 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "yarpl/flowable/Flowable.h" + +#pragma once + +#include + +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +class TimeoutException : public std::runtime_error { + public: + TimeoutException() : std::runtime_error("yarpl::TimeoutException") {} +}; +namespace detail { +class TimeoutExceptionGenerator { + public: + TimeoutException operator()() const { + return {}; + } +}; +} // namespace detail + +namespace flowable { +namespace details { + +template +class TimeoutOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert( + std::is_same, ExceptionGenerator>::value, + "undecayed"); + + public: + template + TimeoutOperator( + std::shared_ptr> upstream, + folly::EventBase& timerEvb, + std::chrono::milliseconds timeout, + std::chrono::milliseconds initTimeout, + F&& exnGen) + : upstream_(std::move(upstream)), + timerEvb_(timerEvb), + timeout_(timeout), + initTimeout_(initTimeout), + exnGen_(std::forward(exnGen)) {} + + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = std::make_shared( + this->ref_from_this(this), + subscriber, + timerEvb_, + initTimeout_, + timeout_); + upstream_->subscribe(std::move(subscription)); + } + + protected: + class TimeoutSubscription : public Super::Subscription, + public folly::HHWheelTimer::Callback { + using SuperSub = typename Super::Subscription; + + public: + TimeoutSubscription( + std::shared_ptr> flowable, + std::shared_ptr> subscriber, + folly::EventBase& timerEvb, + std::chrono::milliseconds initTimeout, + std::chrono::milliseconds timeout) + : Super::Subscription(std::move(subscriber)), + flowable_(std::move(flowable)), + timerEvb_(timerEvb), + initTimeout_(initTimeout), + timeout_(timeout) {} + + void onSubscribeImpl() override { + DCHECK(timerEvb_.isInEventBaseThread()); + if (initTimeout_.count() > 0) { + nextTime_ = std::chrono::steady_clock::now() + initTimeout_; + timerEvb_.timer().scheduleTimeout(this, initTimeout_); + } else { + nextTime_ = std::chrono::steady_clock::time_point::max(); + } + + SuperSub::onSubscribeImpl(); + } + + void onNextImpl(T value) override { + DCHECK(timerEvb_.isInEventBaseThread()); + if (flowable_) { + if (nextTime_ != std::chrono::steady_clock::time_point::max()) { + cancelTimeout(); // cancel timer before calling onNext + auto currentTime = std::chrono::steady_clock::now(); + if (currentTime > nextTime_) { + timeoutExpired(); + return; + } + nextTime_ = std::chrono::steady_clock::time_point::max(); + } + + SuperSub::subscriberOnNext(std::move(value)); + + if (timeout_.count() > 0) { + nextTime_ = std::chrono::steady_clock::now() + timeout_; + timerEvb_.timer().scheduleTimeout(this, timeout_); + } + } + } + + void onTerminateImpl() override { + DCHECK(timerEvb_.isInEventBaseThread()); + flowable_.reset(); + cancelTimeout(); + } + + void timeoutExpired() noexcept override { + if (auto flowable = std::exchange(flowable_, nullptr)) { + SuperSub::terminateErr([&]() -> folly::exception_wrapper { + try { + return flowable->exnGen_(); + } catch (...) { + return folly::make_exception_wrapper(); + } + }()); + } + } + + void callbackCanceled() noexcept override { + // Do nothing.. + } + + private: + std::shared_ptr> flowable_; + folly::EventBase& timerEvb_; + std::chrono::milliseconds initTimeout_; + std::chrono::milliseconds timeout_; + std::chrono::steady_clock::time_point nextTime_; + }; + + std::shared_ptr> upstream_; + folly::EventBase& timerEvb_; + std::chrono::milliseconds timeout_; + std::chrono::milliseconds initTimeout_; + ExceptionGenerator exnGen_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Flowable_FromObservable.h b/yarpl/flowable/Flowable_FromObservable.h new file mode 100644 index 000000000..e191ad7c3 --- /dev/null +++ b/yarpl/flowable/Flowable_FromObservable.h @@ -0,0 +1,348 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/Common.h" +#include "yarpl/Flowable.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace observable { +template +class Observable; + +template +class Observer; +} // namespace observable + +template +class BackpressureStrategyBase : public IBackpressureStrategy, + public flowable::Subscription, + public observable::Observer { + protected: + // + // the following methods are to be overridden + // + virtual void onCreditsAvailable(int64_t /*credits*/) = 0; + virtual void onNextWithoutCredits(T /*t*/) = 0; + + public: + void init( + std::shared_ptr> observable, + std::shared_ptr> subscriber) override { + observable_ = std::move(observable); + subscriberWeak_ = subscriber; + subscriber_ = subscriber; + subscriber->onSubscribe(this->ref_from_this(this)); + observable_->subscribe(this->ref_from_this(this)); + } + + BackpressureStrategyBase() = default; + BackpressureStrategyBase(BackpressureStrategyBase&&) = delete; + + BackpressureStrategyBase(const BackpressureStrategyBase&) = delete; + BackpressureStrategyBase& operator=(BackpressureStrategyBase&&) = delete; + BackpressureStrategyBase& operator=(const BackpressureStrategyBase&) = delete; + + // only for testing purposes + void setTestSubscriber(std::shared_ptr> subscriber) { + subscriberWeak_ = subscriber; + subscriber_ = subscriber; + subscriber->onSubscribe(this->ref_from_this(this)); + } + + void request(int64_t n) override { + if (n <= 0) { + return; + } + auto r = credits::add(&requested_, n); + if (r <= 0) { + return; + } + + // it is possible that after calling subscribe or in onCreditsAvailable + // methods, there will be a stream of + // onNext calls which the processing chain might cancel. The cancel signal + // will remove all references to this class and we need to keep this + // instance around to finish this method + auto thisPtr = this->ref_from_this(this); + + if (r > 0) { + onCreditsAvailable(r); + } + } + + void cancel() override { + if (auto subscriber = subscriber_.exchange(nullptr)) { + observable::Observer::unsubscribe(); + observable_.reset(); + } + } + + // Observer override + void onNext(T t) override { + if (subscriberWeak_.expired()) { + return; + } + if (requested_ > 0) { + downstreamOnNext(std::move(t)); + return; + } + onNextWithoutCredits(std::move(t)); + } + + // Observer override + void onComplete() override { + downstreamOnComplete(); + } + + // Observer override + void onError(folly::exception_wrapper ex) override { + downstreamOnError(std::move(ex)); + } + + virtual void downstreamOnNext(T t) { + credits::consume(&requested_, 1); + if (auto subscriber = subscriberWeak_.lock()) { + subscriber->onNext(std::move(t)); + } + } + + void downstreamOnComplete() { + if (auto subscriber = subscriber_.exchange(nullptr)) { + subscriber->onComplete(); + observable::Observer::onComplete(); + observable_.reset(); + } + } + + void downstreamOnError(folly::exception_wrapper error) { + if (auto subscriber = subscriber_.exchange(nullptr)) { + subscriber->onError(std::move(error)); + observable::Observer::onError(folly::exception_wrapper()); + observable_.reset(); + } + } + + void downstreamOnErrorAndCancel(folly::exception_wrapper error) { + if (auto subscriber = subscriber_.exchange(nullptr)) { + subscriber->onError(std::move(error)); + + observable_.reset(); + observable::Observer::unsubscribe(); + } + } + + private: + std::shared_ptr> observable_; + folly::Synchronized>> subscriber_; + std::weak_ptr> subscriberWeak_; + std::atomic requested_{0}; +}; + +template +class DropBackpressureStrategy : public BackpressureStrategyBase { + public: + void onCreditsAvailable(int64_t /*credits*/) override {} + void onNextWithoutCredits(T /*t*/) override { + // drop anything while we don't have credits + } +}; + +template +class ErrorBackpressureStrategy : public BackpressureStrategyBase { + using Super = BackpressureStrategyBase; + + void onCreditsAvailable(int64_t /*credits*/) override {} + + void onNextWithoutCredits(T /*t*/) override { + Super::downstreamOnErrorAndCancel(flowable::MissingBackpressureException()); + } +}; + +template +class BufferBackpressureStrategy : public BackpressureStrategyBase { + public: + static constexpr size_t kNoLimit = 0; + + explicit BufferBackpressureStrategy(size_t bufferSizeLimit = kNoLimit) + : buffer_(folly::in_place, bufferSizeLimit) {} + + private: + using Super = BackpressureStrategyBase; + + void onComplete() override { + if (!buffer_.rlock()->empty()) { + // we have buffered some items so we will defer delivering on complete for + // later + completed_ = true; + } else { + Super::onComplete(); + } + } + + void onNext(T t) override { + { + auto buffer = buffer_.wlock(); + if (!buffer->empty()) { + if (buffer->push(std::move(t))) { + return; + } + buffer.unlock(); + Super::downstreamOnErrorAndCancel( + flowable::MissingBackpressureException()); + return; + } + } + BackpressureStrategyBase::onNext(std::move(t)); + } + + // + // onError signal is delivered immediately by design + // + + void onNextWithoutCredits(T t) override { + if (buffer_.wlock()->push(std::move(t))) { + return; + } + Super::downstreamOnErrorAndCancel(flowable::MissingBackpressureException()); + } + + void onCreditsAvailable(int64_t credits) override { + DCHECK(credits > 0); + auto lockedBuffer = buffer_.wlock(); + while (credits-- > 0 && !lockedBuffer->empty()) { + Super::downstreamOnNext(std::move(lockedBuffer->front())); + lockedBuffer->pop(); + } + + if (lockedBuffer->empty() && completed_) { + Super::onComplete(); + } + } + + struct Buffer { + public: + explicit Buffer(size_t sizeLimit) : sizeLimit_(sizeLimit) {} + + bool empty() const { + return buffer_.empty(); + } + + bool push(T&& value) { + if (sizeLimit_ != kNoLimit && buffer_.size() >= sizeLimit_) { + return false; + } + buffer_.push(std::move(value)); + return true; + } + + T& front() { + return buffer_.front(); + } + + void pop() { + buffer_.pop(); + } + + private: + const size_t sizeLimit_; + std::queue buffer_; + }; + + folly::Synchronized buffer_; + std::atomic completed_{false}; +}; + +template +class LatestBackpressureStrategy : public BackpressureStrategyBase { + using Super = BackpressureStrategyBase; + + void onComplete() override { + if (storesLatest_) { + // we have buffered an item so we will defer delivering on complete for + // later + completed_ = true; + } else { + Super::onComplete(); + } + } + + // + // onError signal is delivered immediately by design + // + + void onNextWithoutCredits(T t) override { + storesLatest_ = true; + *latest_.wlock() = std::move(t); + } + + void onCreditsAvailable(int64_t credits) override { + DCHECK(credits > 0); + if (storesLatest_) { + storesLatest_ = false; + Super::downstreamOnNext(std::move(*latest_.wlock())); + + if (completed_) { + Super::onComplete(); + } + } + } + + std::atomic storesLatest_{false}; + std::atomic completed_{false}; + folly::Synchronized latest_; +}; + +template +class MissingBackpressureStrategy : public BackpressureStrategyBase { + using Super = BackpressureStrategyBase; + + void onCreditsAvailable(int64_t /*credits*/) override {} + + void onNextWithoutCredits(T t) override { + // call onNext anyways (and potentially violating the protocol) + Super::downstreamOnNext(std::move(t)); + } +}; + +template +std::shared_ptr> IBackpressureStrategy::buffer() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::drop() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::error() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::latest() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::missing() { + return std::make_shared>(); +} + +} // namespace yarpl diff --git a/yarpl/flowable/Flowables.cpp b/yarpl/flowable/Flowables.cpp new file mode 100644 index 000000000..4b61540f5 --- /dev/null +++ b/yarpl/flowable/Flowables.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/Flowables.h" + +namespace yarpl { +namespace flowable { + +std::shared_ptr> Flowable<>::range( + int64_t start, + int64_t count) { + auto lambda = [start, count, i = start]( + Subscriber& subscriber, + int64_t requested) mutable { + int64_t end = start + count; + + while (i < end && requested-- > 0) { + subscriber.onNext(i++); + } + + if (i >= end) { + // TODO T27302402: Even though having two subscriptions exist concurrently + // for Emitters is not possible still. At least it possible to resubscribe + // and consume the same values again. + i = start; + subscriber.onComplete(); + } + }; + return Flowable::create(std::move(lambda)); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Flowables.h b/yarpl/flowable/Flowables.h new file mode 100644 index 000000000..56cb8c034 --- /dev/null +++ b/yarpl/flowable/Flowables.h @@ -0,0 +1,64 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace flowable { + +template <> +class Flowable { + public: + /** + * Emit the sequence of numbers [start, start + count). + */ + static std::shared_ptr> range(int64_t start, int64_t count); + + template + static std::shared_ptr> just(T&& value) { + return Flowable>::just(std::forward(value)); + } + + template + static std::shared_ptr> justN(std::initializer_list list) { + return Flowable>::justN(std::move(list)); + } + + // this will generate a flowable which can be subscribed to only once + template + static std::shared_ptr> justOnce(T&& value) { + return Flowable>::justOnce(std::forward(value)); + } + + template + static std::shared_ptr> concat( + std::shared_ptr> first, + Args... args) { + return first->concatWith(args...); + } + + private: + Flowable() = delete; +}; + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/PublishProcessor.h b/yarpl/flowable/PublishProcessor.h new file mode 100644 index 000000000..8232c4a82 --- /dev/null +++ b/yarpl/flowable/PublishProcessor.h @@ -0,0 +1,255 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/Common.h" +#include "yarpl/flowable/Flowable.h" +#include "yarpl/observable/Observable.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace flowable { + +// Processor that multicasts all subsequently observed items to its current +// Subscribers. The processor does not coordinate backpressure for its +// subscribers and implements a weaker onSubscribe which calls requests +// kNoFlowControl from the incoming Subscriptions. This makes it possible to +// subscribe the PublishProcessor to multiple sources unlike the standard +// Subscriber contract. If subscribers are not able to keep up with the flow +// control, they are terminated with MissingBackpressureException. The +// implementation of onXXX() and subscribe() methods are technically thread-safe +// but non-serialized calls to them may lead to undefined state in the currently +// subscribed Subscribers. +template +class PublishProcessor : public observable::Observable, + public Subscriber { + class PublisherSubscription; + using PublishersVector = std::vector>; + + public: + static std::shared_ptr create() { + return std::shared_ptr(new PublishProcessor()); + } + + ~PublishProcessor() { + auto publishers = std::make_shared(); + publishers_.swap(publishers); + + for (const auto& publisher : *publishers) { + publisher->terminate(); + } + } + + bool hasSubscribers() const { + return !publishers_.copy()->empty(); + } + + std::shared_ptr subscribe( + std::shared_ptr> subscriber) override { + auto publisher = std::make_shared(subscriber, this); + // we have to call onSubscribe before adding it to the list of publishers + // because they might start emitting right away + subscriber->onSubscribe(publisher); + + if (publisher->isCancelled()) { + return publisher; + } + + auto publishers = tryAddPublisher(publisher); + + if (publishers == kCompletedPublishers()) { + publisher->onComplete(); + } else if (publishers == kErroredPublishers()) { + publisher->onError(std::runtime_error("ErroredPublisher")); + } + + return publisher; + } + + void onSubscribe(std::shared_ptr subscription) override { + auto publishers = publishers_.copy(); + if (publishers == kCompletedPublishers() || + publishers == kErroredPublishers()) { + subscription->cancel(); + return; + } + + subscription->request(credits::kNoFlowControl); + } + + void onNext(T value) override { + auto publishers = publishers_.copy(); + DCHECK(publishers != kCompletedPublishers()); + DCHECK(publishers != kErroredPublishers()); + + for (const auto& publisher : *publishers) { + publisher->onNext(value); + } + } + + void onError(folly::exception_wrapper ex) override { + auto publishers = kErroredPublishers(); + publishers_.swap(publishers); + DCHECK(publishers != kCompletedPublishers()); + DCHECK(publishers != kErroredPublishers()); + + for (const auto& publisher : *publishers) { + publisher->onError(ex); + } + } + + void onComplete() override { + auto publishers = kCompletedPublishers(); + publishers_.swap(publishers); + DCHECK(publishers != kCompletedPublishers()); + DCHECK(publishers != kErroredPublishers()); + + for (const auto& publisher : *publishers) { + publisher->onComplete(); + } + } + + private: + PublishProcessor() : publishers_{std::make_shared()} {} + + std::shared_ptr tryAddPublisher( + std::shared_ptr subscriber) { + while (true) { + auto oldPublishers = publishers_.copy(); + if (oldPublishers == kCompletedPublishers() || + oldPublishers == kErroredPublishers()) { + return oldPublishers; + } + + auto newPublishers = std::make_shared(); + newPublishers->reserve(oldPublishers->size() + 1); + newPublishers->insert( + newPublishers->begin(), + oldPublishers->cbegin(), + oldPublishers->cend()); + newPublishers->push_back(subscriber); + + auto locked = publishers_.lock(); + if (*locked == oldPublishers) { + *locked = newPublishers; + return newPublishers; + } + // else the vector changed so we will have to do it again + } + } + + void removePublisher(PublisherSubscription* subscriber) { + while (true) { + auto oldPublishers = publishers_.copy(); + + auto removingItem = std::find_if( + oldPublishers->cbegin(), + oldPublishers->cend(), + [&](const auto& publisherPtr) { + return publisherPtr.get() == subscriber; + }); + + if (removingItem == oldPublishers->cend()) { + // not found anymore + return; + } + + auto newPublishers = std::make_shared(); + newPublishers->reserve(oldPublishers->size() - 1); + newPublishers->insert( + newPublishers->begin(), oldPublishers->cbegin(), removingItem); + newPublishers->insert( + newPublishers->end(), std::next(removingItem), oldPublishers->cend()); + + auto locked = publishers_.lock(); + if (*locked == oldPublishers) { + *locked = std::move(newPublishers); + return; + } + // else the vector changed so we will have to do it again + } + } + + class PublisherSubscription : public observable::Subscription { + public: + PublisherSubscription( + std::shared_ptr> subscriber, + PublishProcessor* processor) + : subscriber_(std::move(subscriber)), processor_(processor) {} + + // cancel may race with terminate(), but the + // PublishProcessor::removePublisher will take care of that the race with + // on{Next, Error, Complete} methods is allowed by the spec + void cancel() override { + subscriber_.reset(); + processor_->removePublisher(this); + } + + // terminate will never race with on{Next, Error, Complete} because they are + // all called from PublishProcessor and terminate is called only from dtor + void terminate() { + if (auto subscriber = std::exchange(subscriber_, nullptr)) { + subscriber->onError(std::runtime_error("PublishProcessor shutdown")); + } + } + + void onNext(T value) { + if (subscriber_) { + subscriber_->onNext(std::move(value)); + } + } + + // used internally, not an interface method + void onError(folly::exception_wrapper ex) { + if (auto subscriber = std::exchange(subscriber_, nullptr)) { + subscriber->onError(std::move(ex)); + } + } + + // used internally, not an interface method + void onComplete() { + if (auto subscriber = std::exchange(subscriber_, nullptr)) { + subscriber->onComplete(); + } + } + + bool isCancelled() const { + return !subscriber_; + } + + private: + std::shared_ptr> subscriber_; + PublishProcessor* processor_; + }; + + static const std::shared_ptr& kCompletedPublishers() { + static std::shared_ptr constant = + std::make_shared(); + return constant; + } + + static const std::shared_ptr& kErroredPublishers() { + static std::shared_ptr constant = + std::make_shared(); + return constant; + } + + folly::Synchronized, std::mutex> + publishers_; +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Subscriber.h b/yarpl/flowable/Subscriber.h new file mode 100644 index 000000000..d1dc3b525 --- /dev/null +++ b/yarpl/flowable/Subscriber.h @@ -0,0 +1,448 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "yarpl/Disposable.h" +#include "yarpl/Refcounted.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace flowable { + +template +class Subscriber : boost::noncopyable { + public: + virtual ~Subscriber() = default; + virtual void onSubscribe(std::shared_ptr) = 0; + virtual void onComplete() = 0; + virtual void onError(folly::exception_wrapper) = 0; + virtual void onNext(T) = 0; + + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + static std::shared_ptr> create( + Next&& next, + int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + static std::shared_ptr> + create(Next&& next, Error&& error, int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + static std::shared_ptr> create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch = credits::kNoFlowControl); + + static std::shared_ptr> create() { + class NullSubscriber : public Subscriber { + void onSubscribe(std::shared_ptr s) override final { + s->request(credits::kNoFlowControl); + } + + void onNext(T) override final {} + void onComplete() override {} + void onError(folly::exception_wrapper) override {} + }; + return std::make_shared(); + } +}; + +namespace details { + +template +class BaseSubscriberDisposable; + +} // namespace details + +#define KEEP_REF_TO_THIS() \ + std::shared_ptr self; \ + if (keep_reference_to_this) { \ + self = this->ref_from_this(this); \ + } + +// T : Type of Flowable that this Subscriber operates on +// +// keep_reference_to_this : BaseSubscriber will keep a live reference to +// itself on the stack while in a signaling or requesting method, in case +// the derived class causes all other references to itself to be dropped. +// +// Classes that ensure that at least one reference will stay live can +// use `keep_reference_to_this = false` as an optimization to +// prevent an atomic inc/dec pair +template +class BaseSubscriber : public Subscriber, public yarpl::enable_get_ref { + public: + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + void onSubscribe(std::shared_ptr subscription) final override { + CHECK(subscription); + CHECK(!yarpl::atomic_load(&subscription_)); + +#ifndef NDEBUG + DCHECK(!gotOnSubscribe_.exchange(true)) + << "Already subscribed to BaseSubscriber"; +#endif + + yarpl::atomic_store(&subscription_, std::move(subscription)); + KEEP_REF_TO_THIS(); + onSubscribeImpl(); + } + + // No further calls to the subscription after this method is invoked. + void onComplete() final override { +#ifndef NDEBUG + DCHECK(gotOnSubscribe_.load()) << "Not subscribed to BaseSubscriber"; + DCHECK(!gotTerminating_.exchange(true)) + << "Already got terminating signal method"; +#endif + + std::shared_ptr null; + if (auto sub = yarpl::atomic_exchange(&subscription_, null)) { + KEEP_REF_TO_THIS(); + onCompleteImpl(); + onTerminateImpl(); + } + } + + // No further calls to the subscription after this method is invoked. + void onError(folly::exception_wrapper e) final override { +#ifndef NDEBUG + DCHECK(gotOnSubscribe_.load()) << "Not subscribed to BaseSubscriber"; + DCHECK(!gotTerminating_.exchange(true)) + << "Already got terminating signal method"; +#endif + + std::shared_ptr null; + if (auto sub = yarpl::atomic_exchange(&subscription_, null)) { + KEEP_REF_TO_THIS(); + onErrorImpl(std::move(e)); + onTerminateImpl(); + } + } + + void onNext(T t) final override { +#ifndef NDEBUG + DCHECK(gotOnSubscribe_.load()) << "Not subscibed to BaseSubscriber"; + if (gotTerminating_.load()) { + VLOG(2) << "BaseSubscriber already got terminating signal method"; + } +#endif + + if (auto sub = yarpl::atomic_load(&subscription_)) { + KEEP_REF_TO_THIS(); + onNextImpl(std::move(t)); + } + } + + void cancel() { + std::shared_ptr null; + if (auto sub = yarpl::atomic_exchange(&subscription_, null)) { + KEEP_REF_TO_THIS(); + sub->cancel(); + onTerminateImpl(); + } +#ifndef NDEBUG + else { + VLOG(2) << "cancel() on BaseSubscriber with no subscription_"; + } +#endif + } + + void request(int64_t n) { + if (auto sub = yarpl::atomic_load(&subscription_)) { + KEEP_REF_TO_THIS(); + sub->request(n); + } +#ifndef NDEBUG + else { + VLOG(2) << "request() on BaseSubscriber with no subscription_"; + } +#endif + } + + protected: + virtual void onSubscribeImpl() = 0; + virtual void onCompleteImpl() = 0; + virtual void onNextImpl(T) = 0; + virtual void onErrorImpl(folly::exception_wrapper) = 0; + + virtual void onTerminateImpl() {} + + private: + bool isTerminated() { + return !yarpl::atomic_load(&subscription_); + } + + friend class ::yarpl::flowable::details::BaseSubscriberDisposable; + + // keeps a reference alive to the subscription + AtomicReference subscription_; + +#ifndef NDEBUG + std::atomic gotOnSubscribe_{false}; + std::atomic gotTerminating_{false}; +#endif +}; + +namespace details { + +template +class BaseSubscriberDisposable : public Disposable { + public: + BaseSubscriberDisposable(std::shared_ptr> subscriber) + : subscriber_(std::move(subscriber)) {} + + void dispose() override { + if (auto sub = yarpl::atomic_exchange(&subscriber_, nullptr)) { + sub->cancel(); + } + } + + bool isDisposed() override { + if (auto sub = yarpl::atomic_load(&subscriber_)) { + return sub->isTerminated(); + } else { + return true; + } + } + + private: + AtomicReference> subscriber_; +}; + +template +class LambdaSubscriber : public BaseSubscriber { + public: + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + static std::shared_ptr> create( + Next&& next, + int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + static std::shared_ptr> + create(Next&& next, Error&& error, int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + static std::shared_ptr> create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch = credits::kNoFlowControl); +}; + +template +class Base : public LambdaSubscriber { + static_assert(std::is_same, Next>::value, "undecayed"); + + public: + template + Base(FNext&& next, int64_t batch) + : next_(std::forward(next)), batch_(batch), pending_(0) {} + + void onSubscribeImpl() override final { + pending_ = batch_; + this->request(batch_); + } + + void onNextImpl(T value) override final { + try { + next_(std::move(value)); + } catch (const std::exception& exn) { + this->cancel(); + auto ew = folly::exception_wrapper{std::current_exception(), exn}; + LOG(ERROR) << "'next' method should not throw: " << ew.what(); + onErrorImpl(ew); + return; + } + + if (--pending_ <= batch_ / 2) { + const auto delta = batch_ - pending_; + pending_ += delta; + this->request(delta); + } + } + + void onCompleteImpl() override {} + void onErrorImpl(folly::exception_wrapper) override {} + + private: + Next next_; + const int64_t batch_; + int64_t pending_; +}; + +template +class WithError : public Base { + static_assert(std::is_same, Error>::value, "undecayed"); + + public: + template + WithError(FNext&& next, FError&& error, int64_t batch) + : Base(std::forward(next), batch), + error_(std::forward(error)) {} + + void onErrorImpl(folly::exception_wrapper error) override final { + try { + error_(std::move(error)); + } catch (const std::exception& exn) { + LOG(ERROR) << "'error' method should not throw: " << exn.what(); + } + } + + private: + Error error_; +}; + +template +class WithErrorAndComplete : public WithError { + static_assert( + std::is_same, Complete>::value, + "undecayed"); + + public: + template + WithErrorAndComplete( + FNext&& next, + FError&& error, + FComplete&& complete, + int64_t batch) + : WithError( + std::forward(next), + std::forward(error), + batch), + complete_(std::forward(complete)) {} + + void onCompleteImpl() override final { + try { + complete_(); + } catch (const std::exception& exn) { + LOG(ERROR) << "'complete' method should not throw: " << exn.what(); + } + } + + private: + Complete complete_; +}; + +template +template +std::shared_ptr> LambdaSubscriber::create( + Next&& next, + int64_t batch) { + return std::make_shared>>( + std::forward(next), batch); +} + +template +template +std::shared_ptr> +LambdaSubscriber::create(Next&& next, Error&& error, int64_t batch) { + return std::make_shared< + details::WithError, std::decay_t>>( + std::forward(next), std::forward(error), batch); +} + +template +template +std::shared_ptr> LambdaSubscriber::create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch) { + return std::make_shared, + std::decay_t, + std::decay_t>>( + std::forward(next), + std::forward(error), + std::forward(complete), + batch); +} + +} // namespace details + +template +template +std::shared_ptr> Subscriber::create( + Next&& next, + int64_t batch) { + return details::LambdaSubscriber::create(std::forward(next), batch); +} + +template +template +std::shared_ptr> +Subscriber::create(Next&& next, Error&& error, int64_t batch) { + return details::LambdaSubscriber::create( + std::forward(next), std::forward(error), batch); +} + +template +template +std::shared_ptr> Subscriber::create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch) { + return details::LambdaSubscriber::create( + std::forward(next), + std::forward(error), + std::forward(complete), + batch); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Subscription.cpp b/yarpl/flowable/Subscription.cpp new file mode 100644 index 000000000..a49e1c97c --- /dev/null +++ b/yarpl/flowable/Subscription.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/Subscription.h" + +namespace yarpl { +namespace flowable { + +std::shared_ptr Subscription::create() { + class NullSubscription : public Subscription { + void request(int64_t) override {} + void cancel() override {} + }; + return std::make_shared(); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Subscription.h b/yarpl/flowable/Subscription.h new file mode 100644 index 000000000..bc4c49bbe --- /dev/null +++ b/yarpl/flowable/Subscription.h @@ -0,0 +1,87 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Refcounted.h" + +namespace yarpl { +namespace flowable { + +class Subscription { + public: + virtual ~Subscription() = default; + + virtual void request(int64_t n) = 0; + virtual void cancel() = 0; + + static std::shared_ptr create(); + + template + static std::shared_ptr create(CancelFunc&& onCancel); + + template + static std::shared_ptr create( + CancelFunc&& onCancel, + RequestFunc&& onRequest); +}; + +namespace detail { + +template +class CallbackSubscription : public Subscription { + static_assert( + std::is_same, CancelFunc>::value, + "undecayed"); + static_assert( + std::is_same, RequestFunc>::value, + "undecayed"); + + public: + template + CallbackSubscription(FCancel&& onCancel, FRequest&& onRequest) + : onCancel_(std::forward(onCancel)), + onRequest_(std::forward(onRequest)) {} + + void request(int64_t n) override { + onRequest_(n); + } + void cancel() override { + onCancel_(); + } + + private: + CancelFunc onCancel_; + RequestFunc onRequest_; +}; +} // namespace detail + +template +std::shared_ptr Subscription::create( + CancelFunc&& onCancel, + RequestFunc&& onRequest) { + return std::make_shared, + std::decay_t>>( + std::forward(onCancel), std::forward(onRequest)); +} + +template +std::shared_ptr Subscription::create(CancelFunc&& onCancel) { + return Subscription::create( + std::forward(onCancel), [](int64_t) {}); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/include/yarpl/flowable/TestSubscriber.h b/yarpl/flowable/TestSubscriber.h similarity index 80% rename from yarpl/include/yarpl/flowable/TestSubscriber.h rename to yarpl/flowable/TestSubscriber.h index 8ed3784d6..127b7fd0f 100644 --- a/yarpl/include/yarpl/flowable/TestSubscriber.h +++ b/yarpl/flowable/TestSubscriber.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -26,10 +38,8 @@ namespace flowable { * ts->assert... */ template -class TestSubscriber : - public BaseSubscriber, - public yarpl::flowable::Subscription -{ +class TestSubscriber : public BaseSubscriber, + public yarpl::flowable::Subscription { public: static_assert( std::is_copy_constructible::value, @@ -42,8 +52,9 @@ class TestSubscriber : * Create a TestSubscriber that will subscribe and store the value it * receives. */ - static Reference> create(int64_t initial = kNoFlowControl) { - return make_ref>(initial); + static std::shared_ptr> create( + int64_t initial = kNoFlowControl) { + return std::make_shared>(initial); } /** @@ -52,17 +63,17 @@ class TestSubscriber : * * This will store the value it receives to allow assertions. */ - static Reference> create( - Reference> delegate, + static std::shared_ptr> create( + std::shared_ptr> delegate, int64_t initial = kNoFlowControl) { - return make_ref>(std::move(delegate), initial); + return std::make_shared>(std::move(delegate), initial); } explicit TestSubscriber(int64_t initial = kNoFlowControl) - : TestSubscriber(Reference>{}, initial) {} + : TestSubscriber(std::shared_ptr>{}, initial) {} explicit TestSubscriber( - Reference> delegate, + std::shared_ptr> delegate, int64_t initial = kNoFlowControl) : delegate_(std::move(delegate)), initial_{initial} {} @@ -185,6 +196,10 @@ class TestSubscriber : return terminated_ && e_; } + const folly::exception_wrapper& exceptionWrapper() const { + return e_; + } + std::string getErrorMsg() const { return e_ ? e_.get_exception()->what() : ""; } @@ -229,6 +244,10 @@ class TestSubscriber : } } + folly::exception_wrapper getException() const { + return e_; + } + void dropValues(bool drop) { valueCount_ = getValueCount(); dropValues_ = drop; @@ -238,14 +257,14 @@ class TestSubscriber : bool dropValues_{false}; std::atomic valueCount_{0}; - Reference> delegate_; + std::shared_ptr> delegate_; std::vector values_; folly::exception_wrapper e_; int64_t initial_{kNoFlowControl}; bool terminated_{false}; std::mutex m_; std::condition_variable terminalEventCV_; - Reference subscription_; + std::shared_ptr subscription_; }; } // namespace flowable } // namespace yarpl diff --git a/yarpl/flowable/ThriftStreamShim.h b/yarpl/flowable/ThriftStreamShim.h new file mode 100644 index 000000000..7d42fef44 --- /dev/null +++ b/yarpl/flowable/ThriftStreamShim.h @@ -0,0 +1,263 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include +#if FOLLY_HAS_COROUTINES +#include +#include +#include +#endif +#include + +#include +#include +#include +#include + +namespace yarpl { +namespace flowable { +class ThriftStreamShim { + public: +#if FOLLY_HAS_COROUTINES + template + static std::shared_ptr> fromClientStream( + apache::thrift::ClientBufferedStream&& stream, + folly::Executor::KeepAlive<> ex) { + struct SharedState { + SharedState( + apache::thrift::detail::ClientStreamBridge::ClientPtr streamBridge, + folly::Executor::KeepAlive<> ex) + : streamBridge_(std::move(streamBridge)), + ex_(folly::SerialExecutor::create(std::move(ex))) {} + apache::thrift::detail::ClientStreamBridge::Ptr streamBridge_; + folly::Executor::KeepAlive ex_; + std::atomic canceled_{false}; + }; + + return yarpl::flowable::internal::flowableFromSubscriber( + [state = + std::make_shared(std::move(stream.streamBridge_), ex), + decode = + stream.decode_](std::shared_ptr> + subscriber) mutable { + class Subscription : public yarpl::flowable::Subscription { + public: + explicit Subscription(std::weak_ptr state) + : state_(std::move(state)) {} + + void request(int64_t n) override { + CHECK(n != yarpl::credits::kNoFlowControl) + << "kNoFlowControl unsupported"; + + if (auto state = state_.lock()) { + state->ex_->add([n, state = std::move(state)]() { + state->streamBridge_->requestN(n); + }); + } + } + + void cancel() override { + if (auto state = state_.lock()) { + state->ex_->add([state = std::move(state)]() { + state->streamBridge_->cancel(); + state->canceled_ = true; + }); + } + } + + private: + std::weak_ptr state_; + }; + + state->ex_->add([keepAlive = state->ex_.copy(), + subscriber, + subscription = std::make_shared( + std::weak_ptr(state))]() mutable { + subscriber->onSubscribe(std::move(subscription)); + }); + + folly::coro::co_invoke( + [subscriber = std::move(subscriber), + state, + decode]() mutable -> folly::coro::Task { + apache::thrift::detail::ClientStreamBridge::ClientQueue queue; + class ReadyCallback + : public apache::thrift::detail::ClientStreamConsumer { + public: + void consume() override { + baton.post(); + } + + void canceled() override { + baton.post(); + } + + folly::coro::Baton baton; + }; + + while (!state->canceled_) { + if (queue.empty()) { + ReadyCallback callback; + if (state->streamBridge_->wait(&callback)) { + co_await callback.baton; + } + queue = state->streamBridge_->getMessages(); + if (queue.empty()) { + // we've been cancelled + apache::thrift::detail::ClientStreamBridge::Ptr( + state->streamBridge_.release()); + break; + } + } + + { + auto& payload = queue.front(); + if (!payload.hasValue() && !payload.hasException()) { + state->ex_->add([subscriber = std::move(subscriber), + keepAlive = state->ex_.copy()] { + subscriber->onComplete(); + }); + break; + } + auto value = decode(std::move(payload)); + queue.pop(); + if (value.hasValue()) { + state->ex_->add([subscriber, + keepAlive = state->ex_.copy(), + value = std::move(value)]() mutable { + subscriber->onNext(std::move(value).value()); + }); + } else if (value.hasException()) { + state->ex_->add([subscriber = std::move(subscriber), + keepAlive = state->ex_.copy(), + value = std::move(value)]() mutable { + subscriber->onError(std::move(value).exception()); + }); + break; + } else { + LOG(FATAL) << "unreachable"; + } + } + } + }) + .scheduleOn(state->ex_) + .start(); + }); + } +#endif + + template + static apache::thrift::ServerStream toServerStream( + std::shared_ptr> flowable) { + class StreamServerCallbackAdaptor final + : public apache::thrift::StreamServerCallback, + public Subscriber { + public: + StreamServerCallbackAdaptor( + apache::thrift::detail::StreamElementEncoder* encode, + folly::EventBase* eb, + apache::thrift::TilePtr&& interaction) + : encode_(encode), + eb_(eb), + interaction_(apache::thrift::TileStreamGuard::transferFrom( + std::move(interaction))) {} + // StreamServerCallback implementation + bool onStreamRequestN(uint64_t tokens) override { + if (!subscription_) { + tokensBeforeSubscribe_ += tokens; + } else { + DCHECK_EQ(0, tokensBeforeSubscribe_); + subscription_->request(tokens); + } + return clientCallback_; + } + void onStreamCancel() override { + clientCallback_ = nullptr; + if (auto subscription = std::move(subscription_)) { + subscription->cancel(); + } + self_.reset(); + } + void resetClientCallback( + apache::thrift::StreamClientCallback& clientCallback) override { + clientCallback_ = &clientCallback; + } + + // Subscriber implementation + void onSubscribe(std::shared_ptr subscription) override { + eb_->add([this, subscription = std::move(subscription)]() mutable { + if (!clientCallback_) { + return subscription->cancel(); + } + + subscription_ = std::move(subscription); + if (auto tokens = std::exchange(tokensBeforeSubscribe_, 0)) { + subscription_->request(tokens); + } + }); + } + void onNext(T next) override { + eb_->add([this, next = std::move(next), s = self_]() mutable { + if (clientCallback_) { + std::ignore = + clientCallback_->onStreamNext(apache::thrift::StreamPayload{ + (*encode_)(std::move(next)).value().payload, {}}); + } + }); + } + void onError(folly::exception_wrapper ew) override { + eb_->add([this, ew = std::move(ew), s = self_]() mutable { + if (clientCallback_) { + std::exchange(clientCallback_, nullptr) + ->onStreamError((*encode_)(std::move(ew)).exception()); + self_.reset(); + } + }); + } + void onComplete() override { + eb_->add([this, s = self_] { + if (clientCallback_) { + std::exchange(clientCallback_, nullptr)->onStreamComplete(); + self_.reset(); + } + }); + } + + void takeRef(std::shared_ptr self) { + self_ = std::move(self); + } + + private: + apache::thrift::StreamClientCallback* clientCallback_{nullptr}; + std::shared_ptr subscription_; + uint32_t tokensBeforeSubscribe_{0}; + apache::thrift::detail::StreamElementEncoder* encode_; + folly::Executor::KeepAlive eb_; + std::shared_ptr self_; + apache::thrift::TileStreamGuard interaction_; + }; + + return apache::thrift::ServerStream( + [flowable = std::move(flowable)]( + folly::Executor::KeepAlive<>, + apache::thrift::detail::StreamElementEncoder* encode) mutable { + return apache::thrift::detail::ServerStreamFactory( + [flowable = std::move(flowable), encode]( + apache::thrift::FirstResponsePayload&& payload, + apache::thrift::StreamClientCallback* callback, + folly::EventBase* clientEb, + apache::thrift::TilePtr&& interaction) mutable { + auto stream = std::make_shared( + encode, clientEb, std::move(interaction)); + stream->takeRef(stream); + stream->resetClientCallback(*callback); + std::ignore = callback->onFirstResponse( + std::move(payload), clientEb, stream.get()); + flowable->subscribe(std::move(stream)); + }); + }); + } +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/include/yarpl/Disposable.h b/yarpl/include/yarpl/Disposable.h deleted file mode 100644 index 1c0451e06..000000000 --- a/yarpl/include/yarpl/Disposable.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -namespace yarpl { - -/** - * Represents a disposable resource. - */ -class Disposable { - public: - Disposable() {} - virtual ~Disposable() = default; - Disposable(Disposable&&) = delete; - Disposable(const Disposable&) = delete; - Disposable& operator=(Disposable&&) = delete; - Disposable& operator=(const Disposable&) = delete; - - /** - * Dispose the resource, the operation should be idempotent. - */ - virtual void dispose() = 0; - - /** - * Returns true if this resource has been disposed. - * @return true if this resource has been disposed - */ - virtual bool isDisposed() = 0; -}; -} diff --git a/yarpl/include/yarpl/Flowable.h b/yarpl/include/yarpl/Flowable.h deleted file mode 100644 index 488272411..000000000 --- a/yarpl/include/yarpl/Flowable.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -// include all the things a developer needs for using Flowable -#include "yarpl/flowable/Flowable.h" -#include "yarpl/flowable/Flowables.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscribers.h" -#include "yarpl/flowable/Subscription.h" - -/** - * // TODO add documentation - */ diff --git a/yarpl/include/yarpl/Observable.h b/yarpl/include/yarpl/Observable.h deleted file mode 100644 index c70f71f77..000000000 --- a/yarpl/include/yarpl/Observable.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -// include all the things a developer needs for using Observable -#include "yarpl/observable/Observable.h" -#include "yarpl/observable/Observables.h" -#include "yarpl/observable/Observer.h" -#include "yarpl/observable/Observers.h" -#include "yarpl/observable/Subscription.h" -#include "yarpl/observable/Subscriptions.h" - -/** - * // TODO add documentation - */ diff --git a/yarpl/include/yarpl/Refcounted.h b/yarpl/include/yarpl/Refcounted.h deleted file mode 100644 index dfdb1cf02..000000000 --- a/yarpl/include/yarpl/Refcounted.h +++ /dev/null @@ -1,568 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace yarpl { - -namespace detail { -struct skip_initial_refcount_check {}; -struct do_initial_refcount_check {}; - - -// refcount debugging utilities -using refcount_map_type = std::unordered_map; -void inc_created(std::string const&); - -void inc_live(std::string const&); -void dec_live(std::string const&); -void debug_refcounts(std::ostream& o); - -template -std::string type_name() -{ - int status; - std::string tname = typeid(T).name(); - char *demangled_name = abi::__cxa_demangle(tname.c_str(), NULL, NULL, &status); - if (status == 0) { - tname = demangled_name; - std::free(demangled_name); - } - return tname; -} - -#ifdef YARPL_REFCOUNT_DEBUGGING -struct set_reference_name; -#endif - -} /* namespace detail */ - -template -class Reference; - -template -class AtomicReference; - -/// Base of refcounted objects. The intention is the same as that -/// of boost::intrusive_ptr<>, except that we have virtual methods -/// anyway, and want to avoid argument-dependent lookup. -/// -/// NOTE: Only derive using "virtual public" inheritance. -class Refcounted { - public: - - /// dtor is thread safe because we cast thread_fence before - /// calling delete this - virtual ~Refcounted() = default; - - // Return the current count. For testing. - std::size_t count() const { - return refcount_; - } - - private: - template - friend class Reference; - template - friend class AtomicReference; - -#ifdef YARPL_REFCOUNT_DEBUGGING - friend struct detail::set_reference_name; -#endif - - void incRef() const { - refcount_.fetch_add(1, std::memory_order_relaxed); - } - - void decRef() const { - auto previous = refcount_.fetch_sub(1, std::memory_order_relaxed); - assert(previous >= 1 && "decRef on a destroyed object!"); - if (previous == 1) { - std::atomic_thread_fence(std::memory_order_acquire); -#ifdef YARPL_REFCOUNT_DEBUGGING - detail::dec_live(this->demangled_name_); -#endif - delete this; - } - } - - // refcount starts at 1 always, so we don't destroy ourselves in - // the constructor if we call `ref_from_this` in it - mutable std::atomic_size_t refcount_{1}; - -#ifdef YARPL_REFCOUNT_DEBUGGING - // for memory debugging and instrumentation - std::string demangled_name_; -#endif -}; - -#ifdef YARPL_REFCOUNT_DEBUGGING -namespace detail { -struct set_reference_name { - set_reference_name(std::string name, Refcounted& refcounted) { - refcounted.demangled_name_ = std::move(name); - } -}; -} -#endif - -/// RAII-enabling smart pointer for refcounted objects. Each reference -/// constructed against a target refcounted object increases its count by 1 -/// during its lifetime. -template -class Reference final { - public: - template - friend class Reference; - - template - friend class AtomicReference; - - Reference() = default; - inline /* implicit */ Reference(std::nullptr_t) {} - - explicit Reference(T* pointer, detail::skip_initial_refcount_check) - : pointer_(pointer) { - // newly constructed object in `make_ref` already had a refcount of 1, - // so don't increment it (we take 'ownership' of the reference made in - // make_ref) - if(pointer) { - assert(pointer->Refcounted::count() >= 1); - } - } - explicit Reference(T* pointer, detail::do_initial_refcount_check) - : pointer_(pointer) { - /** - * consider the following: - * - class MyClass : Refcounted { - MyClass() { - // count() == 0 - auto r = ref_from_this(this) - // count() == 1 - do_something_with(r); - // if do_something_with(r) doens't keep a reference to r somewhere, then - // count() == 0 - // and we call ~MyClass() within the constructor, which is a Bad Thing - } - }; - - * the check below prevents (at runtime) taking a reference in situations - * like this - */ - if(pointer) { - assert( - pointer->Refcounted::count() >= 1 && - "can't take an additional reference to something with a zero refcount"); - } - inc(); - } - - ~Reference() { - dec(); - } - - ////////////////////////////////////////////////////////////////////////////// - - Reference(const Reference& other) : pointer_(other.pointer_) { - inc(); - } - - Reference(Reference&& other) noexcept : pointer_(other.pointer_) { - other.pointer_ = nullptr; - } - - template - Reference(const Reference& other) : pointer_(other.pointer_) { - inc(); - } - - template - Reference(Reference&& other) : pointer_(other.pointer_) { - other.pointer_ = nullptr; - } - - Reference(AtomicReference const& other) : pointer_(other.pointer_) { - inc(); - } - - Reference(AtomicReference&& other) { - pointer_ = other.pointer_.exchange(nullptr); - } - - ////////////////////////////////////////////////////////////////////////////// - - Reference& operator=(std::nullptr_t) { - reset(); - return *this; - } - - Reference& operator=(const Reference& other) { - return assign(other); - } - - Reference& operator=(Reference&& other) { - return assign(std::move(other)); - } - - template - Reference& operator=(const Reference& other) { - return assign(other); - } - - template - Reference& operator=(Reference&& other) { - return assign(std::move(other)); - } - - ////////////////////////////////////////////////////////////////////////////// - - // TODO: remove this from public Reference API - T* get() const { - return pointer_; - } - - T& operator*() const { - return *pointer_; - } - - T* operator->() const { - return pointer_; - } - - void reset() { - Reference{}.swap(*this); - } - - explicit operator bool() const { - return pointer_; - } - - private: - void inc() { - static_assert( - std::is_base_of::value, - "Reference must be used with types that virtually derive Refcounted"); - - if (pointer_) { - pointer_->incRef(); - } - } - - void dec() { - static_assert( - std::is_base_of::value, - "Reference must be used with types that virtually derive Refcounted"); - - if (pointer_) { - pointer_->decRef(); - } - } - - void swap(Reference& other) { - std::swap(pointer_, other.pointer_); - } - - template - Reference& assign(Ref&& other) { - Reference temp(std::forward(other)); - swap(temp); - return *this; - } - - T* pointer_{nullptr}; -}; - -template -class AtomicReference final { - template - friend class AtomicReference; - template - friend class Reference; - - public: - AtomicReference() = default; - AtomicReference(Reference const& other) : pointer_(other.pointer_) { - inc(); - } - AtomicReference(AtomicReference const& other) : pointer_(other.pointer_.load()) { - inc(); - } - - template - AtomicReference(Reference&& other) : pointer_(other.pointer_) { - other.pointer_ = nullptr; - } - - ~AtomicReference() { - dec(); - } - - template - AtomicReference& operator=(Reference&& other) { - return assign(std::move(other)); - } - - Reference load() { - return Reference(*this); - } - - Reference exchange(Reference other) { - T* old = pointer_.exchange(other.get()); - inc(); - Reference r{old, detail::skip_initial_refcount_check{}}; - return r; - } - - void store(Reference const& other) { - dec(); - pointer_.store(other.get()); - inc(); - } - - explicit operator bool() const { - return pointer_; - } - - T* operator->() const { - return pointer_; - } - - private: - template - AtomicReference& assign(AtomicReference&& other) { - other.pointer_.store(pointer_.exchange(other.pointer_.load())); - return *this; - } - - template - AtomicReference& assign(Reference&& other) { - AtomicReference atomic_other{std::forward>(other)}; - return assign(std::move(atomic_other)); - } - - void inc() { - static_assert( - std::is_base_of::value, - "Reference must be used with types that virtually derive Refcounted"); - - if (auto p = pointer_.load()) { - p->incRef(); - } - } - - void dec() { - static_assert( - std::is_base_of::value, - "Reference must be used with types that virtually derive Refcounted"); - - if (auto p = pointer_.load()) { - p->decRef(); - } - } - - std::atomic pointer_{nullptr}; -}; - -template -bool operator==(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() == rhs.get(); -} - -template -bool operator==(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() == nullptr; -} - -template -bool operator==(std::nullptr_t, const Reference& rhs) noexcept { - return rhs.get() == nullptr; -} - -template -bool operator!=(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() != rhs.get(); -} - -template -bool operator!=(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() != nullptr; -} - -template -bool operator!=(std::nullptr_t, const Reference& rhs) noexcept { - return rhs.get() != nullptr; -} - -template -bool operator<(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() < rhs.get(); -} - -template -bool operator<(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() < nullptr; -} - -template -bool operator<(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr < rhs.get(); -} - -template -bool operator<=(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() <= rhs.get(); -} - -template -bool operator<=(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() <= nullptr; -} - -template -bool operator<=(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr <= rhs.get(); -} - -template -bool operator>(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() > rhs.get(); -} - -template -bool operator>(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() > nullptr; -} - -template -bool operator>(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr > rhs.get(); -} - -template -bool operator>=(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() >= rhs.get(); -} - -template -bool operator>=(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() >= nullptr; -} - -template -bool operator>=(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr >= rhs.get(); -} - -//////////////////////////////////////////////////////////////////////////////// - -template -Reference make_ref(Args&&... args) { - static_assert( - std::is_base_of>::value, - "Reference can only be constructed with a Refcounted object"); - - static_assert( - std::is_base_of, std::decay_t>::value, - "Concrete type must be a subclass of casted-to-type"); - - auto r = Reference( - new T(std::forward(args)...), - detail::skip_initial_refcount_check{} - ); - -#ifdef YARPL_REFCOUNT_DEBUGGING - auto demangled_name = detail::type_name(); - detail::inc_created(demangled_name); - detail::inc_live(demangled_name); - detail::set_reference_name{std::move(demangled_name), *r}; -#endif - - return std::move(r); -} - -class enable_get_ref { - private: -#ifdef DEBUG - // force the class to be polymorphic so we can dynamic_cast(this) - virtual void dummy_internal_get_ref() {} -#endif - - protected: - // materialize a reference to 'this', but a type even further derived from - // Derived, because C++ doesn't have covariant return types on methods - template - Reference ref_from_this(As* ptr) { - // at runtime, ensure that the most derived class can indeed be - // converted into an 'as' - (void) ptr; // silence 'unused parameter' errors in Release builds -#ifdef DEBUG - assert( - dynamic_cast(this) && "must be able to convert from this to As*"); -#endif - - assert( - static_cast(this) == ptr && - "must give 'this' as argument to ref_from_this(this)"); - - static_assert( - std::is_base_of::value, - "Inferred type must be a subclass of Refcounted"); - - return Reference( - static_cast(this), detail::do_initial_refcount_check{}); - } - - template - Reference ref_from_this(As const* ptr) const { - (void) ptr; // silence 'unused parameter' errors in Release builds -#ifdef DEBUG - assert( - dynamic_cast(this) && "must be able to convert from this to As*"); -#endif - - assert( - static_cast(this) == ptr && - "must give 'this' as argument to ref_from_this(this)"); - - static_assert( - std::is_base_of::value, - "Inferred type must be a subclass of Refcounted"); - - return Reference( - static_cast(this), detail::do_initial_refcount_check{}); - } -}; - -} // namespace yarpl - -// -// custom specialization of std::hash> -// -namespace std -{ -template -struct hash> -{ - typedef yarpl::Reference argument_type; - typedef typename std::hash::result_type result_type; - - result_type operator()(argument_type const& s) const - { - return std::hash()(s.get()); - } -}; -} diff --git a/yarpl/include/yarpl/Single.h b/yarpl/include/yarpl/Single.h deleted file mode 100644 index 2c2ad24b9..000000000 --- a/yarpl/include/yarpl/Single.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -// include all the things a developer needs for using Single -#include "yarpl/single/Single.h" -#include "yarpl/single/SingleObserver.h" -#include "yarpl/single/SingleObservers.h" -#include "yarpl/single/SingleSubscriptions.h" -#include "yarpl/single/Singles.h" - -/** - * Create a single with code such as this: - * - * auto a = Single::create([](Reference> obs) { - * obs->onSubscribe(SingleSubscriptions::empty()); - * obs->onSuccess(1); - * }); - * - * // TODO add more documentation - */ diff --git a/yarpl/include/yarpl/flowable/CancelingSubscriber.h b/yarpl/include/yarpl/flowable/CancelingSubscriber.h deleted file mode 100644 index 1764144eb..000000000 --- a/yarpl/include/yarpl/flowable/CancelingSubscriber.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/flowable/Subscriber.h" - -#include - -namespace yarpl { -namespace flowable { - -/** - * A Subscriber that always cancels the subscription passed to it. - */ -template -class CancelingSubscriber final : public BaseSubscriber { - public: - void onSubscribeImpl() override { - this->cancel(); - } - - void onNextImpl(T) override { - throw std::logic_error{"CancelingSubscriber::onNext() can never be called"}; - } - void onCompleteImpl() override { - throw std::logic_error{"CancelingSubscriber::onComplete() can never be called"}; - } - void onErrorImpl(folly::exception_wrapper) override { - throw std::logic_error{"CancelingSubscriber::onError() can never be called"}; - } -}; -} -} diff --git a/yarpl/include/yarpl/flowable/Flowable.h b/yarpl/include/yarpl/flowable/Flowable.h deleted file mode 100644 index 02e15f5c3..000000000 --- a/yarpl/include/yarpl/flowable/Flowable.h +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include - -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscribers.h" -#include "yarpl/utils/credits.h" -#include "yarpl/utils/type_traits.h" - -#include -#include - -namespace yarpl { -namespace flowable { - -template -class Flowable; - -namespace detail { - -template -struct IsFlowable : std::false_type {}; - -template -struct IsFlowable>> : std::true_type { - using ElemType = R; -}; - -} // namespace detail - -template -class Flowable : public virtual Refcounted, public yarpl::enable_get_ref { - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename = - typename std::enable_if::value>::type> - void subscribe(Next next, int64_t batch = credits::kNoFlowControl) { - subscribe(Subscribers::create(std::move(next), batch)); - } - - /** - * Subscribe overload that accepts lambdas. - * - * Takes an optional batch size for request_n. Default is no flow control. - */ - template < - typename Next, - typename Error, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value>::type> - void - subscribe(Next next, Error error, int64_t batch = credits::kNoFlowControl) { - subscribe(Subscribers::create(std::move(next), std::move(error), batch)); - } - - /** - * Subscribe overload that accepts lambdas. - * - * Takes an optional batch size for request_n. Default is no flow control. - */ - template < - typename Next, - typename Error, - typename Complete, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value && - folly::is_invocable::value>::type> - void subscribe( - Next next, - Error error, - Complete complete, - int64_t batch = credits::kNoFlowControl) { - subscribe(Subscribers::create( - std::move(next), std::move(error), std::move(complete), batch)); - } - - template < - typename Function, - typename R = typename std::result_of::type> - Reference> map(Function function); - - template < - typename Function, - typename R = typename detail::IsFlowable< - typename std::result_of::type>::ElemType> - Reference> flatMap(Function func); - - template - Reference> filter(Function function); - - template < - typename Function, - typename R = typename std::result_of::type> - Reference> reduce(Function function); - - Reference> take(int64_t); - - Reference> skip(int64_t); - - Reference> ignoreElements(); - - Reference> subscribeOn(folly::Executor&); - - Reference> observeOn(folly::Executor&); - - template - using enableWrapRef = - typename std::enable_if::value, Q>::type; - - template - enableWrapRef merge() { - return this->flatMap([](auto f) { return std::move(f); }); - } - - template < - typename Emitter, - typename = typename std::enable_if, - Emitter, Reference>, int64_t - >::value>::type> - static Reference> create(Emitter emitter); -}; - -} // flowable -} // yarpl - -#include "yarpl/flowable/EmitterFlowable.h" -#include "yarpl/flowable/FlowableOperator.h" - -namespace yarpl { -namespace flowable { - -template -template -Reference> Flowable::create(Emitter emitter) { - return make_ref>(std::move(emitter)); -} - -template -template -Reference> Flowable::map(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -template -template -Reference> Flowable::filter(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -template -template -Reference> Flowable::reduce(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -template -Reference> Flowable::take(int64_t limit) { - return make_ref>(this->ref_from_this(this), limit); -} - -template -Reference> Flowable::skip(int64_t offset) { - return make_ref>(this->ref_from_this(this), offset); -} - -template -Reference> Flowable::ignoreElements() { - return make_ref>(this->ref_from_this(this)); -} - -template -Reference> Flowable::subscribeOn(folly::Executor& executor) { - return make_ref>(this->ref_from_this(this), executor); -} - -template -Reference> Flowable::observeOn(folly::Executor& executor) { - return make_ref>( - this->ref_from_this(this), executor); -} - -template -template -Reference> Flowable::flatMap(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/FlowableObserveOnOperator.h b/yarpl/include/yarpl/flowable/FlowableObserveOnOperator.h deleted file mode 100644 index ceadb895f..000000000 --- a/yarpl/include/yarpl/flowable/FlowableObserveOnOperator.h +++ /dev/null @@ -1,109 +0,0 @@ -#pragma once - -namespace yarpl { -namespace flowable { -namespace detail { - -template -class ObserveOnOperatorSubscriber; - -template -class ObserveOnOperatorSubscription : public yarpl::flowable::Subscription, - public yarpl::enable_get_ref { - public: - ObserveOnOperatorSubscription( - Reference> subscriber, - Reference subscription) - : subscriber_(std::move(subscriber)), - subscription_(std::move(subscription)) {} - - // all requesting methods are called from 'executor_' in the - // associated subscriber - void cancel() override { - auto self = this->ref_from_this(this); - - if (auto subscriber = std::move(subscriber_)) { - subscriber->isCanceled_ = true; - } - - subscription_->cancel(); - } - - void request(int64_t n) override { - subscription_->request(n); - } - - private: - Reference> subscriber_; - Reference subscription_; -}; - -template -class ObserveOnOperatorSubscriber - : public yarpl::flowable::Subscriber { - public: - ObserveOnOperatorSubscriber( - Reference> inner, - folly::Executor& executor) - : inner_(std::move(inner)), executor_(executor) {} - - // all signaling methods are called from upstream EB - void onSubscribe(Reference subscription) override { - executor_.add([ - self = this->ref_from_this(this), - s = std::move(subscription) - ]() mutable { - auto subscription = - make_ref>(self, std::move(s)); - self->inner_->onSubscribe(std::move(subscription)); - }); - } - void onNext(T next) override { - executor_.add( - [ self = this->ref_from_this(this), n = std::move(next) ]() mutable { - if (!self->isCanceled_) { - self->inner_->onNext(std::move(n)); - } - }); - } - void onComplete() override { - executor_.add([self = this->ref_from_this(this)]() mutable { - if (!self->isCanceled_) { - self->inner_->onComplete(); - } - }); - } - void onError(folly::exception_wrapper err) override { - executor_.add( - [ self = this->ref_from_this(this), e = std::move(err) ]() mutable { - if (!self->isCanceled_) { - self->inner_->onError(std::move(e)); - } - }); - } - - private: - friend class ObserveOnOperatorSubscription; - bool isCanceled_{false}; // only accessed in executor_ thread - - Reference> inner_; - folly::Executor& executor_; -}; - -template -class ObserveOnOperator : public yarpl::flowable::Flowable { - public: - ObserveOnOperator(Reference> upstream, folly::Executor& executor) - : upstream_(std::move(upstream)), executor_(executor) {} - - void subscribe(Reference> subscriber) override { - upstream_->subscribe(make_ref>( - std::move(subscriber), executor_)); - } - - Reference> upstream_; - folly::Executor& executor_; -}; -} -} -} /* namespace yarpl::flowable::detail */ diff --git a/yarpl/include/yarpl/flowable/Flowable_FromObservable.h b/yarpl/include/yarpl/flowable/Flowable_FromObservable.h deleted file mode 100644 index 74203092e..000000000 --- a/yarpl/include/yarpl/flowable/Flowable_FromObservable.h +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include "yarpl/Flowable.h" -#include "yarpl/utils/credits.h" - -namespace yarpl { -namespace observable { -template -class Observable; -} -} - -namespace yarpl { -namespace flowable { - -// Exception thrown in case the downstream can't keep up. -class MissingBackpressureException : public std::runtime_error { - public: - MissingBackpressureException() - : std::runtime_error("BACK_PRESSURE: DROP (missing credits onNext)") {} -}; - -namespace details { - -template -class FlowableFromObservableSubscription : public flowable::Subscription, - public observable::Observer { - public: - FlowableFromObservableSubscription( - Reference> observable, - Reference> subscriber) - : observable_(std::move(observable)), - subscriber_(std::move(subscriber)) {} - - FlowableFromObservableSubscription(FlowableFromObservableSubscription&&) = - delete; - - FlowableFromObservableSubscription( - const FlowableFromObservableSubscription&) = delete; - FlowableFromObservableSubscription& operator=( - FlowableFromObservableSubscription&&) = delete; - FlowableFromObservableSubscription& operator=( - const FlowableFromObservableSubscription&) = delete; - - void request(int64_t n) override { - if (n <= 0) { - return; - } - auto r = credits::add(&requested_, n); - if (r <= 0) { - return; - } - - // it is possible that after calling subscribe or in onCreditsAvailable - // methods, there will be a stream of - // onNext calls which the processing chain might cancel. The cancel signal - // will remove all references to this class and we need to keep this - // instance around to finish this method - auto thisPtr = this->ref_from_this(this); - - if (!started.exchange(true)) { - observable_->subscribe(this->ref_from_this(this)); - - // the credits might have changed since subscribe - r = requested_.load(); - } - - if (r > 0) { - onCreditsAvailable(r); - } - } - - void cancel() override { - if (observable::Observer::isUnsubscribedOrTerminated()) { - return; - } - observable::Observer::subscription()->cancel(); - credits::cancel(&requested_); - } - - // Observer override - void onNext(T t) override { - if (requested_ > 0) { - subscriber_->onNext(std::move(t)); - credits::consume(&requested_, 1); - return; - } - onNextWithoutCredits(std::move(t)); - } - - // Observer override - void onComplete() override { - if (observable::Observer::isUnsubscribedOrTerminated()) { - return; - } - auto subscriber = std::move(subscriber_); - subscriber->onComplete(); - observable::Observer::onComplete(); - } - - // Observer override - void onError(folly::exception_wrapper error) override { - if (observable::Observer::isUnsubscribedOrTerminated()) { - return; - } - auto subscriber = std::move(subscriber_); - subscriber->onError(std::move(error)); - observable::Observer::onError(folly::exception_wrapper()); - } - - protected: - virtual void onCreditsAvailable(int64_t /*credits*/) {} - virtual void onNextWithoutCredits(T /*t*/) { - // by default drop anything else received while we don't have credits - } - - Reference> observable_; - Reference> subscriber_; - std::atomic_bool started{false}; - std::atomic requested_{0}; -}; - -template -using FlowableFromObservableSubscriptionDropStrategy = - FlowableFromObservableSubscription; - -template -class FlowableFromObservableSubscriptionErrorStrategy - : public FlowableFromObservableSubscription { - using Super = FlowableFromObservableSubscription; - - public: - using Super::FlowableFromObservableSubscription; - - private: - void onNextWithoutCredits(T /*t*/) override { - if (observable::Observer::isUnsubscribedOrTerminated()) { - return; - } - Super::onError(MissingBackpressureException()); - Super::cancel(); - } -}; - -template -class FlowableFromObservableSubscriptionBufferStrategy - : public FlowableFromObservableSubscription { - using Super = FlowableFromObservableSubscription; - - public: - using Super::FlowableFromObservableSubscription; - - private: - void onComplete() override { - if (!buffer_->empty()) { - // we have buffered some items so we will defer delivering on complete for - // later - completed_ = true; - } else { - Super::onComplete(); - } - } - - // - // onError signal is delivered immediately by design - // - - void onNextWithoutCredits(T t) override { - if (Super::isUnsubscribed()) { - return; - } - - buffer_->push_back(std::move(t)); - } - - void onCreditsAvailable(int64_t credits) override { - DCHECK(credits > 0); - auto&& lockedBuffer = buffer_.wlock(); - while (credits-- > 0 && !lockedBuffer->empty()) { - Super::onNext(std::move(lockedBuffer->front())); - lockedBuffer->pop_front(); - } - - if (lockedBuffer->empty() && completed_) { - Super::onComplete(); - } - } - - folly::Synchronized> buffer_; - std::atomic completed_{false}; -}; - -template -class FlowableFromObservableSubscriptionLatestStrategy - : public FlowableFromObservableSubscription { - using Super = FlowableFromObservableSubscription; - - public: - using Super::FlowableFromObservableSubscription; - - private: - void onComplete() override { - if (storesLatest_) { - // we have buffered an item so we will defer delivering on complete for - // later - completed_ = true; - } else { - Super::onComplete(); - } - } - - // - // onError signal is delivered immediately by design - // - - void onNextWithoutCredits(T t) override { - storesLatest_ = true; - *latest_.wlock() = std::move(t); - } - - void onCreditsAvailable(int64_t credits) override { - DCHECK(credits > 0); - if (storesLatest_) { - storesLatest_ = false; - Super::onNext(std::move(*latest_.wlock())); - - if (completed_) { - Super::onComplete(); - } - } - } - - std::atomic storesLatest_{false}; - std::atomic completed_{false}; - folly::Synchronized latest_; -}; - -template -class FlowableFromObservableSubscriptionMissingStrategy - : public FlowableFromObservableSubscription { - using Super = FlowableFromObservableSubscription; - - public: - using Super::FlowableFromObservableSubscription; - - private: - void onNextWithoutCredits(T t) override { - Super::onNext(std::move(t)); - } -}; -} -} -} diff --git a/yarpl/include/yarpl/flowable/Flowables.h b/yarpl/include/yarpl/flowable/Flowables.h deleted file mode 100644 index f866215d3..000000000 --- a/yarpl/include/yarpl/flowable/Flowables.h +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include - -#include - -#include "yarpl/flowable/Flowable.h" - -#include - -namespace yarpl { -namespace flowable { - -class Flowables { - public: - /** - * Emit the sequence of numbers [start, start + count). - */ - static Reference> range(int64_t start, int64_t count) { - auto lambda = [ start, count, i = start ]( - Reference> subscriber, int64_t requested) mutable { - int64_t emitted = 0; - bool done = false; - int64_t end = start + count; - - while (i < end && emitted < requested) { - subscriber->onNext(i++); - ++emitted; - } - - if (i >= end) { - subscriber->onComplete(); - done = true; - } - - return std::make_tuple(requested, done); - }; - - return Flowable::create(std::move(lambda)); - } - - template - static Reference> just(const T& value) { - auto lambda = [value](Reference> subscriber, int64_t) { - // # requested should be > 0. Ignoring the actual parameter. - subscriber->onNext(value); - subscriber->onComplete(); - return std::make_tuple(static_cast(1), true); - }; - - return Flowable::create(std::move(lambda)); - } - - template - static Reference> justN(std::initializer_list list) { - std::vector vec(list); - - auto lambda = [ v = std::move(vec), i = size_t{0} ]( - Reference> subscriber, int64_t requested) mutable { - int64_t emitted = 0; - bool done = false; - - while (i < v.size() && emitted < requested) { - subscriber->onNext(v[i++]); - ++emitted; - } - - if (i == v.size()) { - subscriber->onComplete(); - done = true; - } - - return std::make_tuple(emitted, done); - }; - - return Flowable::create(std::move(lambda)); - } - - // this will generate a flowable which can be subscribed to only once - template - static Reference> justOnce(T value) { - auto lambda = [ value = std::move(value), used = false ]( - Reference> subscriber, int64_t) mutable { - if (used) { - subscriber->onError( - std::runtime_error("justOnce value was already used")); - return std::make_tuple(static_cast(0), true); - } - - used = true; - // # requested should be > 0. Ignoring the actual parameter. - subscriber->onNext(std::move(value)); - subscriber->onComplete(); - return std::make_tuple(static_cast(1), true); - }; - - return Flowable::create(std::move(lambda)); - } - - template < - typename T, - typename OnSubscribe, - typename = typename std::enable_if>>::value>::type> - static Reference> fromPublisher(OnSubscribe function) { - return make_ref>(std::move(function)); - } - - template - static Reference> empty() { - auto lambda = [](Reference> subscriber, int64_t) { - subscriber->onComplete(); - return std::make_tuple(static_cast(0), true); - }; - return Flowable::create(std::move(lambda)); - } - - template - static Reference> error(folly::exception_wrapper ex) { - auto lambda = [ex = std::move(ex)]( - Reference> subscriber, int64_t) { - subscriber->onError(std::move(ex)); - return std::make_tuple(static_cast(0), true); - }; - return Flowable::create(std::move(lambda)); - } - - template - static Reference> error(const ExceptionType& ex) { - auto lambda = [ex = std::move(ex)]( - Reference> subscriber, int64_t) { - subscriber->onError(std::move(ex)); - return std::make_tuple(static_cast(0), true); - }; - return Flowable::create(std::move(lambda)); - } - - template - static Reference> fromGenerator(TGenerator generator) { - auto lambda = [generator = std::move(generator)]( - Reference> subscriber, int64_t requested) { - int64_t generated = 0; - try { - while (generated < requested) { - subscriber->onNext(generator()); - ++generated; - } - return std::make_tuple(generated, false); - } catch (const std::exception& ex) { - subscriber->onError( - folly::exception_wrapper(std::current_exception(), ex)); - return std::make_tuple(generated, true); - } catch (...) { - subscriber->onError(std::runtime_error("unknown error")); - return std::make_tuple(generated, true); - } - }; - return Flowable::create(std::move(lambda)); - } - - private: - Flowables() = delete; -}; - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/Subscriber.h b/yarpl/include/yarpl/flowable/Subscriber.h deleted file mode 100644 index 9d9fc94fb..000000000 --- a/yarpl/include/yarpl/flowable/Subscriber.h +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" -#include "yarpl/flowable/Subscription.h" - -#include - -#include - -namespace yarpl { -namespace flowable { - -template -class Subscriber : public virtual Refcounted, public yarpl::enable_get_ref { - public: - virtual void onSubscribe(Reference) = 0; - virtual void onComplete() = 0; - virtual void onError(folly::exception_wrapper) = 0; - virtual void onNext(T) = 0; -}; - -#define KEEP_REF_TO_THIS() \ - Reference self; \ - if (keep_reference_to_this) { \ - self = this->ref_from_this(this); \ - } - -// T : Type of Flowable that this Subscriber operates on -// -// keep_reference_to_this : BaseSubscriber will keep a live reference to -// itself on the stack while in a signaling or requesting method, in case -// the derived class causes all other references to itself to be dropped. -// -// Classes that ensure that at least one reference will stay live can -// use `keep_reference_to_this = false` as an optimization to -// prevent an atomic inc/dec pair -template -class BaseSubscriber : public Subscriber { - public: - // Note: If any of the following methods is overridden in a subclass, the new - // methods SHOULD ensure that these are invoked as well. - void onSubscribe(Reference subscription) final override { - DCHECK(subscription); - CHECK(!subscription_.load()); - -#ifdef DEBUG - DCHECK(!gotOnSubscribe_.exchange(true)) - << "Already subscribed to BaseSubscriber"; -#endif - - subscription_.store(subscription); - KEEP_REF_TO_THIS(); - onSubscribeImpl(); - } - - // No further calls to the subscription after this method is invoked. - void onComplete() final override { -#ifdef DEBUG - DCHECK(gotOnSubscribe_.load()) << "Not subscribed to BaseSubscriber"; - DCHECK(!gotTerminating_.exchange(true)) - << "Already got terminating signal method"; -#endif - - if(auto sub = subscription_.exchange(nullptr)) { - KEEP_REF_TO_THIS(); - onCompleteImpl(); - onTerminateImpl(); - } - } - - // No further calls to the subscription after this method is invoked. - void onError(folly::exception_wrapper e) final override { -#ifdef DEBUG - DCHECK(gotOnSubscribe_.load()) << "Not subscribed to BaseSubscriber"; - DCHECK(!gotTerminating_.exchange(true)) - << "Already got terminating signal method"; -#endif - - if(auto sub = subscription_.exchange(nullptr)) { - KEEP_REF_TO_THIS(); - onErrorImpl(std::move(e)); - onTerminateImpl(); - } - } - - void onNext(T t) final override { -#ifdef DEBUG - DCHECK(gotOnSubscribe_.load()) << "Not subscibed to BaseSubscriber"; - if (gotTerminating_.load()) { - VLOG(2) << "BaseSubscriber already got terminating signal method"; - } -#endif - - if(auto sub = subscription_.load()) { - KEEP_REF_TO_THIS(); - onNextImpl(std::move(t)); - } - } - - void cancel() { - if(auto sub = subscription_.exchange(nullptr)) { - KEEP_REF_TO_THIS(); - sub->cancel(); - onTerminateImpl(); - } -#ifdef DEBUG - else { - VLOG(2) << "cancel() on BaseSubscriber with no subscription_"; - } -#endif - } - - void request(int64_t n) { - if(auto sub = subscription_.load()) { - KEEP_REF_TO_THIS(); - sub->request(n); - } -#ifdef DEBUG - else { - VLOG(2) << "request() on BaseSubscriber with no subscription_"; - } -#endif - } - -protected: - virtual void onSubscribeImpl() = 0; - virtual void onCompleteImpl() = 0; - virtual void onNextImpl(T) = 0; - virtual void onErrorImpl(folly::exception_wrapper) = 0; - - virtual void onTerminateImpl() {} - - private: - // keeps a reference alive to the subscription - AtomicReference subscription_; - -#ifdef DEBUG - std::atomic gotOnSubscribe_{false}; - std::atomic gotTerminating_{false}; -#endif -}; - -} -} /* namespace yarpl::flowable */ diff --git a/yarpl/include/yarpl/flowable/Subscribers.h b/yarpl/include/yarpl/flowable/Subscribers.h deleted file mode 100644 index 9f8ee49d6..000000000 --- a/yarpl/include/yarpl/flowable/Subscribers.h +++ /dev/null @@ -1,159 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include - -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/utils/credits.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace flowable { - -/// Helper methods for constructing subscriber instances from functions: -/// one, two, or three functions (callables; can be lamda, for instance) -/// may be specified, corresponding to onNext, onError and onSubscribe -/// method bodies in the subscriber. -class Subscribers { - constexpr static auto kNoFlowControl = credits::kNoFlowControl; - - public: - template < - typename T, - typename Next, - typename = - typename std::enable_if::value>::type> - static Reference> create( - Next next, - int64_t batch = kNoFlowControl) { - return make_ref>(std::move(next), batch); - } - - template < - typename T, - typename Next, - typename Error, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value>::type> - static Reference> - create(Next next, Error error, int64_t batch = kNoFlowControl) { - return make_ref>( - std::move(next), std::move(error), batch); - } - - template < - typename T, - typename Next, - typename Error, - typename Complete, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value && - folly::is_invocable::value>::type> - static Reference> create( - Next next, - Error error, - Complete complete, - int64_t batch = kNoFlowControl) { - return make_ref>( - std::move(next), std::move(error), std::move(complete), batch); - } - - private: - template - class Base : public BaseSubscriber { - public: - Base(Next next, int64_t batch) - : next_(std::move(next)), batch_(batch), pending_(0) {} - - void onSubscribeImpl() override final { - pending_ += batch_; - this->request(batch_); - } - - void onNextImpl(T value) override final { - try { - next_(std::move(value)); - } catch (const std::exception& exn) { - this->cancel(); - auto ew = folly::exception_wrapper{std::current_exception(), exn}; - LOG(ERROR) << "'next' method should not throw: " << ew.what(); - onErrorImpl(ew); - return; - } - - if (--pending_ < batch_ / 2) { - const auto delta = batch_ - pending_; - pending_ += delta; - this->request(delta); - } - } - - void onCompleteImpl() override {} - void onErrorImpl(folly::exception_wrapper) override {} - - private: - Next next_; - const int64_t batch_; - int64_t pending_; - }; - - template - class WithError : public Base { - public: - WithError(Next next, Error error, int64_t batch) - : Base(std::move(next), batch), error_(std::move(error)) {} - - void onErrorImpl(folly::exception_wrapper error) override final { - try { - error_(std::move(error)); - } catch (const std::exception& exn) { - auto ew = folly::exception_wrapper{std::current_exception(), exn}; - LOG(ERROR) << "'error' method should not throw: " << ew.what(); -#ifndef NDEBUG - throw ew; // Throw the wrapped exception -#endif - } - } - - private: - Error error_; - }; - - template - class WithErrorAndComplete : public WithError { - public: - WithErrorAndComplete( - Next next, - Error error, - Complete complete, - int64_t batch) - : WithError(std::move(next), std::move(error), batch), - complete_(std::move(complete)) {} - - void onCompleteImpl() override final { - try { - complete_(); - } catch (const std::exception& exn) { - auto ew = folly::exception_wrapper{std::current_exception(), exn}; - LOG(ERROR) << "'complete' method should not throw: " << ew.what(); -#ifndef NDEBUG - throw ew; // Throw the wrapped exception -#endif - } - } - - private: - Complete complete_; - }; - - Subscribers() = delete; -}; - -} // namespace flowable -} // namespace yarpl diff --git a/yarpl/include/yarpl/flowable/Subscription.h b/yarpl/include/yarpl/flowable/Subscription.h deleted file mode 100644 index b4ccfe523..000000000 --- a/yarpl/include/yarpl/flowable/Subscription.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -namespace yarpl { -namespace flowable { - -class Subscription : public virtual Refcounted { - public: - virtual ~Subscription() = default; - - virtual void request(int64_t n) = 0; - virtual void cancel() = 0; - - static yarpl::Reference empty(); -}; - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Observable.h b/yarpl/include/yarpl/observable/Observable.h deleted file mode 100644 index c25fc0e26..000000000 --- a/yarpl/include/yarpl/observable/Observable.h +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "yarpl/utils/type_traits.h" - -#include "yarpl/Refcounted.h" -#include "yarpl/observable/Observer.h" -#include "yarpl/observable/Observers.h" -#include "yarpl/observable/Subscription.h" - -#include "yarpl/Flowable.h" -#include "yarpl/flowable/Flowable_FromObservable.h" - -#include - -namespace yarpl { -namespace observable { - -/** -*Strategy for backpressure when converting from Observable to Flowable. -*/ -enum class BackpressureStrategy { BUFFER, DROP, ERROR, LATEST, MISSING }; - -template -class Observable : public virtual Refcounted, public yarpl::enable_get_ref { - public: - virtual Reference subscribe(Reference>) = 0; - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename = - typename std::enable_if::value>::type> - Reference subscribe(Next next) { - return subscribe(Observers::create(std::move(next))); - } - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename Error, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value>::type> - Reference subscribe(Next next, Error error) { - return subscribe(Observers::create( - std::move(next), std::move(error))); - } - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename Error, - typename Complete, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value && - folly::is_invocable::value>::type> - Reference subscribe(Next next, Error error, Complete complete) { - return subscribe(Observers::create( - std::move(next), - std::move(error), - std::move(complete))); - } - - Reference subscribe() { - return subscribe(Observers::createNull()); - } - - template - static Reference> create(OnSubscribe); - - template < - typename Function, - typename R = typename std::result_of::type> - Reference> map(Function function); - - template - Reference> filter(Function function); - - template < - typename Function, - typename R = typename std::result_of::type> - Reference> reduce(Function function); - - Reference> take(int64_t); - - Reference> skip(int64_t); - - Reference> ignoreElements(); - - Reference> subscribeOn(folly::Executor&); - - // function is invoked when onComplete occurs. - template - Reference> doOnSubscribe(Function function); - - // function is invoked when onNext occurs. - template - Reference> doOnNext(Function function); - - // function is invoked when onError occurs. - template - Reference> doOnError(Function function); - - // function is invoked when onComplete occurs. - template - Reference> doOnComplete(Function function); - - // function is invoked when either onComplete or onError occurs. - template - Reference> doOnTerminate(Function function); - - // the function is invoked for each of onNext, onCompleted, onError - template - Reference> doOnEach(Function function); - - // the callbacks will be invoked of each of the signals - template - Reference> doOn(OnNextFunc onNext, OnCompleteFunc onComplete); - - // the callbacks will be invoked of each of the signals - template - Reference> doOn(OnNextFunc onNext, OnCompleteFunc onComplete, OnErrorFunc onError); - - /** - * Convert from Observable to Flowable with a given BackpressureStrategy. - * - * Currently the only strategy is DROP. - */ - auto toFlowable(BackpressureStrategy strategy); -}; -} // observable -} // yarpl - -#include "yarpl/observable/ObservableOperator.h" - -namespace yarpl { -namespace observable { - -template -template -Reference> Observable::create(OnSubscribe function) { - static_assert( - folly::is_invocable>>::value, - "OnSubscribe must have type `void(Reference>)`"); - - return make_ref>( - std::move(function)); -} - -template -template -Reference> Observable::map(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -template -template -Reference> Observable::filter(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -template -template -Reference> Observable::reduce(Function function) { - return make_ref>( - this->ref_from_this(this), std::move(function)); -} - -template -Reference> Observable::take(int64_t limit) { - return make_ref>(this->ref_from_this(this), limit); -} - -template -Reference> Observable::skip(int64_t offset) { - return make_ref>(this->ref_from_this(this), offset); -} - -template -Reference> Observable::ignoreElements() { - return make_ref>(this->ref_from_this(this)); -} - -template -Reference> Observable::subscribeOn(folly::Executor& executor) { - return make_ref>(this->ref_from_this(this), executor); -} - -template -template -Reference> Observable::doOnSubscribe(Function function) { - return details::createDoOperator(ref_from_this(this), std::move(function), [](const T&){}, [](const auto&){}, []{}); -} - -template -template -Reference> Observable::doOnNext(Function function) { - return details::createDoOperator(ref_from_this(this), []{}, std::move(function), [](const auto&){}, []{}); -} - -template -template -Reference> Observable::doOnError(Function function) { - return details::createDoOperator(ref_from_this(this), []{}, [](const T&){}, std::move(function), []{}); -} - -template -template -Reference> Observable::doOnComplete(Function function) { - return details::createDoOperator(ref_from_this(this), []{}, [](const T&){}, [](const auto&){}, std::move(function)); -} - -template -template -Reference> Observable::doOnTerminate(Function function) { - auto sharedFunction = std::make_shared(std::move(function)); - return details::createDoOperator(ref_from_this(this), []{}, [](const T&){}, [sharedFunction](const auto&){(*sharedFunction)();}, [sharedFunction](){(*sharedFunction)();}); - -} - -template -template -Reference> Observable::doOnEach(Function function) { - auto sharedFunction = std::make_shared(std::move(function)); - return details::createDoOperator(ref_from_this(this), []{}, [sharedFunction](const T&){(*sharedFunction)();}, [sharedFunction](const auto&){(*sharedFunction)();}, [sharedFunction](){(*sharedFunction)();}); -} - -template -template -Reference> Observable::doOn(OnNextFunc onNext, OnCompleteFunc onComplete) { - return details::createDoOperator(ref_from_this(this), []{}, std::move(onNext), [](const auto&){}, std::move(onComplete)); -} - -template -template -Reference> Observable::doOn(OnNextFunc onNext, OnCompleteFunc onComplete, OnErrorFunc onError) { - return details::createDoOperator(ref_from_this(this), []{}, std::move(onNext), std::move(onError), std::move(onComplete)); -} - -template -auto Observable::toFlowable(BackpressureStrategy strategy) { - // we currently ONLY support the DROP strategy - // so do not use the strategy parameter for anything - return yarpl::flowable::Flowables::fromPublisher([ - thisObservable = this->ref_from_this(this), - strategy - ](Reference> subscriber) { - Reference subscription; - switch (strategy) { - case BackpressureStrategy::DROP: - subscription = - make_ref>( - thisObservable, subscriber); - break; - case BackpressureStrategy::ERROR: - subscription = - make_ref>( - thisObservable, subscriber); - break; - case BackpressureStrategy::BUFFER: - subscription = - make_ref>( - thisObservable, subscriber); - break; - case BackpressureStrategy::LATEST: - subscription = - make_ref>( - thisObservable, subscriber); - break; - case BackpressureStrategy::MISSING: - subscription = - make_ref>( - thisObservable, subscriber); - break; - default: - CHECK(false); // unknown value for strategy - } - subscriber->onSubscribe(std::move(subscription)); - }); -} - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/ObservableDoOperator.h b/yarpl/include/yarpl/observable/ObservableDoOperator.h deleted file mode 100644 index 5e1d22e43..000000000 --- a/yarpl/include/yarpl/observable/ObservableDoOperator.h +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -namespace yarpl { -namespace observable { - -template < - typename U, - typename OnSubscribeFunc, - typename OnNextFunc, - typename OnErrorFunc, - typename OnCompleteFunc> -class DoOperator : public ObservableOperator> { - using ThisOperatorT = DoOperator; - using Super = ObservableOperator; - - public: - DoOperator(Reference> upstream, - OnSubscribeFunc onSubscribeFunc, - OnNextFunc onNextFunc, - OnErrorFunc onErrorFunc, - OnCompleteFunc onCompleteFunc) - : Super(std::move(upstream)), - onSubscribeFunc_(std::move(onSubscribeFunc)), - onNextFunc_(std::move(onNextFunc)), - onErrorFunc_(std::move(onErrorFunc)), - onCompleteFunc_(std::move(onCompleteFunc)) {} - - Reference subscribe(Reference> observer) override { - auto subscription = - make_ref(this->ref_from_this(this), std::move(observer)); - Super::upstream_->subscribe( - // Note: implicit cast to a reference to a observer. - subscription); - return subscription; - } - - private: - class DoSubscription : public Super::OperatorSubscription { - using SuperSub = typename Super::OperatorSubscription; - - public: - DoSubscription( - Reference observable, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)) {} - - void onSubscribe( - Reference subscription) override { - auto&& op = SuperSub::getObservableOperator(); - op->onSubscribeFunc_(); - SuperSub::onSubscribe(std::move(subscription)); - } - - void onNext(U value) override { - auto&& op = SuperSub::getObservableOperator(); - const auto& valueRef = value; - op->onNextFunc_(valueRef); - SuperSub::observerOnNext(std::move(value)); - } - - void onError(folly::exception_wrapper ex) override { - auto&& op = SuperSub::getObservableOperator(); - const auto& exRef = ex; - op->onErrorFunc_(exRef); - SuperSub::onError(std::move(ex)); - } - - void onComplete() override { - auto&& op = SuperSub::getObservableOperator(); - op->onCompleteFunc_(); - SuperSub::onComplete(); - } - }; - - OnSubscribeFunc onSubscribeFunc_; - OnNextFunc onNextFunc_; - OnErrorFunc onErrorFunc_; - OnCompleteFunc onCompleteFunc_; -}; - -namespace details { - -template < - typename U, - typename OnSubscribeFunc, - typename OnNextFunc, - typename OnErrorFunc, - typename OnCompleteFunc> -inline auto createDoOperator(Reference> upstream, - OnSubscribeFunc onSubscribeFunc, - OnNextFunc onNextFunc, - OnErrorFunc onErrorFunc, - OnCompleteFunc onCompleteFunc) { - return make_ref>( - std::move(upstream), std::move(onSubscribeFunc), std::move(onNextFunc), - std::move(onErrorFunc), std::move(onCompleteFunc)); -} - -} - -} -} diff --git a/yarpl/include/yarpl/observable/Observables.h b/yarpl/include/yarpl/observable/Observables.h deleted file mode 100644 index 718c286f9..000000000 --- a/yarpl/include/yarpl/observable/Observables.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/observable/Observable.h" -#include "yarpl/observable/Subscriptions.h" - -#include - -namespace yarpl { -namespace observable { - -class Observables { - public: - static Reference> range(int64_t start, int64_t end) { - auto lambda = [start, end](Reference> observer) { - for (int64_t i = start; i < end; ++i) { - observer->onNext(i); - } - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - template - static Reference> just(const T& value) { - auto lambda = [value](Reference> observer) { - observer->onNext(value); - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - template - static Reference> justN(std::initializer_list list) { - std::vector vec(list); - - auto lambda = [v = std::move(vec)](Reference> observer) { - for (auto const& elem : v) { - observer->onNext(elem); - } - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - // this will generate an observable which can be subscribed to only once - template - static Reference> justOnce(T value) { - auto lambda = [ value = std::move(value), used = false ]( - Reference> observer) mutable { - if (used) { - observer->onError( - std::runtime_error("justOnce value was already used")); - return; - } - - used = true; - observer->onNext(std::move(value)); - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - template < - typename T, - typename OnSubscribe, - typename = typename std::enable_if< - folly::is_invocable>>::value>:: - type> - static Reference> create(OnSubscribe function) { - return make_ref>(std::move(function)); - } - - template - static Reference> empty() { - auto lambda = [](Reference> observer) { - observer->onComplete(); - }; - return Observable::create(std::move(lambda)); - } - - template - static Reference> error(folly::exception_wrapper ex) { - auto lambda = [ex = std::move(ex)](Reference> observer) { - observer->onError(std::move(ex)); - }; - return Observable::create(std::move(lambda)); - } - - template - static Reference> error(const ExceptionType& ex) { - auto lambda = [ex = std::move(ex)](Reference> observer) { - observer->onError(std::move(ex)); - }; - return Observable::create(std::move(lambda)); - } - - private: - Observables() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Observer.h b/yarpl/include/yarpl/observable/Observer.h deleted file mode 100644 index 8f2a833f6..000000000 --- a/yarpl/include/yarpl/observable/Observer.h +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" -#include "yarpl/observable/Subscriptions.h" - -#include - -#include - -namespace yarpl { -namespace observable { - -template -class Observer : public virtual Refcounted, public yarpl::enable_get_ref { - public: - // Note: If any of the following methods is overridden in a subclass, the new - // methods SHOULD ensure that these are invoked as well. - virtual void onSubscribe(Reference subscription) { - DCHECK(subscription); - - if (subscription_) { - DLOG(ERROR) << "attempt to double subscribe"; - subscription->cancel(); - return; - } - - subscription_ = std::move(subscription); - } - - // No further calls to the subscription after this method is invoked. - virtual void onComplete() { - DCHECK(subscription_) << "Calling onComplete() without a subscription"; - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(folly::exception_wrapper) { - DCHECK(subscription_) << "Calling onError() without a subscription"; - subscription_.reset(); - } - - virtual void onNext(T) = 0; - - bool isUnsubscribed() const { - CHECK(subscription_); - return subscription_->isCancelled(); - } - - // Ability to add more subscription objects which will be notified when the - // subscription has been cancelled. - // Note that calling cancel on the tied subscription is not going to cancel - // this subscriber - void addSubscription(Reference subscription) { - if(!subscription_) { - subscription->cancel(); - return; - } - subscription_->tieSubscription(std::move(subscription)); - } - - template - void addSubscription(OnCancel onCancel) { - addSubscription(Subscriptions::create(std::move(onCancel))); - } - - bool isUnsubscribedOrTerminated() const { - return !subscription_ || subscription_->isCancelled(); - } - - protected: - Subscription* subscription() { - return subscription_.operator->(); - } - - void unsubscribe() { - CHECK(subscription_); - subscription_->cancel(); - } - - private: - Reference subscription_; -}; -} -} diff --git a/yarpl/include/yarpl/observable/Observers.h b/yarpl/include/yarpl/observable/Observers.h deleted file mode 100644 index a5cfb4e7e..000000000 --- a/yarpl/include/yarpl/observable/Observers.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include -#include - -#include "yarpl/observable/Observer.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace observable { - -/// Helper methods for constructing subscriber instances from functions: -/// one, two, or three functions (callables; can be lamda, for instance) -/// may be specified, corresponding to onNext, onError and onComplete -/// method bodies in the subscriber. -class Observers { - private: - /// Defined if Next, Error and Complete are signature-compatible with - /// onNext, onError and onComplete subscriber methods respectively. - template < - typename T, - typename Next, - typename Error = void (*)(folly::exception_wrapper), - typename Complete = void (*)()> - using EnableIfCompatible = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value && - folly::is_invocable::value>::type; - - public: - template > - static auto create(Next next) { - return make_ref, Observer>(std::move(next)); - } - - template < - typename T, - typename Next, - typename Error, - typename = EnableIfCompatible> - static auto create(Next next, Error error) { - return make_ref, Observer>( - std::move(next), std::move(error)); - } - - template < - typename T, - typename Next, - typename Error, - typename Complete, - typename = EnableIfCompatible> - static auto create(Next next, Error error, Complete complete) { - return make_ref< - WithErrorAndComplete, - Observer>(std::move(next), std::move(error), std::move(complete)); - } - - template - static auto createNull() { - return make_ref>(); - } - - private: - template - class NullObserver : public Observer { - public: - void onNext(T) {} - }; - - template - class Base : public Observer { - public: - explicit Base(Next next) : next_(std::move(next)) {} - - void onNext(T value) override { - next_(std::move(value)); - } - - private: - Next next_; - }; - - template - class WithError : public Base { - public: - WithError(Next next, Error error) - : Base(std::move(next)), error_(std::move(error)) {} - - void onError(folly::exception_wrapper error) override { - error_(std::move(error)); - } - - private: - Error error_; - }; - - template - class WithErrorAndComplete : public WithError { - public: - WithErrorAndComplete(Next next, Error error, Complete complete) - : WithError( - std::move(next), - std::move(error)), - complete_(std::move(complete)) {} - - void onComplete() override { - complete_(); - } - - private: - Complete complete_; - }; - - Observers() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Subscription.h b/yarpl/include/yarpl/observable/Subscription.h deleted file mode 100644 index 7479b85cb..000000000 --- a/yarpl/include/yarpl/observable/Subscription.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include "yarpl/Refcounted.h" - -namespace yarpl { -namespace observable { - -class Subscription : public virtual Refcounted { - public: - virtual ~Subscription() = default; - virtual void cancel(); - bool isCancelled() const; - - // Adds ability to tie another subscription to this instance. - // Whenever *this subscription is cancelled then all tied subscriptions get - // cancelled as well - void tieSubscription(Reference subscription); - - protected: - std::atomic cancelled_{false}; - folly::Synchronized>> tiedSubscriptions_; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Subscriptions.h b/yarpl/include/yarpl/observable/Subscriptions.h deleted file mode 100644 index 1d40c35e4..000000000 --- a/yarpl/include/yarpl/observable/Subscriptions.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/observable/Subscription.h" - -namespace yarpl { -namespace observable { - -/** -* Implementation that gets a callback when cancellation occurs. -*/ -class CallbackSubscription : public Subscription { - public: - explicit CallbackSubscription(std::function onCancel); - void cancel() override; - - private: - std::function onCancel_; -}; - -class Subscriptions { - public: - static Reference create(std::function onCancel); - static Reference create(std::atomic_bool& cancelled); - static Reference create(); -}; - -} // observable namespace -} // yarpl namespace diff --git a/yarpl/include/yarpl/single/Single.h b/yarpl/include/yarpl/single/Single.h deleted file mode 100644 index 10dc8332c..000000000 --- a/yarpl/include/yarpl/single/Single.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/single/SingleObserver.h" -#include "yarpl/single/SingleObservers.h" -#include "yarpl/single/SingleSubscription.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace single { - -template -class Single : public virtual Refcounted, public yarpl::enable_get_ref { - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Success, - typename = - typename std::enable_if::value>::type> - void subscribe(Success next) { - subscribe(SingleObservers::create(std::move(next))); - } - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Success, - typename Error, - typename = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value>::type> - void subscribe(Success next, Error error) { - subscribe(SingleObservers::create(std::move(next), std::move(error))); - } - - /** - * Blocking subscribe that accepts lambdas. - * - * This blocks the current thread waiting on the response. - */ - template < - typename Success, - typename = - typename std::enable_if::value>::type> - void subscribeBlocking(Success next) { - auto waiting_ = std::make_shared>(); - subscribe( - SingleObservers::create([ next = std::move(next), waiting_ ](T t) { - next(std::move(t)); - waiting_->post(); - })); - // TODO get errors and throw if one is received - waiting_->wait(); - } - - template < - typename OnSubscribe, - typename = typename std::enable_if>>::value>::type> - static Reference> create(OnSubscribe); - - template - auto map(Function function); -}; - -template <> -class Single : public virtual Refcounted { - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload taking lambda for onSuccess that is called upon writing - * to the network. - */ - template < - typename Success, - typename = - typename std::enable_if::value>::type> - void subscribe(Success s) { - class SuccessSingleObserver : public SingleObserverBase { - public: - SuccessSingleObserver(Success success) : success_{std::move(success)} {} - - void onSubscribe(Reference subscription) override { - SingleObserverBase::onSubscribe(std::move(subscription)); - } - - void onSuccess() override { - success_(); - SingleObserverBase::onSuccess(); - } - - // No further calls to the subscription after this method is invoked. - void onError(folly::exception_wrapper ex) override { - SingleObserverBase::onError(std::move(ex)); - } - - private: - Success success_; - }; - - subscribe(make_ref(std::move(s))); - } - - template < - typename OnSubscribe, - typename = typename std::enable_if>>::value>::type> - static auto create(OnSubscribe); -}; - -} // namespace single -} // namespace yarpl - -#include "yarpl/single/SingleOperator.h" - -namespace yarpl { -namespace single { - -template -template -Reference> Single::create(OnSubscribe function) { - return make_ref>(std::move(function)); -} - -template -auto Single::create(OnSubscribe function) { - return make_ref, Single>( - std::forward(function)); -} - -template -template -auto Single::map(Function function) { - using D = typename std::result_of::type; - return make_ref, Single>( - this->ref_from_this(this), std::move(function)); -} - -} // namespace single -} // namespace yarpl diff --git a/yarpl/include/yarpl/single/SingleObserver.h b/yarpl/include/yarpl/single/SingleObserver.h deleted file mode 100644 index fa82015c2..000000000 --- a/yarpl/include/yarpl/single/SingleObserver.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" -#include "yarpl/single/SingleSubscription.h" - -#include - -#include - -namespace yarpl { -namespace single { - -template -class SingleObserver : public virtual Refcounted, public yarpl::enable_get_ref { -public: - virtual void onSubscribe(Reference) = 0; - virtual void onSuccess(T) = 0; - virtual void onError(folly::exception_wrapper) = 0; -}; - -template -class SingleObserverBase : public SingleObserver { - public: - // Note: If any of the following methods is overridden in a subclass, the new - // methods SHOULD ensure that these are invoked as well. - void onSubscribe(Reference subscription) override { - DCHECK(subscription); - - if (subscription_) { - subscription->cancel(); - return; - } - - subscription_ = std::move(subscription); - } - - void onSuccess(T) override { - DCHECK(subscription_) << "Calling onSuccess() without a subscription"; - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - void onError(folly::exception_wrapper) override { - DCHECK(subscription_) << "Calling onError() without a subscription"; - subscription_.reset(); - } - - protected: - SingleSubscription* subscription() { - return subscription_.operator->(); - } - - private: - Reference subscription_; -}; - -/// Specialization of SingleObserverBase. -template <> -class SingleObserverBase : public virtual Refcounted { - public: - // Note: If any of the following methods is overridden in a subclass, the new - // methods SHOULD ensure that these are invoked as well. - virtual void onSubscribe(Reference subscription) { - DCHECK(subscription); - - if (subscription_) { - subscription->cancel(); - return; - } - - subscription_ = std::move(subscription); - } - - virtual void onSuccess() { - DCHECK(subscription_) << "Calling onSuccess() without a subscription"; - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(folly::exception_wrapper) { - DCHECK(subscription_) << "Calling onError() without a subscription"; - subscription_.reset(); - } - - protected: - SingleSubscription* subscription() { - return subscription_.operator->(); - } - - private: - Reference subscription_; -}; -} -} diff --git a/yarpl/include/yarpl/single/SingleObservers.h b/yarpl/include/yarpl/single/SingleObservers.h deleted file mode 100644 index 508e489dc..000000000 --- a/yarpl/include/yarpl/single/SingleObservers.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/utils/type_traits.h" - -#include "yarpl/single/SingleObserver.h" - -#include - -namespace yarpl { -namespace single { - -/// Helper methods for constructing subscriber instances from functions: -/// one or two functions (callables; can be lamda, for instance) -/// may be specified, corresponding to onNext, onError and onComplete -/// method bodies in the subscriber. -class SingleObservers { - private: - /// Defined if Success and Error are signature-compatible with - /// onSuccess and onError subscriber methods respectively. - template < - typename T, - typename Success, - typename Error = void (*)(folly::exception_wrapper)> - using EnableIfCompatible = typename std::enable_if< - folly::is_invocable::value && - folly::is_invocable::value>::type; - - public: - template > - static auto create(Next next) { - return make_ref, SingleObserverBase>(std::move(next)); - } - - template < - typename T, - typename Success, - typename Error, - typename = EnableIfCompatible> - static auto create(Success next, Error error) { - return make_ref, SingleObserverBase>( - std::move(next), std::move(error)); - } - - private: - template - class Base : public SingleObserverBase { - public: - explicit Base(Next next) : next_(std::move(next)) {} - - void onSuccess(T value) override { - next_(std::move(value)); - // TODO how do we call the super to trigger release? - // SingleObserver::onSuccess(value); - } - - private: - Next next_; - }; - - template - class WithError : public Base { - public: - WithError(Success next, Error error) - : Base(std::move(next)), error_(std::move(error)) {} - - void onError(folly::exception_wrapper error) override { - error_(error); - // TODO do we call the super here to trigger release? - Base::onError(std::move(error)); - } - - private: - Error error_; - }; - - SingleObservers() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/single/SingleSubscription.h b/yarpl/include/yarpl/single/SingleSubscription.h deleted file mode 100644 index b22faf838..000000000 --- a/yarpl/include/yarpl/single/SingleSubscription.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -namespace yarpl { -namespace single { - -class SingleSubscription : public virtual Refcounted { - public: - virtual ~SingleSubscription() = default; - virtual void cancel() = 0; - - protected: - SingleSubscription() {} -}; - -} // single -} // yarpl diff --git a/yarpl/include/yarpl/single/Singles.h b/yarpl/include/yarpl/single/Singles.h deleted file mode 100644 index 1b8d661ca..000000000 --- a/yarpl/include/yarpl/single/Singles.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/utils/type_traits.h" - -#include "yarpl/single/Single.h" -#include "yarpl/single/SingleSubscriptions.h" - -#include - -namespace yarpl { -namespace single { - -class Singles { - public: - template - static Reference> just(const T& value) { - auto lambda = [value](Reference> observer) { - observer->onSubscribe(SingleSubscriptions::empty()); - observer->onSuccess(value); - }; - - return Single::create(std::move(lambda)); - } - - template < - typename T, - typename OnSubscribe, - typename = typename std::enable_if>>::value>::type> - static Reference> create(OnSubscribe function) { - return make_ref>(std::move(function)); - } - - template - static Reference> error(folly::exception_wrapper ex) { - auto lambda = [e = std::move(ex)](Reference> observer) { - observer->onSubscribe(SingleSubscriptions::empty()); - observer->onError(e); - }; - return Single::create(std::move(lambda)); - } - - template - static Reference> error(const ExceptionType& ex) { - auto lambda = [ex](Reference> observer) { - observer->onSubscribe(SingleSubscriptions::empty()); - observer->onError(ex); - }; - return Single::create(std::move(lambda)); - } - - template - static Reference> fromGenerator(TGenerator generator) { - auto lambda = [generator = std::move(generator)]( - Reference> observer) mutable { - observer->onSubscribe(SingleSubscriptions::empty()); - observer->onSuccess(generator()); - }; - return Single::create(std::move(lambda)); - } - - private: - Singles() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/utils/type_traits.h b/yarpl/include/yarpl/utils/type_traits.h deleted file mode 100644 index eb2fb35e3..000000000 --- a/yarpl/include/yarpl/utils/type_traits.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#if __cplusplus < 201500 - -namespace std { - -namespace implementation { - -template -struct is_callable : std::false_type {}; - -template -struct is_callable< - F(Args...), - R, - std::enable_if_t>::value>> - : std::true_type {}; - -} // implementation - -template -struct is_callable : implementation::is_callable {}; - -} // std -#endif // __cplusplus diff --git a/yarpl/observable/DeferObservable.h b/yarpl/observable/DeferObservable.h new file mode 100644 index 000000000..302aeecaa --- /dev/null +++ b/yarpl/observable/DeferObservable.h @@ -0,0 +1,50 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/observable/Observable.h" + +namespace yarpl { +namespace observable { +namespace details { + +template +class DeferObservable : public Observable { + static_assert( + std::is_same, ObservableFactory>::value, + "undecayed"); + + public: + template + explicit DeferObservable(F&& factory) : factory_(std::forward(factory)) {} + + virtual std::shared_ptr subscribe( + std::shared_ptr> observer) { + std::shared_ptr> observable; + try { + observable = factory_(); + } catch (const std::exception& ex) { + observable = Observable::error(ex, std::current_exception()); + } + return observable->subscribe(std::move(observer)); + } + + private: + ObservableFactory factory_; +}; + +} // namespace details +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Observable.h b/yarpl/observable/Observable.h new file mode 100644 index 000000000..30729c360 --- /dev/null +++ b/yarpl/observable/Observable.h @@ -0,0 +1,560 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "yarpl/Refcounted.h" +#include "yarpl/observable/Observer.h" +#include "yarpl/observable/Subscription.h" + +#include "yarpl/Common.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/Flowable_FromObservable.h" + +namespace yarpl { + +namespace observable { + +template +class Observable : public yarpl::enable_get_ref { + public: + static std::shared_ptr> empty() { + auto lambda = [](std::shared_ptr> observer) { + observer->onComplete(); + }; + return Observable::create(std::move(lambda)); + } + + static std::shared_ptr> error(folly::exception_wrapper ex) { + auto lambda = + [ex = std::move(ex)](std::shared_ptr> observer) mutable { + observer->onError(std::move(ex)); + }; + return Observable::create(std::move(lambda)); + } + + template + static std::shared_ptr> error(Ex&) { + static_assert( + std::is_lvalue_reference::value, + "use variant of error() method accepting also exception_ptr"); + } + + template + static std::shared_ptr> error(Ex& ex, std::exception_ptr ptr) { + auto lambda = [ew = folly::exception_wrapper(std::move(ptr), ex)]( + std::shared_ptr> observer) mutable { + observer->onError(std::move(ew)); + }; + return Observable::create(std::move(lambda)); + } + + static std::shared_ptr> just(T value) { + auto lambda = + [value = std::move(value)](std::shared_ptr> observer) { + observer->onNext(value); + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); + } + + /** + * The Defer operator waits until an observer subscribes to it, and then it + * generates an Observable with an ObservableFactory function. It + * does this afresh for each subscriber, so although each subscriber may + * think it is subscribing to the same Observable, in fact each subscriber + * gets its own individual sequence. + */ + template < + typename ObservableFactory, + typename = typename std::enable_if>, + std::decay_t&>::value>::type> + static std::shared_ptr> defer(ObservableFactory&&); + + static std::shared_ptr> justN(std::initializer_list list) { + auto lambda = [v = std::vector(std::move(list))]( + std::shared_ptr> observer) { + for (auto const& elem : v) { + observer->onNext(elem); + } + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); + } + + // this will generate an observable which can be subscribed to only once + static std::shared_ptr> justOnce(T value) { + auto lambda = [value = std::move(value), used = false]( + std::shared_ptr> observer) mutable { + if (used) { + observer->onError( + std::runtime_error("justOnce value was already used")); + return; + } + + used = true; + observer->onNext(std::move(value)); + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); + } + + template + static std::shared_ptr> create(OnSubscribe&&); + + template + static std::shared_ptr> createEx(OnSubscribe&&); + + virtual std::shared_ptr subscribe( + std::shared_ptr>) = 0; + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + std::shared_ptr subscribe(Next&& next) { + return subscribe(Observer::create(std::forward(next))); + } + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + std::shared_ptr subscribe(Next&& next, Error&& error) { + return subscribe(Observer::create( + std::forward(next), std::forward(error))); + } + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + std::shared_ptr + subscribe(Next&& next, Error&& error, Complete&& complete) { + return subscribe(Observer::create( + std::forward(next), + std::forward(error), + std::forward(complete))); + } + + std::shared_ptr subscribe() { + return subscribe(Observer::create()); + } + + template < + typename Function, + typename R = typename folly::invoke_result_t> + std::shared_ptr> map(Function&& function); + + template + std::shared_ptr> filter(Function&& function); + + template < + typename Function, + typename R = typename folly::invoke_result_t> + std::shared_ptr> reduce(Function&& function); + + std::shared_ptr> take(int64_t); + + std::shared_ptr> skip(int64_t); + + std::shared_ptr> ignoreElements(); + + std::shared_ptr> subscribeOn(folly::Executor&); + + std::shared_ptr> concatWith(std::shared_ptr>); + + template + std::shared_ptr> concatWith( + std::shared_ptr> first, + Args... args) { + return concatWith(first)->concatWith(args...); + } + + template + static std::shared_ptr> concat( + std::shared_ptr> first, + Args... args) { + return first->concatWith(args...); + } + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnSubscribe(Function&& function); + + // function is invoked when onNext occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>::type> + std::shared_ptr> doOnNext(Function&& function); + + // function is invoked when onError occurs. + template < + typename Function, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> doOnError(Function&& function); + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnComplete(Function&& function); + + // function is invoked when either onComplete or onError occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnTerminate(Function&& function); + + // the function is invoked for each of onNext, onCompleted, onError + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnEach(Function&& function); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> + doOn(OnNextFunc&& onNext, OnCompleteFunc&& onComplete, OnErrorFunc&& onError); + + // function is invoked when cancel is called. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnCancel(Function&& function); + + /** + * Convert from Observable to Flowable with a given BackpressureStrategy. + */ + auto toFlowable(BackpressureStrategy strategy); + + /** + * Convert from Observable to Flowable with a given BackpressureStrategy. + */ + auto toFlowable(std::shared_ptr> strategy); +}; +} // namespace observable +} // namespace yarpl + +#include "yarpl/observable/DeferObservable.h" +#include "yarpl/observable/ObservableOperator.h" + +namespace yarpl { +namespace observable { + +template +template +std::shared_ptr> Observable::create(OnSubscribe&& function) { + static_assert( + folly::is_invocable>>::value, + "OnSubscribe must have type `void(std::shared_ptr>)`"); + + return createEx([func = std::forward(function)]( + std::shared_ptr> observer, + std::shared_ptr) mutable { + func(std::move(observer)); + }); +} + +template +template +std::shared_ptr> Observable::createEx(OnSubscribe&& function) { + static_assert( + folly::is_invocable< + OnSubscribe&&, + std::shared_ptr>, + std::shared_ptr>::value, + "OnSubscribe must have type " + "`void(std::shared_ptr>, std::shared_ptr)`"); + + return std::make_shared>>( + std::forward(function)); +} + +template +template +std::shared_ptr> Observable::defer( + ObservableFactory&& factory) { + return std::make_shared< + details::DeferObservable>>( + std::forward(factory)); +} + +template +template +std::shared_ptr> Observable::map(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +template +std::shared_ptr> Observable::filter(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +template +std::shared_ptr> Observable::reduce(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +std::shared_ptr> Observable::take(int64_t limit) { + return std::make_shared>(this->ref_from_this(this), limit); +} + +template +std::shared_ptr> Observable::skip(int64_t offset) { + return std::make_shared>(this->ref_from_this(this), offset); +} + +template +std::shared_ptr> Observable::ignoreElements() { + return std::make_shared>(this->ref_from_this(this)); +} + +template +std::shared_ptr> Observable::subscribeOn( + folly::Executor& executor) { + return std::make_shared>( + this->ref_from_this(this), executor); +} + +template +template +std::shared_ptr> Observable::doOnSubscribe( + Function&& function) { + return details::createDoOperator( + ref_from_this(this), + std::forward(function), + [](const T&) {}, + [](const auto&) {}, + [] {}, + [] {}); // onCancel +} + +template +std::shared_ptr> Observable::concatWith( + std::shared_ptr> next) { + return std::make_shared>( + this->ref_from_this(this), std::move(next)); +} + +template +template +std::shared_ptr> Observable::doOnNext(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(function), + [](const auto&) {}, + [] {}, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnError(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + std::forward(function), + [] {}, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnComplete( + Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [](const auto&) {}, + std::forward(function), + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnTerminate( + Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnEach(Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [sharedFunction](const T&) { (*sharedFunction)(); }, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + [](const auto&) {}, + std::forward(onComplete), + [] {}); // onCancel +} + +template +template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename, + typename, + typename> +std::shared_ptr> Observable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete, + OnErrorFunc&& onError) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + std::forward(onError), + std::forward(onComplete), + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnCancel(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, // onSubscribe + [](const auto&) {}, // onNext + [](const auto&) {}, // onError + [] {}, // onComplete + std::forward(function)); // onCancel +} + +template +auto Observable::toFlowable(BackpressureStrategy strategy) { + switch (strategy) { + case BackpressureStrategy::DROP: + return toFlowable(IBackpressureStrategy::drop()); + case BackpressureStrategy::ERROR: + return toFlowable(IBackpressureStrategy::error()); + case BackpressureStrategy::BUFFER: + return toFlowable(IBackpressureStrategy::buffer()); + case BackpressureStrategy::LATEST: + return toFlowable(IBackpressureStrategy::latest()); + case BackpressureStrategy::MISSING: + return toFlowable(IBackpressureStrategy::missing()); + default: + CHECK(false); // unknown value for strategy + } +} + +template +auto Observable::toFlowable( + std::shared_ptr> strategy) { + return yarpl::flowable::internal::flowableFromSubscriber( + [thisObservable = this->ref_from_this(this), + strategy = std::move(strategy)]( + std::shared_ptr> subscriber) { + strategy->init(std::move(thisObservable), std::move(subscriber)); + }); +} + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/ObservableConcatOperators.h b/yarpl/observable/ObservableConcatOperators.h new file mode 100644 index 000000000..4a3879a4e --- /dev/null +++ b/yarpl/observable/ObservableConcatOperators.h @@ -0,0 +1,154 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/observable/ObservableOperator.h" + +namespace yarpl { +namespace observable { +namespace details { + +template +class ConcatWithOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + ConcatWithOperator( + std::shared_ptr> first, + std::shared_ptr> second) + : first_(std::move(first)), second_(std::move(second)) { + CHECK(first_); + CHECK(second_); + } + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(observer, first_, second_); + subscription->init(); + + return subscription; + } + + private: + class ForwardObserver; + + // Downstream will always point to this subscription + class ConcatWithSubscription + : public yarpl::observable::Subscription, + public std::enable_shared_from_this { + public: + ConcatWithSubscription( + std::shared_ptr> observer, + std::shared_ptr> first, + std::shared_ptr> second) + : downObserver_(std::move(observer)), + first_(std::move(first)), + second_(std::move(second)) {} + + void init() { + upObserver_ = std::make_shared(this->shared_from_this()); + downObserver_->onSubscribe(this->shared_from_this()); + if (upObserver_) { + first_->subscribe(upObserver_); + } + } + + void cancel() override { + if (auto observer = std::move(upObserver_)) { + observer->cancel(); + } + first_.reset(); + second_.reset(); + upObserver_.reset(); + downObserver_.reset(); + } + + void onNext(T value) { + downObserver_->onNext(std::move(value)); + } + + void onComplete() { + if (auto first = std::move(first_)) { + upObserver_ = + std::make_shared(this->shared_from_this()); + second_->subscribe(upObserver_); + second_.reset(); + } else { + downObserver_->onComplete(); + downObserver_.reset(); + } + } + + void onError(folly::exception_wrapper ew) { + downObserver_->onError(std::move(ew)); + first_.reset(); + second_.reset(); + upObserver_.reset(); + downObserver_.reset(); + } + + private: + std::shared_ptr> downObserver_; + std::shared_ptr> first_; + std::shared_ptr> second_; + std::shared_ptr upObserver_; + }; + + class ForwardObserver : public yarpl::observable::Observer, + public yarpl::observable::Subscription { + public: + ForwardObserver( + std::shared_ptr concatWithSubscription) + : concatWithSubscription_(std::move(concatWithSubscription)) {} + + void cancel() override { + if (auto subs = std::move(subscription_)) { + subs->cancel(); + } + } + + void onSubscribe(std::shared_ptr subscription) override { + // Don't forward the subscription to downstream observer + subscription_ = std::move(subscription); + } + + void onComplete() override { + concatWithSubscription_->onComplete(); + concatWithSubscription_.reset(); + } + + void onError(folly::exception_wrapper ew) override { + concatWithSubscription_->onError(std::move(ew)); + concatWithSubscription_.reset(); + } + + void onNext(T value) override { + concatWithSubscription_->onNext(std::move(value)); + } + + private: + std::shared_ptr concatWithSubscription_; + std::shared_ptr subscription_; + }; + + private: + const std::shared_ptr> first_; + const std::shared_ptr> second_; +}; + +} // namespace details +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/ObservableDoOperator.h b/yarpl/observable/ObservableDoOperator.h new file mode 100644 index 000000000..66f655eaf --- /dev/null +++ b/yarpl/observable/ObservableDoOperator.h @@ -0,0 +1,159 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/observable/ObservableOperator.h" + +namespace yarpl { +namespace observable { +namespace details { + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnCancelFunc> +class DoOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert( + std::is_same, OnSubscribeFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnNextFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnErrorFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCompleteFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCancelFunc>::value, + "undecayed"); + + public: + template < + typename FSubscribe, + typename FNext, + typename FError, + typename FComplete, + typename FCancel> + DoOperator( + std::shared_ptr> upstream, + FSubscribe&& onSubscribeFunc, + FNext&& onNextFunc, + FError&& onErrorFunc, + FComplete&& onCompleteFunc, + FCancel&& onCancelFunc) + : upstream_(std::move(upstream)), + onSubscribeFunc_(std::forward(onSubscribeFunc)), + onNextFunc_(std::forward(onNextFunc)), + onErrorFunc_(std::forward(onErrorFunc)), + onCompleteFunc_(std::forward(onCompleteFunc)), + onCancelFunc_(std::forward(onCancelFunc)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(observer)); + upstream_->subscribe( + // Note: implicit cast to a reference to a observer. + subscription); + return subscription; + } + + private: + class DoSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + DoSubscription( + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), observable_(std::move(observable)) {} + + void onSubscribe(std::shared_ptr + subscription) override { + observable_->onSubscribeFunc_(); + SuperSub::onSubscribe(std::move(subscription)); + } + + void onNext(U value) override { + const auto& valueRef = value; + observable_->onNextFunc_(valueRef); + SuperSub::observerOnNext(std::move(value)); + } + + void onError(folly::exception_wrapper ex) override { + const auto& exRef = ex; + observable_->onErrorFunc_(exRef); + SuperSub::onError(std::move(ex)); + } + + void onComplete() override { + observable_->onCompleteFunc_(); + SuperSub::onComplete(); + } + + void cancel() override { + observable_->onCancelFunc_(); + SuperSub::cancel(); + } + + private: + std::shared_ptr observable_; + }; + + std::shared_ptr> upstream_; + OnSubscribeFunc onSubscribeFunc_; + OnNextFunc onNextFunc_; + OnErrorFunc onErrorFunc_; + OnCompleteFunc onCompleteFunc_; + OnCancelFunc onCancelFunc_; +}; + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnCancelFunc> +inline auto createDoOperator( + std::shared_ptr> upstream, + OnSubscribeFunc&& onSubscribeFunc, + OnNextFunc&& onNextFunc, + OnErrorFunc&& onErrorFunc, + OnCompleteFunc&& onCompleteFunc, + OnCancelFunc&& onCancelFunc) { + return std::make_shared, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t>>( + std::move(upstream), + std::forward(onSubscribeFunc), + std::forward(onNextFunc), + std::forward(onErrorFunc), + std::forward(onCompleteFunc), + std::forward(onCancelFunc)); +} +} // namespace details +} // namespace observable +} // namespace yarpl diff --git a/yarpl/include/yarpl/observable/ObservableOperator.h b/yarpl/observable/ObservableOperator.h similarity index 52% rename from yarpl/include/yarpl/observable/ObservableOperator.h rename to yarpl/observable/ObservableOperator.h index 6e667839a..451c6bd13 100644 --- a/yarpl/include/yarpl/observable/ObservableOperator.h +++ b/yarpl/observable/ObservableOperator.h @@ -1,14 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include +#include + #include "yarpl/Observable.h" #include "yarpl/observable/Observer.h" -#include "yarpl/observable/Subscriptions.h" - -#include +#include "yarpl/observable/Observable.h" namespace yarpl { namespace observable { @@ -20,13 +32,8 @@ namespace observable { * pipelines * can be built: a Observable heading a sequence of Operators. */ -template +template class ObservableOperator : public Observable { - public: - explicit ObservableOperator(Reference> upstream) - : upstream_(std::move(upstream)) {} - using ThisOperatorT = ThisOp; - protected: /// An Operator's subscription. /// @@ -38,21 +45,11 @@ class ObservableOperator : public Observable { class OperatorSubscription : public ::yarpl::observable::Subscription, public Observer { protected: - OperatorSubscription( - Reference observable, - Reference> observer) - : observable_(std::move(observable)), observer_(std::move(observer)) { - assert(observable_); + explicit OperatorSubscription(std::shared_ptr> observer) + : observer_(std::move(observer)) { assert(observer_); } - Reference& getObservableOperator() { - static_assert( - std::is_base_of, ThisOperatorT>::value, - "Operator must be a subclass of Observable"); - return observable_; - } - void observerOnNext(D value) { if (observer_) { observer_->onNext(std::move(value)); @@ -78,8 +75,8 @@ class ObservableOperator : public Observable { // Observer. - void onSubscribe( - Reference subscription) override { + void onSubscribe(std::shared_ptr + subscription) override { if (upstream_) { DLOG(ERROR) << "attempt to subscribe twice"; subscription->cancel(); @@ -147,14 +144,11 @@ class ObservableOperator : public Observable { } } - /// The Observable has the lambda, and other creation parameters. - Reference observable_; - /// This subscription controls the life-cycle of the observer. The /// observer is retained as long as calls on it can be made. (Note: /// the observer in turn maintains a reference on this subscription /// object until cancellation and/or completion.) - Reference> observer_; + std::shared_ptr> observer_; /// In an active pipeline, cancel and (possibly modified) request(n) /// calls should be forwarded upstream. Note that `this` is also a @@ -162,29 +156,27 @@ class ObservableOperator : public Observable { /// the objects drop their references at cancel/complete. // TODO(lehecka): this is extra field... base class has this member so // remove it - Reference<::yarpl::observable::Subscription> upstream_; + std::shared_ptr<::yarpl::observable::Subscription> upstream_; }; - - Reference> upstream_; }; -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>::type> -class MapOperator : public ObservableOperator> { - using ThisOperatorT = MapOperator; - using Super = ObservableOperator; +template +class MapOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); public: - MapOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} - - Reference subscribe(Reference> observer) override { - auto subscription = - make_ref(this->ref_from_this(this), std::move(observer)); - Super::upstream_->subscribe( + template + MapOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(observer)); + upstream_->subscribe( // Note: implicit cast to a reference to a observer. subscription); return subscription; @@ -196,41 +188,44 @@ class MapOperator : public ObservableOperator> { public: MapSubscription( - Reference observable, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)) {} + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), observable_(std::move(observable)) {} void onNext(U value) override { try { - auto& map = this->getObservableOperator(); - this->observerOnNext(map->function_(std::move(value))); + this->observerOnNext(observable_->function_(std::move(value))); } catch (const std::exception& exn) { folly::exception_wrapper ew{std::current_exception(), exn}; this->terminateErr(std::move(ew)); } } + + private: + std::shared_ptr observable_; }; + std::shared_ptr> upstream_; F function_; }; -template < - typename U, - typename F, - typename = - typename std::enable_if::value>::type> -class FilterOperator : public ObservableOperator> { - using ThisOperatorT = FilterOperator; - using Super = ObservableOperator; +template +class FilterOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); public: - FilterOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} - - Reference subscribe(Reference> observer) override { - auto subscription = - make_ref(this->ref_from_this(this), std::move(observer)); - Super::upstream_->subscribe( + template + FilterOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(observer)); + upstream_->subscribe( // Note: implicit cast to a reference to a observer. subscription); return subscription; @@ -242,42 +237,42 @@ class FilterOperator : public ObservableOperator> { public: FilterSubscription( - Reference observable, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)) {} + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), observable_(std::move(observable)) {} void onNext(U value) override { - auto& filter = SuperSub::getObservableOperator(); - if (filter->function_(value)) { + if (observable_->function_(value)) { SuperSub::observerOnNext(std::move(value)); } } + + private: + std::shared_ptr observable_; }; + std::shared_ptr> upstream_; F function_; }; -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>, - typename = - typename std::enable_if::value>::type> -class ReduceOperator - : public ObservableOperator> { - using ThisOperatorT = ReduceOperator; - using Super = ObservableOperator; +template +class ReduceOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(std::is_assignable::value, "not assignable"); + static_assert(folly::is_invocable_r::value, "not invocable"); public: - ReduceOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} - - Reference subscribe( - Reference> subscriber) override { - auto subscription = make_ref( + template + ReduceOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + std::shared_ptr subscribe( + std::shared_ptr> subscriber) override { + auto subscription = std::make_shared( this->ref_from_this(this), std::move(subscriber)); - Super::upstream_->subscribe( + upstream_->subscribe( // Note: implicit cast to a reference to a subscriber. subscription); return subscription; @@ -289,15 +284,15 @@ class ReduceOperator public: ReduceSubscription( - Reference flowable, - Reference> subscriber) - : SuperSub(std::move(flowable), std::move(subscriber)), + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), + observable_(std::move(observable)), accInitialized_(false) {} void onNext(U value) override { - auto& reduce = SuperSub::getObservableOperator(); if (accInitialized_) { - acc_ = reduce->function_(std::move(acc_), std::move(value)); + acc_ = observable_->function_(std::move(acc_), std::move(value)); } else { acc_ = std::move(value); accInitialized_ = true; @@ -312,26 +307,28 @@ class ReduceOperator } private: + std::shared_ptr observable_; bool accInitialized_; D acc_; }; + std::shared_ptr> upstream_; F function_; }; template -class TakeOperator : public ObservableOperator> { - using ThisOperatorT = TakeOperator; - using Super = ObservableOperator; +class TakeOperator : public ObservableOperator { + using Super = ObservableOperator; public: - TakeOperator(Reference> upstream, int64_t limit) - : Super(std::move(upstream)), limit_(limit) {} + TakeOperator(std::shared_ptr> upstream, int64_t limit) + : upstream_(std::move(upstream)), limit_(limit) {} - Reference subscribe(Reference> observer) override { - auto subscription = make_ref( - this->ref_from_this(this), limit_, std::move(observer)); - Super::upstream_->subscribe(subscription); + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(limit_, std::move(observer)); + upstream_->subscribe(subscription); return subscription; } @@ -340,16 +337,20 @@ class TakeOperator : public ObservableOperator> { using SuperSub = typename Super::OperatorSubscription; public: - TakeSubscription( - Reference observable, - int64_t limit, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)), limit_(limit) {} + TakeSubscription(int64_t limit, std::shared_ptr> observer) + : SuperSub(std::move(observer)), limit_(limit) {} + + void onSubscribe(std::shared_ptr + subscription) override { + SuperSub::onSubscribe(std::move(subscription)); + + if (limit_ <= 0) { + SuperSub::terminate(); + } + } void onNext(T value) override { if (limit_-- > 0) { - if (pending_ > 0) - --pending_; SuperSub::observerOnNext(std::move(value)); if (limit_ == 0) { SuperSub::terminate(); @@ -358,26 +359,26 @@ class TakeOperator : public ObservableOperator> { } private: - int64_t pending_{0}; int64_t limit_; }; + std::shared_ptr> upstream_; const int64_t limit_; }; template -class SkipOperator : public ObservableOperator> { - using ThisOperatorT = SkipOperator; - using Super = ObservableOperator; +class SkipOperator : public ObservableOperator { + using Super = ObservableOperator; public: - SkipOperator(Reference> upstream, int64_t offset) - : Super(std::move(upstream)), offset_(offset) {} + SkipOperator(std::shared_ptr> upstream, int64_t offset) + : upstream_(std::move(upstream)), offset_(offset) {} - Reference subscribe(Reference> observer) override { - auto subscription = make_ref( - this->ref_from_this(this), offset_, std::move(observer)); - Super::upstream_->subscribe(subscription); + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(offset_, std::move(observer)); + upstream_->subscribe(subscription); return subscription; } @@ -386,12 +387,8 @@ class SkipOperator : public ObservableOperator> { using SuperSub = typename Super::OperatorSubscription; public: - SkipSubscription( - Reference observable, - int64_t offset, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)), - offset_(offset) {} + SkipSubscription(int64_t offset, std::shared_ptr> observer) + : SuperSub(std::move(observer)), offset_(offset) {} void onNext(T value) override { if (offset_ <= 0) { @@ -405,23 +402,23 @@ class SkipOperator : public ObservableOperator> { int64_t offset_; }; + std::shared_ptr> upstream_; const int64_t offset_; }; template -class IgnoreElementsOperator - : public ObservableOperator> { - using ThisOperatorT = IgnoreElementsOperator; - using Super = ObservableOperator; +class IgnoreElementsOperator : public ObservableOperator { + using Super = ObservableOperator; public: - explicit IgnoreElementsOperator(Reference> upstream) - : Super(std::move(upstream)) {} + explicit IgnoreElementsOperator(std::shared_ptr> upstream) + : upstream_(std::move(upstream)) {} - Reference subscribe(Reference> observer) override { - auto subscription = make_ref( - this->ref_from_this(this), std::move(observer)); - Super::upstream_->subscribe(subscription); + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(std::move(observer)); + upstream_->subscribe(subscription); return subscription; } @@ -430,31 +427,32 @@ class IgnoreElementsOperator using SuperSub = typename Super::OperatorSubscription; public: - IgnoreElementsSubscription( - Reference observable, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)) {} + IgnoreElementsSubscription(std::shared_ptr> observer) + : SuperSub(std::move(observer)) {} void onNext(T) override {} }; + + std::shared_ptr> upstream_; }; template -class SubscribeOnOperator - : public ObservableOperator> { - using ThisOperatorT = SubscribeOnOperator; - using Super = ObservableOperator; +class SubscribeOnOperator : public ObservableOperator { + using Super = ObservableOperator; public: SubscribeOnOperator( - Reference> upstream, + std::shared_ptr> upstream, folly::Executor& executor) - : Super(std::move(upstream)), executor_(executor) {} - - Reference subscribe(Reference> observer) override { - auto subscription = make_ref( - this->ref_from_this(this), executor_, std::move(observer)); - Super::upstream_->subscribe(subscription); + : upstream_(std::move(upstream)), executor_(executor) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + executor_, std::move(observer)); + executor_.add([subscription, upstream = upstream_]() mutable { + upstream->subscribe(std::move(subscription)); + }); return subscription; } @@ -464,14 +462,12 @@ class SubscribeOnOperator public: SubscribeOnSubscription( - Reference observable, folly::Executor& executor, - Reference> observer) - : SuperSub(std::move(observable), std::move(observer)), - executor_(executor) {} + std::shared_ptr> observer) + : SuperSub(std::move(observer)), executor_(executor) {} void cancel() override { - executor_.add([ self = this->ref_from_this(this), this ] { + executor_.add([self = this->ref_from_this(this), this] { this->callSuperCancel(); }); } @@ -489,55 +485,67 @@ class SubscribeOnOperator folly::Executor& executor_; }; + std::shared_ptr> upstream_; folly::Executor& executor_; }; template class FromPublisherOperator : public Observable { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + public: - explicit FromPublisherOperator(OnSubscribe function) - : function_(std::move(function)) {} + template + explicit FromPublisherOperator(F&& function) + : function_(std::forward(function)) {} private: class PublisherObserver : public Observer { public: PublisherObserver( - Reference> inner, - Reference subscription) + std::shared_ptr> inner, + std::shared_ptr subscription) : inner_(std::move(inner)) { Observer::onSubscribe(std::move(subscription)); } - void onSubscribe(Reference) override { + void onSubscribe(std::shared_ptr) override { DLOG(ERROR) << "not allowed to call"; CHECK(false); } void onComplete() override { - inner_->onComplete(); + if (auto inner = atomic_exchange(&inner_, nullptr)) { + inner->onComplete(); + } Observer::onComplete(); } void onError(folly::exception_wrapper ex) override { - inner_->onError(std::move(ex)); + if (auto inner = atomic_exchange(&inner_, nullptr)) { + inner->onError(std::move(ex)); + } Observer::onError(folly::exception_wrapper()); } void onNext(T t) override { - inner_->onNext(std::move(t)); + atomic_load(&inner_)->onNext(std::move(t)); } private: - Reference> inner_; + AtomicReference> inner_; }; public: - Reference subscribe(Reference> observer) override { - auto subscription = Subscriptions::create(); + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = Subscription::create(); observer->onSubscribe(subscription); if (!subscription->isCancelled()) { - function_(make_ref(std::move(observer), subscription)); + function_(std::make_shared( + std::move(observer), subscription), subscription); } return subscription; } @@ -545,7 +553,8 @@ class FromPublisherOperator : public Observable { private: OnSubscribe function_; }; -} -} +} // namespace observable +} // namespace yarpl +#include "yarpl/observable/ObservableConcatOperators.h" #include "yarpl/observable/ObservableDoOperator.h" diff --git a/yarpl/observable/Observables.cpp b/yarpl/observable/Observables.cpp new file mode 100644 index 000000000..6107938fe --- /dev/null +++ b/yarpl/observable/Observables.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/observable/Observables.h" + +namespace yarpl { +namespace observable { + +std::shared_ptr> Observable<>::range( + int64_t start, + int64_t count) { + auto lambda = [start, count](std::shared_ptr> observer) { + auto end = start + count; + for (int64_t i = start; i < end; ++i) { + observer->onNext(i); + } + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); +} +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Observables.h b/yarpl/observable/Observables.h new file mode 100644 index 000000000..7c30c4bec --- /dev/null +++ b/yarpl/observable/Observables.h @@ -0,0 +1,57 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/observable/Observable.h" +#include "yarpl/observable/Subscription.h" + +namespace yarpl { +namespace observable { + +template <> +class Observable { + public: + /** + * Emit the sequence of numbers [start, start + count). + */ + static std::shared_ptr> range( + int64_t start, + int64_t count); + + template + static std::shared_ptr> just(T&& value) { + return Observable>::just(std::forward(value)); + } + + template + static std::shared_ptr> justN(std::initializer_list list) { + return Observable>::justN(std::move(list)); + } + + // this will generate an observable which can be subscribed to only once + template + static std::shared_ptr> justOnce(T&& value) { + return Observable>::justOnce( + std::forward(value)); + } + + private: + Observable() = delete; +}; + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Observer.h b/yarpl/observable/Observer.h new file mode 100644 index 000000000..3e1e456b4 --- /dev/null +++ b/yarpl/observable/Observer.h @@ -0,0 +1,226 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "yarpl/Refcounted.h" +#include "yarpl/observable/Subscription.h" + +namespace yarpl { +namespace observable { + +template +class Observer : public yarpl::enable_get_ref { + public: + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + virtual void onSubscribe(std::shared_ptr subscription) { + DCHECK(subscription); + + if (subscription_) { + DLOG(ERROR) << "attempt to double subscribe"; + subscription->cancel(); + return; + } + + if (cancelled_) { + subscription->cancel(); + } + + subscription_ = std::move(subscription); + } + + // No further calls to the subscription after this method is invoked. + virtual void onComplete() { + DCHECK(subscription_) << "Calling onComplete() without a subscription"; + subscription_.reset(); + } + + // No further calls to the subscription after this method is invoked. + virtual void onError(folly::exception_wrapper) { + DCHECK(subscription_) << "Calling onError() without a subscription"; + subscription_.reset(); + } + + virtual void onNext(T) = 0; + + bool isUnsubscribed() const { + CHECK(subscription_); + return subscription_->isCancelled(); + } + + // Ability to add more subscription objects which will be notified when the + // subscription has been cancelled. + // Note that calling cancel on the tied subscription is not going to cancel + // this subscriber + void addSubscription(std::shared_ptr subscription) { + if (!subscription_) { + subscription->cancel(); + return; + } + subscription_->tieSubscription(std::move(subscription)); + } + + template + void addSubscription(OnCancel onCancel) { + addSubscription(Subscription::create(std::move(onCancel))); + } + + bool isUnsubscribedOrTerminated() const { + return !subscription_ || subscription_->isCancelled(); + } + + protected: + void unsubscribe() { + if (subscription_) { + subscription_->cancel(); + } else { + cancelled_ = true; + } + } + + public: + template < + typename Next, + typename = + typename std::enable_if::value>::type> + static std::shared_ptr> create(Next next); + + template < + typename Next, + typename Error, + typename = + typename std::enable_if::value>::type, + typename = typename std::enable_if< + folly::is_invocable::value>::type> + static std::shared_ptr> create(Next next, Error error); + + template < + typename Next, + typename Error, + typename Complete, + typename = + typename std::enable_if::value>::type, + typename = typename std::enable_if< + folly::is_invocable::value>::type, + typename = + typename std::enable_if::value>::type> + static std::shared_ptr> + create(Next next, Error error, Complete complete); + + static std::shared_ptr> create() { + class NullObserver : public Observer { + public: + void onNext(T) {} + }; + return std::make_shared(); + } + + private: + std::shared_ptr subscription_; + bool cancelled_{false}; +}; + +namespace details { + +template +class Base : public Observer { + static_assert(std::is_same, Next>::value, "undecayed"); + + public: + template + explicit Base(FNext&& next) : next_(std::forward(next)) {} + + void onNext(T value) override { + next_(std::move(value)); + } + + private: + Next next_; +}; + +template +class WithError : public Base { + static_assert(std::is_same, Error>::value, "undecayed"); + + public: + template + WithError(FNext&& next, FError&& error) + : Base(std::forward(next)), + error_(std::forward(error)) {} + + void onError(folly::exception_wrapper error) override { + error_(std::move(error)); + } + + private: + Error error_; +}; + +template +class WithErrorAndComplete : public WithError { + static_assert( + std::is_same, Complete>::value, + "undecayed"); + + public: + template + WithErrorAndComplete(FNext&& next, FError&& error, FComplete&& complete) + : WithError( + std::forward(next), + std::forward(error)), + complete_(std::move(complete)) {} + + void onComplete() override { + complete_(); + } + + private: + Complete complete_; +}; +} // namespace details + +template +template +std::shared_ptr> Observer::create(Next next) { + return std::make_shared>(std::move(next)); +} + +template +template +std::shared_ptr> Observer::create(Next next, Error error) { + return std::make_shared>( + std::move(next), std::move(error)); +} + +template +template < + typename Next, + typename Error, + typename Complete, + typename, + typename, + typename> +std::shared_ptr> +Observer::create(Next next, Error error, Complete complete) { + return std::make_shared< + details::WithErrorAndComplete>( + std::move(next), std::move(error), std::move(complete)); +} + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Subscription.cpp b/yarpl/observable/Subscription.cpp new file mode 100644 index 000000000..6a0abda1a --- /dev/null +++ b/yarpl/observable/Subscription.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/observable/Subscription.h" +#include +#include +#include + +namespace yarpl { +namespace observable { + +/** + * Implementation that allows checking if a Subscription is cancelled. + */ +void Subscription::cancel() { + cancelled_ = true; + // Lock must be obtained here and not in the range expression for it to + // apply to the loop body. + auto locked = tiedSubscriptions_.wlock(); + for (auto& subscription : *locked) { + subscription->cancel(); + } +} + +bool Subscription::isCancelled() const { + return cancelled_; +} + +void Subscription::tieSubscription(std::shared_ptr subscription) { + CHECK(subscription); + if (isCancelled()) { + subscription->cancel(); + } + tiedSubscriptions_.wlock()->push_back(std::move(subscription)); +} + +std::shared_ptr Subscription::create( + std::function onCancel) { + class CallbackSubscription : public Subscription { + public: + explicit CallbackSubscription(std::function onCancel) + : onCancel_(std::move(onCancel)) {} + + void cancel() override { + bool expected = false; + // mark cancelled 'true' and only if successful invoke 'onCancel()' + if (cancelled_.compare_exchange_strong(expected, true)) { + onCancel_(); + // Lock must be obtained here and not in the range expression for it to + // apply to the loop body. + auto locked = tiedSubscriptions_.wlock(); + for (auto& subscription : *locked) { + subscription->cancel(); + } + } + } + + private: + std::function onCancel_; + }; + return std::make_shared(std::move(onCancel)); +} + +std::shared_ptr Subscription::create( + std::atomic_bool& cancelled) { + return create([&cancelled]() { cancelled = true; }); +} + +std::shared_ptr Subscription::create() { + return std::make_shared(); +} + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Subscription.h b/yarpl/observable/Subscription.h new file mode 100644 index 000000000..38dc17792 --- /dev/null +++ b/yarpl/observable/Subscription.h @@ -0,0 +1,45 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace yarpl { +namespace observable { + +class Subscription { + public: + virtual ~Subscription() = default; + virtual void cancel(); + bool isCancelled() const; + + // Adds ability to tie another subscription to this instance. + // Whenever *this subscription is cancelled then all tied subscriptions get + // cancelled as well + void tieSubscription(std::shared_ptr subscription); + + static std::shared_ptr create(std::function onCancel); + static std::shared_ptr create(std::atomic_bool& cancelled); + static std::shared_ptr create(); + + protected: + std::atomic cancelled_{false}; + folly::Synchronized>> + tiedSubscriptions_; +}; + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/include/yarpl/observable/TestObserver.h b/yarpl/observable/TestObserver.h similarity index 77% rename from yarpl/include/yarpl/observable/TestObserver.h rename to yarpl/observable/TestObserver.h index 7c5fab2a0..a4d290492 100644 --- a/yarpl/include/yarpl/observable/TestObserver.h +++ b/yarpl/observable/TestObserver.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -18,7 +30,7 @@ namespace observable { * Example usage: * * auto observable = ... - * auto ts = TestObserver::create(); + * auto ts = std::make_shared>(); * observable->subscribe(ts->unique_observer()); * ts->awaitTerminalEvent(); * ts->assert... @@ -29,7 +41,8 @@ namespace observable { * * For example: * - * auto ts = TestObserver::create(std::make_unique()); + * auto ts = + * std::make_shared>(std::make_unique()); * observable->subscribe(ts->unique_observer()); * * Now when 'observable' is subscribed to, the TestObserver behavior @@ -44,27 +57,11 @@ class TestObserver : public yarpl::observable::Observer, using Observer = yarpl::observable::Observer; public: - /** - * Create a TestObserver that will subscribe upwards - * with no flow control (max value) and store all values it receives. - * @return - */ - static std::shared_ptr> create(); - - /** - * Create a TestObserver that will delegate all on* method calls - * to the provided Observer. - * - * This will store all values it receives to allow assertions. - * @return - */ - static std::shared_ptr> create(std::unique_ptr); - TestObserver(); explicit TestObserver(std::unique_ptr delegate); - void onSubscribe(Subscription* s) override; - void onNext(const T& t) override; + void onSubscribe(std::shared_ptr s) override; + void onNext(T t) override; void onComplete() override; void onError(folly::exception_wrapper ex) override; @@ -84,7 +81,8 @@ class TestObserver : public yarpl::observable::Observer, /** * Block the current thread until either onComplete or onError is called. */ - void awaitTerminalEvent(); + void awaitTerminalEvent( + std::chrono::milliseconds ms = std::chrono::seconds{1}); /** * If the onNext values received does not match the given count, @@ -116,14 +114,24 @@ class TestObserver : public yarpl::observable::Observer, */ void cancel(); + bool isComplete() const { + return complete_; + } + + bool isError() const { + return error_; + } + private: std::unique_ptr delegate_; std::vector values_; folly::exception_wrapper e_; bool terminated_{false}; + bool complete_{false}; + bool error_{false}; std::mutex m_; std::condition_variable terminalEventCV_; - Subscription* subscription_; + std::shared_ptr subscription_; }; template @@ -134,18 +142,7 @@ TestObserver::TestObserver(std::unique_ptr delegate) : delegate_(std::move(delegate)){}; template -std::shared_ptr> TestObserver::create() { - return std::make_shared>(); -} - -template -std::shared_ptr> TestObserver::create( - std::unique_ptr s) { - return std::make_shared>(std::move(s)); -} - -template -void TestObserver::onSubscribe(Subscription* s) { +void TestObserver::onSubscribe(std::shared_ptr s) { subscription_ = s; if (delegate_) { delegate_->onSubscribe(s); @@ -153,7 +150,7 @@ void TestObserver::onSubscribe(Subscription* s) { } template -void TestObserver::onNext(const T& t) { +void TestObserver::onNext(T t) { if (delegate_) { // std::cout << "TestObserver onNext& => copy then delegate" << // std::endl; @@ -171,6 +168,7 @@ void TestObserver::onComplete() { delegate_->onComplete(); } terminated_ = true; + complete_ = true; terminalEventCV_.notify_all(); } @@ -181,15 +179,18 @@ void TestObserver::onError(folly::exception_wrapper ex) { } e_ = std::move(ex); terminated_ = true; + error_ = true; terminalEventCV_.notify_all(); } template -void TestObserver::awaitTerminalEvent() { +void TestObserver::awaitTerminalEvent(std::chrono::milliseconds ms) { // now block this thread std::unique_lock lk(m_); // if shutdown gets implemented this would then be released by it - terminalEventCV_.wait(lk, [this] { return terminated_; }); + if (!terminalEventCV_.wait_for(lk, ms, [this] { return terminated_; })) { + throw std::runtime_error("timeout in awaitTerminalEvent"); + } } template @@ -253,5 +254,5 @@ void TestObserver::assertOnErrorMessage(std::string msg) { throw std::runtime_error(ss.str()); } } -} -} +} // namespace observable +} // namespace yarpl diff --git a/yarpl/perf/Function_perf.cpp b/yarpl/perf/Function_perf.cpp index 4783d2a8a..4c868a44f 100644 --- a/yarpl/perf/Function_perf.cpp +++ b/yarpl/perf/Function_perf.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include diff --git a/yarpl/perf/Observable_perf.cpp b/yarpl/perf/Observable_perf.cpp index abbe81e90..ecfb6a92c 100644 --- a/yarpl/perf/Observable_perf.cpp +++ b/yarpl/perf/Observable_perf.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -9,7 +21,7 @@ using namespace yarpl::observable; static void Observable_OnNextOne_ConstructOnly(benchmark::State& state) { while (state.KeepRunning()) { - auto a = Observable::create([](yarpl::Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onSubscribe(Subscriptions::empty()); obs->onNext(1); obs->onComplete(); @@ -19,20 +31,20 @@ static void Observable_OnNextOne_ConstructOnly(benchmark::State& state) { BENCHMARK(Observable_OnNextOne_ConstructOnly); static void Observable_OnNextOne_SubscribeOnly(benchmark::State& state) { - auto a = Observable::create([](yarpl::Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onSubscribe(Subscriptions::empty()); obs->onNext(1); obs->onComplete(); }); while (state.KeepRunning()) { - a->subscribe(Observers::create([](int /* value */) {})); + a->subscribe(Observer::create([](int /* value */) {})); } } BENCHMARK(Observable_OnNextOne_SubscribeOnly); static void Observable_OnNextN(benchmark::State& state) { auto a = - Observable::create([&state](yarpl::Reference> obs) { + Observable::create([&state](std::shared_ptr> obs) { obs->onSubscribe(Subscriptions::empty()); for (int i = 0; i < state.range(0); i++) { obs->onNext(i); @@ -40,7 +52,7 @@ static void Observable_OnNextN(benchmark::State& state) { obs->onComplete(); }); while (state.KeepRunning()) { - a->subscribe(Observers::create([](int /* value */) {})); + a->subscribe(Observer::create([](int /* value */) {})); } } diff --git a/yarpl/single/Single.h b/yarpl/single/Single.h new file mode 100644 index 000000000..1355e30ba --- /dev/null +++ b/yarpl/single/Single.h @@ -0,0 +1,175 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "yarpl/Refcounted.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleObservers.h" +#include "yarpl/single/SingleSubscription.h" + +namespace yarpl { +namespace single { + +template +class Single : public yarpl::enable_get_ref { + public: + virtual ~Single() = default; + + virtual void subscribe(std::shared_ptr>) = 0; + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Success, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + void subscribe(Success&& next) { + subscribe(SingleObservers::create(std::forward(next))); + } + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Success, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + void subscribe(Success next, Error error) { + subscribe(SingleObservers::create( + std::forward(next), std::forward(error))); + } + + /** + * Blocking subscribe that accepts lambdas. + * + * This blocks the current thread waiting on the response. + */ + template < + typename Success, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + void subscribeBlocking(Success&& next) { + auto waiting_ = std::make_shared>(); + subscribe( + SingleObservers::create([next = std::forward(next), waiting_](T t) { + next(std::move(t)); + waiting_->post(); + })); + // TODO get errors and throw if one is received + waiting_->wait(); + } + + template < + typename OnSubscribe, + typename = typename std::enable_if&, + std::shared_ptr>>::value>::type> + static std::shared_ptr> create(OnSubscribe&&); + + template + auto map(Function&& function); +}; + +template <> +class Single { + public: + virtual ~Single() = default; + + virtual void subscribe(std::shared_ptr>) = 0; + + /** + * Subscribe overload taking lambda for onSuccess that is called upon writing + * to the network. + */ + template < + typename Success, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + void subscribe(Success&& s) { + class SuccessSingleObserver : public SingleObserverBase { + public: + explicit SuccessSingleObserver(Success&& success) + : success_{std::forward(success)} {} + + void onSubscribe( + std::shared_ptr subscription) override { + SingleObserverBase::onSubscribe(std::move(subscription)); + } + + void onSuccess() override { + success_(); + SingleObserverBase::onSuccess(); + } + + // No further calls to the subscription after this method is invoked. + void onError(folly::exception_wrapper ex) override { + SingleObserverBase::onError(std::move(ex)); + } + + private: + std::decay_t success_; + }; + + subscribe( + std::make_shared(std::forward(s))); + } + + template < + typename OnSubscribe, + typename = typename std::enable_if&, + std::shared_ptr>>::value>::type> + static auto create(OnSubscribe&&); +}; + +} // namespace single +} // namespace yarpl + +#include "yarpl/single/SingleOperator.h" + +namespace yarpl { +namespace single { + +template +template +std::shared_ptr> Single::create(OnSubscribe&& function) { + return std::make_shared>>( + std::forward(function)); +} + +template +auto Single::create(OnSubscribe&& function) { + return std::make_shared< + SingleVoidFromPublisherOperator>>( + std::forward(function)); +} + +template +template +auto Single::map(Function&& function) { + using D = typename folly::invoke_result_t; + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleObserver.h b/yarpl/single/SingleObserver.h new file mode 100644 index 000000000..8c74337c5 --- /dev/null +++ b/yarpl/single/SingleObserver.h @@ -0,0 +1,171 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/single/SingleSubscription.h" + +namespace yarpl { +namespace single { + +template +class SingleObserver { + public: + virtual ~SingleObserver() = default; + virtual void onSubscribe(std::shared_ptr) = 0; + virtual void onSuccess(T) = 0; + virtual void onError(folly::exception_wrapper) = 0; + + template + static std::shared_ptr> create(Success&& success); + + template + static std::shared_ptr> create( + Success&& success, + Error&& error); +}; + +template +class SingleObserverBase : public SingleObserver { + public: + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + void onSubscribe(std::shared_ptr subscription) override { + DCHECK(subscription); + + if (subscription_) { + subscription->cancel(); + return; + } + + subscription_ = std::move(subscription); + } + + void onSuccess(T) override { + DCHECK(subscription_) << "Calling onSuccess() without a subscription"; + subscription_.reset(); + } + + // No further calls to the subscription after this method is invoked. + void onError(folly::exception_wrapper) override { + DCHECK(subscription_) << "Calling onError() without a subscription"; + subscription_.reset(); + } + + protected: + SingleSubscription* subscription() { + return subscription_.operator->(); + } + + private: + std::shared_ptr subscription_; +}; + +/// Specialization of SingleObserverBase. +template <> +class SingleObserverBase { + public: + virtual ~SingleObserverBase() = default; + + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + virtual void onSubscribe(std::shared_ptr subscription) { + DCHECK(subscription); + + if (subscription_) { + subscription->cancel(); + return; + } + + subscription_ = std::move(subscription); + } + + virtual void onSuccess() { + DCHECK(subscription_) << "Calling onSuccess() without a subscription"; + subscription_.reset(); + } + + // No further calls to the subscription after this method is invoked. + virtual void onError(folly::exception_wrapper) { + DCHECK(subscription_) << "Calling onError() without a subscription"; + subscription_.reset(); + } + + protected: + SingleSubscription* subscription() { + return subscription_.operator->(); + } + + private: + std::shared_ptr subscription_; +}; + +template +class SimpleSingleObserver : public SingleObserver { + public: + SimpleSingleObserver(Success success, Error error) + : success_(std::move(success)), error_(std::move(error)) {} + + void onSubscribe(std::shared_ptr) { + // throw away the subscription + } + + void onSuccess(T value) override { + success_(std::move(value)); + } + + void onError(folly::exception_wrapper ew) { + error_(std::move(ew)); + } + + Success success_; + Error error_; +}; + +template +template +std::shared_ptr> SingleObserver::create( + Success&& success) { + static_assert( + folly::is_invocable::value, + "Input `success` should be invocable with a parameter of `T`."); + return std::make_shared, + folly::Function>>( + std::forward(success), [](folly::exception_wrapper) {}); +} + +template +template +std::shared_ptr> SingleObserver::create( + Success&& success, + Error&& error) { + static_assert( + folly::is_invocable::value, + "Input `success` should be invocable with a parameter of `T`."); + static_assert( + folly::is_invocable::value, + "Input `error` should be invocable with a parameter of " + "`folly::exception_wrapper`."); + + return std::make_shared< + SimpleSingleObserver, std::decay_t>>( + std::forward(success), std::forward(error)); +} + +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleObservers.h b/yarpl/single/SingleObservers.h new file mode 100644 index 000000000..118b25fa9 --- /dev/null +++ b/yarpl/single/SingleObservers.h @@ -0,0 +1,107 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/single/SingleObserver.h" + +#include + +namespace yarpl { +namespace single { + +/// Helper methods for constructing subscriber instances from functions: +/// one or two functions (callables; can be lamda, for instance) +/// may be specified, corresponding to onNext, onError and onComplete +/// method bodies in the subscriber. +class SingleObservers { + private: + /// Defined if Success and Error are signature-compatible with + /// onSuccess and onError subscriber methods respectively. + template < + typename T, + typename Success, + typename Error = void (*)(folly::exception_wrapper)> + using EnableIfCompatible = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type; + + public: + template > + static auto create(Next&& next) { + return std::make_shared>>( + std::forward(next)); + } + + template < + typename T, + typename Success, + typename Error, + typename = EnableIfCompatible> + static auto create(Success&& next, Error&& error) { + return std::make_shared< + WithError, std::decay_t>>( + std::forward(next), std::forward(error)); + } + + template + static auto create() { + return std::make_shared>(); + } + + private: + template + class Base : public SingleObserverBase { + static_assert(std::is_same, Next>::value, "undecayed"); + + public: + template + explicit Base(FNext&& next) : next_(std::forward(next)) {} + + void onSuccess(T value) override { + next_(std::move(value)); + // TODO how do we call the super to trigger release? + // SingleObserver::onSuccess(value); + } + + private: + Next next_; + }; + + template + class WithError : public Base { + static_assert(std::is_same, Error>::value, "undecayed"); + + public: + template + WithError(FSuccess&& success, FError&& error) + : Base(std::forward(success)), + error_(std::forward(error)) {} + + void onError(folly::exception_wrapper error) override { + error_(error); + // TODO do we call the super here to trigger release? + Base::onError(std::move(error)); + } + + private: + Error error_; + }; + + SingleObservers() = delete; +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/include/yarpl/single/SingleOperator.h b/yarpl/single/SingleOperator.h similarity index 68% rename from yarpl/include/yarpl/single/SingleOperator.h rename to yarpl/single/SingleOperator.h index d21c5a94c..0b3e7392e 100644 --- a/yarpl/include/yarpl/single/SingleOperator.h +++ b/yarpl/single/SingleOperator.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -23,7 +35,7 @@ namespace single { template class SingleOperator : public Single { public: - explicit SingleOperator(Reference> upstream) + explicit SingleOperator(std::shared_ptr> upstream) : upstream_(std::move(upstream)) {} protected: @@ -37,11 +49,12 @@ class SingleOperator : public Single { /// the user-supplied observer being the last of the pipeline stages. template class Subscription : public ::yarpl::single::SingleSubscription, - public SingleObserver { + public SingleObserver, + public yarpl::enable_get_ref { protected: Subscription( - Reference single, - Reference> observer) + std::shared_ptr single, + std::shared_ptr> observer) : single_(std::move(single)), observer_(std::move(observer)) {} ~Subscription() { @@ -56,7 +69,7 @@ class SingleOperator : public Single { terminateImpl(TerminateState::Down(), folly::Try{std::move(ew)}); } - Reference getOperator() { + std::shared_ptr getOperator() { return single_; } @@ -72,8 +85,8 @@ class SingleOperator : public Single { // Subscriber. - void onSubscribe( - Reference subscription) override { + void onSubscribe(std::shared_ptr + subscription) override { upstream_ = std::move(subscription); observer_->onSubscribe(this->ref_from_this(this)); } @@ -129,43 +142,45 @@ class SingleOperator : public Single { } /// The Single has the lambda, and other creation parameters. - Reference single_; + std::shared_ptr single_; /// This subscription controls the life-cycle of the observer. The /// observer is retained as long as calls on it can be made. (Note: /// the observer in turn maintains a reference on this subscription /// object until cancellation and/or completion.) - Reference> observer_; + std::shared_ptr> observer_; /// In an active pipeline, cancel and (possibly modified) request(n) /// calls should be forwarded upstream. Note that `this` is also a /// observer for the upstream stage: thus, there are cycles; all of /// the objects drop their references at cancel/complete. - Reference upstream_; + std::shared_ptr upstream_; }; - Reference> upstream_; + std::shared_ptr> upstream_; }; template < typename U, typename D, - typename F, - typename = typename std::enable_if::value>::type> + typename F> class MapOperator : public SingleOperator { using ThisOperatorT = MapOperator; using Super = SingleOperator; using OperatorSubscription = typename Super::template Subscription; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); public: - MapOperator(Reference> upstream, F function) - : Super(std::move(upstream)), function_(std::move(function)) {} + template + MapOperator(std::shared_ptr> upstream, Func&& function) + : Super(std::move(upstream)), function_(std::forward(function)) {} - void subscribe(Reference> observer) override { + void subscribe(std::shared_ptr> observer) override { Super::upstream_->subscribe( // Note: implicit cast to a reference to a observer. - make_ref( + std::make_shared( this->ref_from_this(this), std::move(observer))); } @@ -173,8 +188,8 @@ class MapOperator : public SingleOperator { class MapSubscription : public OperatorSubscription { public: MapSubscription( - Reference single, - Reference> observer) + std::shared_ptr single, + std::shared_ptr> observer) : OperatorSubscription(std::move(single), std::move(observer)) {} void onSuccess(U value) override { @@ -193,11 +208,16 @@ class MapOperator : public SingleOperator { template class FromPublisherOperator : public Single { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + public: - explicit FromPublisherOperator(OnSubscribe function) - : function_(std::move(function)) {} + template + explicit FromPublisherOperator(F&& function) + : function_(std::forward(function)) {} - void subscribe(Reference> observer) override { + void subscribe(std::shared_ptr> observer) override { function_(std::move(observer)); } @@ -207,11 +227,16 @@ class FromPublisherOperator : public Single { template class SingleVoidFromPublisherOperator : public Single { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + public: - explicit SingleVoidFromPublisherOperator(OnSubscribe&& function) - : function_(std::move(function)) {} + template + explicit SingleVoidFromPublisherOperator(F&& function) + : function_(std::forward(function)) {} - void subscribe(Reference> observer) override { + void subscribe(std::shared_ptr> observer) override { function_(std::move(observer)); } @@ -219,5 +244,5 @@ class SingleVoidFromPublisherOperator : public Single { OnSubscribe function_; }; -} // single -} // yarpl +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleSubscription.h b/yarpl/single/SingleSubscription.h new file mode 100644 index 000000000..ef898c1ba --- /dev/null +++ b/yarpl/single/SingleSubscription.h @@ -0,0 +1,32 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Refcounted.h" + +namespace yarpl { +namespace single { + +class SingleSubscription { + public: + virtual ~SingleSubscription() = default; + virtual void cancel() = 0; + + protected: + SingleSubscription() {} +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/include/yarpl/single/SingleSubscriptions.h b/yarpl/single/SingleSubscriptions.h similarity index 62% rename from yarpl/include/yarpl/single/SingleSubscriptions.h rename to yarpl/single/SingleSubscriptions.h index be37cfdad..9ebfe4498 100644 --- a/yarpl/include/yarpl/single/SingleSubscriptions.h +++ b/yarpl/single/SingleSubscriptions.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -13,8 +25,8 @@ namespace yarpl { namespace single { /** -* Implementation that allows checking if a Subscription is cancelled. -*/ + * Implementation that allows checking if a Subscription is cancelled. + */ class AtomicBoolSingleSubscription : public SingleSubscription { public: void cancel() override { @@ -29,8 +41,8 @@ class AtomicBoolSingleSubscription : public SingleSubscription { }; /** -* Implementation that gets a callback when cancellation occurs. -*/ + * Implementation that gets a callback when cancellation occurs. + */ class CallbackSingleSubscription : public SingleSubscription { public: explicit CallbackSingleSubscription(std::function onCancel) @@ -52,10 +64,10 @@ class CallbackSingleSubscription : public SingleSubscription { }; /** -* Implementation that can be cancelled with or without + * Implementation that can be cancelled with or without * a delegate, and when the delegate exists (before or after cancel) * it will be cancelled in a thread-safe manner. -*/ + */ class DelegateSingleSubscription : public SingleSubscription { public: explicit DelegateSingleSubscription() {} @@ -80,7 +92,7 @@ class DelegateSingleSubscription : public SingleSubscription { /** * This can be called once. */ - void setDelegate(Reference d) { + void setDelegate(std::shared_ptr d) { bool shouldCancelDelegate = false; { std::lock_guard g(m_); @@ -102,26 +114,27 @@ class DelegateSingleSubscription : public SingleSubscription { // all must be protected by a mutex mutable std::mutex m_; bool cancelled_{false}; - Reference delegate_; + std::shared_ptr delegate_; }; class SingleSubscriptions { public: - static Reference create( + static std::shared_ptr create( std::function onCancel) { - return make_ref(std::move(onCancel)); + return std::make_shared(std::move(onCancel)); } - static Reference create( + static std::shared_ptr create( std::atomic_bool& cancelled) { return create([&cancelled]() { cancelled = true; }); } - static Reference empty() { - return make_ref(); + static std::shared_ptr empty() { + return std::make_shared(); } - static Reference atomicBoolSubscription() { - return make_ref(); + static std::shared_ptr + atomicBoolSubscription() { + return std::make_shared(); } }; -} // single namespace -} // yarpl namespace +} // namespace single +} // namespace yarpl diff --git a/yarpl/include/yarpl/single/SingleTestObserver.h b/yarpl/single/SingleTestObserver.h similarity index 77% rename from yarpl/include/yarpl/single/SingleTestObserver.h rename to yarpl/single/SingleTestObserver.h index a0c8a84e7..2557f5d10 100644 --- a/yarpl/include/yarpl/single/SingleTestObserver.h +++ b/yarpl/single/SingleTestObserver.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -32,7 +44,7 @@ namespace single { * * For example: * - * auto to = SingleTestObserver::create(make_ref()); + * auto to = SingleTestObserver::create(std::make_shared()); * single->subscribe(to); * * Now when 'single' is subscribed to, the SingleTestObserver behavior @@ -49,8 +61,8 @@ class SingleTestObserver : public yarpl::single::SingleObserver { * * @return */ - static Reference> create() { - return make_ref>(); + static std::shared_ptr> create() { + return std::make_shared>(); } /** @@ -60,9 +72,9 @@ class SingleTestObserver : public yarpl::single::SingleObserver { * This will store the value it receives to allow assertions. * @return */ - static Reference> create( - Reference> delegate) { - return make_ref>(std::move(delegate)); + static std::shared_ptr> create( + std::shared_ptr> delegate) { + return std::make_shared>(std::move(delegate)); } SingleTestObserver() : delegate_(nullptr) {} @@ -74,10 +86,10 @@ class SingleTestObserver : public yarpl::single::SingleObserver { // and then access them for verification/assertion // on the unit test main thread. - explicit SingleTestObserver(Reference> delegate) + explicit SingleTestObserver(std::shared_ptr> delegate) : delegate_(std::move(delegate)) {} - void onSubscribe(Reference subscription) override { + void onSubscribe(std::shared_ptr subscription) override { if (delegate_) { delegateSubscription_->setDelegate(subscription); // copy delegate_->onSubscribe(std::move(subscription)); @@ -149,7 +161,10 @@ class SingleTestObserver : public yarpl::single::SingleObserver { throw std::runtime_error("Did not receive terminal event."); } if (e_) { - throw std::runtime_error("Received onError instead of onSuccess"); + std::stringstream ss; + ss << "Received onError instead of onSuccess"; + ss << " (error was " << e_ << ")"; + throw std::runtime_error(ss.str()); } } @@ -195,6 +210,10 @@ class SingleTestObserver : public yarpl::single::SingleObserver { } } + folly::exception_wrapper getException() const { + return e_; + } + /** * Submit SingleSubscription->cancel(); */ @@ -206,15 +225,15 @@ class SingleTestObserver : public yarpl::single::SingleObserver { private: std::mutex m_; std::condition_variable terminalEventCV_; - Reference> delegate_; + std::shared_ptr> delegate_; // The following variables must be protected by mutex m_ T value_; folly::exception_wrapper e_; bool terminated_{false}; // allows thread-safe cancellation against a delegate // regardless of when it is received - Reference delegateSubscription_{ - make_ref()}; + std::shared_ptr delegateSubscription_{ + std::make_shared()}; }; -} -} +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/Singles.h b/yarpl/single/Singles.h new file mode 100644 index 000000000..b6fe896cb --- /dev/null +++ b/yarpl/single/Singles.h @@ -0,0 +1,83 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/single/Single.h" +#include "yarpl/single/SingleSubscriptions.h" + +#include + +namespace yarpl { +namespace single { + +class Singles { + public: + template + static std::shared_ptr> just(const T& value) { + auto lambda = [value](std::shared_ptr> observer) { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onSuccess(value); + }; + + return Single::create(std::move(lambda)); + } + + template < + typename T, + typename OnSubscribe, + typename = typename std::enable_if>>::value>::type> + static std::shared_ptr> create(OnSubscribe&& function) { + return std::make_shared< + FromPublisherOperator>>( + std::forward(function)); + } + + template + static std::shared_ptr> error(folly::exception_wrapper ex) { + auto lambda = + [e = std::move(ex)](std::shared_ptr> observer) { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onError(e); + }; + return Single::create(std::move(lambda)); + } + + template + static std::shared_ptr> error(const ExceptionType& ex) { + auto lambda = [ex](std::shared_ptr> observer) { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onError(ex); + }; + return Single::create(std::move(lambda)); + } + + template + static std::shared_ptr> fromGenerator(TGenerator&& generator) { + auto lambda = [generator = std::forward(generator)]( + std::shared_ptr> observer) mutable { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onSuccess(generator()); + }; + return Single::create(std::move(lambda)); + } + + private: + Singles() = delete; +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/src/yarpl/Refcounted.cpp b/yarpl/src/yarpl/Refcounted.cpp deleted file mode 100644 index 47edc60f1..000000000 --- a/yarpl/src/yarpl/Refcounted.cpp +++ /dev/null @@ -1,111 +0,0 @@ -#include "yarpl/Refcounted.h" - -#include "folly/Synchronized.h" - -#include -#include - -namespace yarpl { -namespace detail { - -using sync_map_type = folly::Synchronized; -using refcount_pair = std::pair; - -#ifdef YARPL_REFCOUNT_DEBUGGING - -// number of objects currently live of this type -sync_map_type live_refcounted_map; - -// number of objects ever created of this type -sync_map_type total_refcounted_map; - -static void inc_in_map(std::string const& typestring, sync_map_type& the_map) { - auto map = the_map.wlock(); - auto it = map->find(typestring); - if (it == map->end()) { - map->emplace(typestring, 1); - it = map->find(typestring); - } - else { - it->second = it->second + 1; - } - - VLOG(6) << "increment " << typestring << " to " << it->second; -} - -static void dec_in_map(std::string const& typestring, sync_map_type& the_map) { - auto map = the_map.wlock(); - auto it = map->find(typestring); - if (it == map->end()) { - VLOG(6) << "didn't find " << typestring << " in the map"; - return; - } - else { - if (it->second >= 1) { - it->second = it->second - 1; - } - else { - VLOG(6) << "deallocating " << typestring << " past zero?"; - return; - } - } - - VLOG(6) << "decrement " << typestring << " to " << it->second; -} - -void inc_live(std::string const& typestring) { inc_in_map(typestring, live_refcounted_map); } -void dec_live(std::string const& typestring) { dec_in_map(typestring, live_refcounted_map); } - -void inc_created(std::string const& typestring) { inc_in_map(typestring, total_refcounted_map); } - -template -void debug_refcounts_map(std::ostream& o, sync_map_type const& the_map, ComparePred& pred) { - // truncate demangled typename - auto const max_type_len = 50; - // only print the first 'n' entries - auto max_entries = 50; - - auto the_map_locked = the_map.rlock(); - std::vector entries(the_map_locked->begin(), the_map_locked->end()); - std::sort(entries.begin(), entries.end(), pred); - - for(auto& pair : entries) { - if (!max_entries--) break; - - auto s = pair.first; - if (s.size() > max_type_len) { - s = s.substr(0, max_type_len); - } - o << std::left << std::setw (max_type_len) << s << " :: " << pair.second << std::endl; - } -} - -void debug_refcounts(std::ostream& o) { - struct { - bool operator()(refcount_pair const& a, refcount_pair const& b) { - return a.second > b.second; - } - } max_refcount_pred; - - o << "===============" << std::endl; - o << "LIVE REFCOUNTS: " << std::endl; - debug_refcounts_map(o, live_refcounted_map, max_refcount_pred); - o << "===============" << std::endl; - o << "===============" << std::endl; - o << "TOTAL REFCOUNTS: " << std::endl; - debug_refcounts_map(o, total_refcounted_map, max_refcount_pred); - o << "===============" << std::endl; -} - -#else /* YARPL_REFCOUNT_DEBUGGING */ - -void inc_created(std::string const&) { assert(false); } -void inc_live(std::string const&) { assert(false); } -void dec_live(std::string const&) { assert(false); } -void debug_refcounts(std::ostream& o) { - o << "Compile with YARPL_REFCOUNT_DEBUGGING (-DYARPL_REFCOUNT_DEBUGGING=On) to get Refcounted allocation counts" << std::endl; -} - -#endif - -} } diff --git a/yarpl/src/yarpl/flowable/sources/Subscription.cpp b/yarpl/src/yarpl/flowable/sources/Subscription.cpp deleted file mode 100644 index 905837fd2..000000000 --- a/yarpl/src/yarpl/flowable/sources/Subscription.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "yarpl/flowable/Subscription.h" - -namespace yarpl { -namespace flowable { - -yarpl::Reference Subscription::empty() { - class NullSubscription : public Subscription { - void request(int64_t) override {} - void cancel() override {} - }; - return make_ref(); -} - -} // flowable -} // yarpl diff --git a/yarpl/src/yarpl/observable/Subscriptions.cpp b/yarpl/src/yarpl/observable/Subscriptions.cpp deleted file mode 100644 index 77fbf7641..000000000 --- a/yarpl/src/yarpl/observable/Subscriptions.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "yarpl/observable/Subscriptions.h" -#include -#include -#include - -namespace yarpl { -namespace observable { - -/** - * Implementation that allows checking if a Subscription is cancelled. - */ -void Subscription::cancel() { - cancelled_ = true; - // Lock must be obtained here and not in the range expression for it to - // apply to the loop body. - auto locked = tiedSubscriptions_.wlock(); - for(auto& subscription : *locked) { - subscription->cancel(); - } -} - -bool Subscription::isCancelled() const { - return cancelled_; -} - -void Subscription::tieSubscription(Reference subscription) { - CHECK(subscription); - if (isCancelled()) { - subscription->cancel(); - } - tiedSubscriptions_->push_back(std::move(subscription)); -} - -/** - * Implementation that gets a callback when cancellation occurs. - */ -CallbackSubscription::CallbackSubscription(std::function onCancel) - : onCancel_(std::move(onCancel)) {} - -void CallbackSubscription::cancel() { - bool expected = false; - // mark cancelled 'true' and only if successful invoke 'onCancel()' - if (cancelled_.compare_exchange_strong(expected, true)) { - onCancel_(); - // Lock must be obtained here and not in the range expression for it to - // apply to the loop body. - auto locked = tiedSubscriptions_.wlock(); - for(auto& subscription : *locked) { - subscription->cancel(); - } - } -} - -Reference Subscriptions::create(std::function onCancel) { - return make_ref(std::move(onCancel)); -} - -Reference Subscriptions::create(std::atomic_bool& cancelled) { - return create([&cancelled]() { cancelled = true; }); -} - -Reference Subscriptions::create() { - return make_ref(); -} - -} -} diff --git a/yarpl/test/FlowableFlatMapTest.cpp b/yarpl/test/FlowableFlatMapTest.cpp index df398aea9..5306d9971 100644 --- a/yarpl/test/FlowableFlatMapTest.cpp +++ b/yarpl/test/FlowableFlatMapTest.cpp @@ -1,19 +1,28 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include #include #include #include #include #include - -#include -#include -#include - -#include "yarpl/test_utils/Mocks.h" - #include "yarpl/Flowable.h" #include "yarpl/flowable/TestSubscriber.h" +#include "yarpl/test_utils/Mocks.h" namespace yarpl { namespace flowable { @@ -25,9 +34,9 @@ namespace { /// exception was sent, the exception is thrown. template std::vector run( - Reference> flowable, + std::shared_ptr> flowable, int64_t requestCount = 100) { - auto subscriber = make_ref>(requestCount); + auto subscriber = std::make_shared>(requestCount); flowable->subscribe(subscriber); return std::move(subscriber->values()); } @@ -49,16 +58,16 @@ filter_range(std::vector in, int64_t startat, int64_t endat) { } auto make_flowable_mapper_func() { - return folly::Function>(int)>([](int n) { + return folly::Function>(int)>([](int n) { switch (n) { case 10: - return Flowables::range(n, 2); + return Flowable<>::range(n, 2); case 20: - return Flowables::range(n, 3); + return Flowable<>::range(n, 3); case 30: - return Flowables::range(n, 4); + return Flowable<>::range(n, 4); } - return Flowables::range(n, 3); + return Flowable<>::range(n, 3); }); } @@ -91,8 +100,8 @@ bool validate_flatmapped_values( } TEST(FlowableFlatMapTest, AllRequestedTest) { - auto f = - Flowables::justN({10, 20, 30})->flatMap(make_flowable_mapper_func()); + auto f = Flowable<>::justN({10, 20, 30}) + ->flatMap(make_flowable_mapper_func()); std::vector res = run(f); EXPECT_EQ(9UL, res.size()); @@ -102,10 +111,10 @@ TEST(FlowableFlatMapTest, AllRequestedTest) { } TEST(FlowableFlatMapTest, FiniteRequested) { - auto f = - Flowables::justN({10, 20, 30})->flatMap(make_flowable_mapper_func()); + auto f = Flowable<>::justN({10, 20, 30}) + ->flatMap(make_flowable_mapper_func()); - auto subscriber = make_ref>(1); + auto subscriber = std::make_shared>(1); f->subscribe(subscriber); EXPECT_EQ(1UL, subscriber->values().size()); @@ -121,15 +130,15 @@ TEST(FlowableFlatMapTest, FiniteRequested) { } TEST(FlowableFlatMapTest, MappingLambdaThrowsErrorOnFirstCall) { - folly::Function>(int)> func = [](int n) { + folly::Function>(int)> func = [](int n) { CHECK_EQ(1, n); throw std::runtime_error{"throwing in mapper!"}; - return Flowables::empty(); + return Flowable::empty(); }; - auto f = Flowables::just(1)->flatMap(std::move(func)); + auto f = Flowable<>::just(1)->flatMap(std::move(func)); - auto subscriber = make_ref>(1); + auto subscriber = std::make_shared>(1); f->subscribe(subscriber); EXPECT_EQ(subscriber->getValueCount(), 0); @@ -138,26 +147,26 @@ TEST(FlowableFlatMapTest, MappingLambdaThrowsErrorOnFirstCall) { } TEST(FlowableFlatMapTest, MappedStreamThrows) { - folly::Function>(int)> func = [](int n) { + folly::Function>(int)> func = [](int n) { CHECK_EQ(1, n); // flowable which emits an onNext, then the next iteration, emits an error int64_t i = 1; - return Flowable::create([i](auto subscriber, int64_t req) mutable { - CHECK_EQ(1, req); - if (i > 0) { - subscriber->onNext(i); - i--; - } else { - subscriber->onError(std::runtime_error{"throwing in stream!"}); - } - return std::tuple(1, false); - }); + return Flowable::create( + [i](auto& subscriber, int64_t req) mutable { + CHECK_EQ(1, req); + if (i > 0) { + subscriber.onNext(i); + i--; + } else { + subscriber.onError(std::runtime_error{"throwing in stream!"}); + } + }); }; - auto f = Flowables::just(1)->flatMap(std::move(func)); + auto f = Flowable<>::just(1)->flatMap(std::move(func)); - auto subscriber = make_ref>(2); + auto subscriber = std::make_shared>(2); f->subscribe(subscriber); EXPECT_EQ(subscriber->values(), std::vector({1})); @@ -183,34 +192,16 @@ struct CBSubscription : yarpl::flowable::Subscription { struct FlowableEvbPair { FlowableEvbPair() = default; - Reference> flowable{nullptr}; + std::shared_ptr> flowable{nullptr}; folly::EventBaseThread evb{}; }; std::shared_ptr make_range_flowable(int start, int end) { auto ret = std::make_shared(); ret->evb.start("MRF_Worker"); - - ret->flowable = Flowables::fromPublisher( - [&ret, start, end](Reference> s) mutable { - auto evb = ret->evb.getEventBase(); - auto subscription = yarpl::make_ref( - [=](int64_t req) mutable { - /* request */ - CHECK_EQ(req, 1); - if (start >= end) { - evb->runInEventBaseThread([=] { s->onComplete(); }); - } else { - auto n = start++; - evb->runInEventBaseThread([=] { s->onNext(n); }); - } - }, - /* onCancel: do nothing */ - []() {}); - - evb->runInEventBaseThread([=] { s->onSubscribe(subscription); }); - }); - + ret->flowable = Flowable<>::range(start, end - start) + ->map([](int64_t val) { return (int)val; }) + ->subscribeOn(*ret->evb.getEventBase()); return ret; } @@ -218,7 +209,7 @@ TEST(FlowableFlatMapTest, Multithreaded) { auto p1 = make_range_flowable(10, 12); auto p2 = make_range_flowable(20, 25); - auto f = Flowables::range(0, 2)->flatMap([&](auto i) { + auto f = Flowable<>::range(0, 2)->flatMap([&](auto i) { if (i == 0) { return p1->flowable; } else { @@ -226,7 +217,7 @@ TEST(FlowableFlatMapTest, Multithreaded) { } }); - auto sub = yarpl::make_ref>(0); + auto sub = std::make_shared>(0); f->subscribe(sub); sub->request(2); @@ -242,7 +233,7 @@ TEST(FlowableFlatMapTest, MultithreadedLargeAmount) { auto p1 = make_range_flowable(10000, 40000); auto p2 = make_range_flowable(50000, 80000); - auto f = Flowables::range(0, 2)->flatMap([&](auto i) { + auto f = Flowable<>::range(0, 2)->flatMap([&](auto i) { if (i == 0) { return p1->flowable; } else { @@ -250,7 +241,7 @@ TEST(FlowableFlatMapTest, MultithreadedLargeAmount) { } }); - auto sub = yarpl::make_ref>(); + auto sub = std::make_shared>(); sub->dropValues(true); f->subscribe(sub); @@ -264,14 +255,14 @@ TEST(FlowableFlatMapTest, MultithreadedLargeAmount) { } TEST(FlowableFlatMapTest, MergeOperator) { - auto sub = yarpl::make_ref>(0); + auto sub = std::make_shared>(0); - auto p1 = Flowables::justN({"foo", "bar"}); - auto p2 = Flowables::justN({"baz", "quxx"}); - Reference>>> p3 = - Flowables::justN>>({p1, p2}); + auto p1 = Flowable<>::justN({"foo", "bar"}); + auto p2 = Flowable<>::justN({"baz", "quxx"}); + std::shared_ptr>>> p3 = + Flowable<>::justN>>({p1, p2}); - Reference> p4 = p3->merge(); + std::shared_ptr> p4 = p3->merge(); p4->subscribe(sub); EXPECT_EQ(0, sub->getValueCount()); @@ -303,4 +294,4 @@ TEST(FlowableFlatMapTest, MergeOperator) { } } // namespace flowable -} // namespace yarpl \ No newline at end of file +} // namespace yarpl diff --git a/yarpl/test/FlowableSubscriberTest.cpp b/yarpl/test/FlowableSubscriberTest.cpp index 73835df59..683f57f4b 100644 --- a/yarpl/test/FlowableSubscriberTest.cpp +++ b/yarpl/test/FlowableSubscriberTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "yarpl/flowable/Subscriber.h" #include "yarpl/test_utils/Mocks.h" @@ -10,68 +22,125 @@ using namespace testing; namespace { +TEST(FlowableSubscriberTest, CreateSubscriber) { + int calls{0}; + struct Functor { + explicit Functor(int& calls) : calls_(calls) {} + // If we update the template definition of the Subscriber, + // then we should comment out this method and observe the compiler output + // with and without the change. + void operator()(int) & { + ++calls_; + } + void operator()(int) && { + FAIL() << "onNext lambda should be stored as l-value"; + } + void operator()(std::string) const& { + ++calls_; + } + void operator()(std::string) const&& { + FAIL() << "onNext lambda should be stored as l-value"; + } + int& calls_; + }; + auto s1 = Subscriber::create(Functor(calls)); + s1->onSubscribe(yarpl::flowable::Subscription::create()); + s1->onNext(1); + EXPECT_EQ(1, calls); + + auto s2 = Subscriber::create(Functor(calls)); + s2->onSubscribe(yarpl::flowable::Subscription::create()); + s2->onNext((long)1); + EXPECT_EQ(2, calls); + + auto s3 = Subscriber::create(Functor(calls)); + s3->onSubscribe(yarpl::flowable::Subscription::create()); + s3->onNext("test"); + EXPECT_EQ(3, calls); + + // by reference + auto f = Functor(calls); + auto s4 = Subscriber::create(f); + s4->onSubscribe(yarpl::flowable::Subscription::create()); + s4->onNext(1); + EXPECT_EQ(4, calls); +} + TEST(FlowableSubscriberTest, TestBasicFunctionality) { Sequence subscriber_seq; - auto subscriber = yarpl::make_ref>>(); + auto subscriber = std::make_shared>>(); EXPECT_CALL(*subscriber, onSubscribeImpl()) - .Times(1) - .InSequence(subscriber_seq) - .WillOnce(Invoke([&] { - subscriber->request(3); - })); - EXPECT_CALL(*subscriber, onNextImpl(5)) - .Times(1) - .InSequence(subscriber_seq); + .Times(1) + .InSequence(subscriber_seq) + .WillOnce(Invoke([&] { subscriber->request(3); })); + EXPECT_CALL(*subscriber, onNextImpl(5)).Times(1).InSequence(subscriber_seq); EXPECT_CALL(*subscriber, onCompleteImpl()) - .Times(1) - .InSequence(subscriber_seq); + .Times(1) + .InSequence(subscriber_seq); - auto subscription = yarpl::make_ref>(); + auto subscription = std::make_shared>(); EXPECT_CALL(*subscription, request_(3)) - .Times(1) - .WillOnce(InvokeWithoutArgs([&] { - subscriber->onNext(5); - subscriber->onComplete(); - })); + .Times(1) + .WillOnce(InvokeWithoutArgs([&] { + subscriber->onNext(5); + subscriber->onComplete(); + })); subscriber->onSubscribe(subscription); } TEST(FlowableSubscriberTest, TestKeepRefToThisIsDisabled) { - auto subscriber = yarpl::make_ref>>(); - auto subscription = yarpl::make_ref>(); + auto subscriber = + std::make_shared>>(); + auto subscription = std::make_shared>(); // tests that only a single reference exists to the Subscriber; clearing // reference in `auto subscriber` would cause it to deallocate { InSequence s; - EXPECT_CALL(*subscriber, onSubscribeImpl()) - .Times(1) - .WillOnce(Invoke([&] { - EXPECT_EQ(1UL, subscriber->count()); - })); + EXPECT_CALL(*subscriber, onSubscribeImpl()).Times(1).WillOnce(Invoke([&] { + EXPECT_EQ(1UL, subscriber.use_count()); + })); } subscriber->onSubscribe(subscription); } TEST(FlowableSubscriberTest, TestKeepRefToThisIsEnabled) { - auto subscriber = yarpl::make_ref>>(); - auto subscription = yarpl::make_ref>(); + auto subscriber = std::make_shared>>(); + auto subscription = std::make_shared>(); // tests that only a reference is held somewhere on the stack, so clearing // references to `BaseSubscriber` while in a signaling method won't // deallocate it (until it's safe to do so) { InSequence s; - EXPECT_CALL(*subscriber, onSubscribeImpl()) - .Times(1) - .WillOnce(Invoke([&] { - EXPECT_EQ(2UL, subscriber->count()); - })); + EXPECT_CALL(*subscriber, onSubscribeImpl()).Times(1).WillOnce(Invoke([&] { + EXPECT_EQ(2UL, subscriber.use_count()); + })); } subscriber->onSubscribe(subscription); } +TEST(FlowableSubscriberTest, AutoFlowControl) { + size_t count = 0; + auto subscriber = Subscriber::create( + [&](int value) { + ++count; + EXPECT_EQ(value, count); + }, + 1); + auto subscription = std::make_shared>(); + + EXPECT_CALL(*subscription, request_(1)) + .Times(3) + .WillOnce(InvokeWithoutArgs([&] { subscriber->onNext(1); })) + .WillOnce(InvokeWithoutArgs([&] { + subscriber->onNext(2); + subscriber->onComplete(); + })); + + subscriber->onSubscribe(subscription); } +} // namespace diff --git a/yarpl/test/FlowableTest.cpp b/yarpl/test/FlowableTest.cpp index 6f2524f09..7c8a77353 100644 --- a/yarpl/test/FlowableTest.cpp +++ b/yarpl/test/FlowableTest.cpp @@ -1,5 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include #include #include #include @@ -7,11 +22,18 @@ #include #include "yarpl/Flowable.h" +#include "yarpl/flowable/Subscriber.h" #include "yarpl/flowable/TestSubscriber.h" #include "yarpl/test_utils/Mocks.h" -namespace yarpl { -namespace flowable { +#if FOLLY_HAS_COROUTINES +#include +#include "yarpl/flowable/AsyncGeneratorShim.h" +#endif + +using namespace yarpl::flowable; +using namespace testing; + namespace { /* @@ -73,29 +95,29 @@ class CollectingSubscriber : public BaseSubscriber { /// exception was sent, the exception is thrown. template std::vector run( - Reference> flowable, + std::shared_ptr> flowable, int64_t requestCount = 100) { - auto subscriber = make_ref>(requestCount); + auto subscriber = std::make_shared>(requestCount); flowable->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); return std::move(subscriber->values()); } - } // namespace TEST(FlowableTest, SingleFlowable) { - auto flowable = Flowables::just(10); + auto flowable = Flowable<>::just(10); flowable.reset(); } TEST(FlowableTest, SingleMovableFlowable) { auto value = std::make_unique(123456); - auto flowable = Flowables::justOnce(std::move(value)); - EXPECT_EQ(std::size_t{1}, flowable->count()); + auto flowable = Flowable<>::justOnce(std::move(value)); + EXPECT_EQ(1, flowable.use_count()); size_t received = 0; auto subscriber = - Subscribers::create>([&](std::unique_ptr p) { + Subscriber>::create([&](std::unique_ptr p) { EXPECT_EQ(*p, 123456); received++; }); @@ -105,24 +127,24 @@ TEST(FlowableTest, SingleMovableFlowable) { } TEST(FlowableTest, JustFlowable) { - EXPECT_EQ(run(Flowables::just(22)), std::vector{22}); + EXPECT_EQ(run(Flowable<>::just(22)), std::vector{22}); EXPECT_EQ( - run(Flowables::justN({12, 34, 56, 98})), + run(Flowable<>::justN({12, 34, 56, 98})), std::vector({12, 34, 56, 98})); EXPECT_EQ( - run(Flowables::justN({"ab", "pq", "yz"})), + run(Flowable<>::justN({"ab", "pq", "yz"})), std::vector({"ab", "pq", "yz"})); } TEST(FlowableTest, JustIncomplete) { - auto flowable = Flowables::justN({"a", "b", "c"})->take(2); + auto flowable = Flowable<>::justN({"a", "b", "c"})->take(2); EXPECT_EQ(run(std::move(flowable)), std::vector({"a", "b"})); - flowable = Flowables::justN({"a", "b", "c"})->take(2)->take(1); + flowable = Flowable<>::justN({"a", "b", "c"})->take(2)->take(1); EXPECT_EQ(run(std::move(flowable)), std::vector({"a"})); flowable.reset(); - flowable = Flowables::justN( + flowable = Flowable<>::justN( {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) ->map([](std::string s) { s[0] = ::toupper(s[0]); @@ -137,14 +159,14 @@ TEST(FlowableTest, JustIncomplete) { } TEST(FlowableTest, MapWithException) { - auto flowable = Flowables::justN({1, 2, 3, 4})->map([](int n) { + auto flowable = Flowable<>::justN({1, 2, 3, 4})->map([](int n) { if (n > 2) { throw std::runtime_error{"Too big!"}; } return n; }); - auto subscriber = yarpl::make_ref>(); + auto subscriber = std::make_shared>(); flowable->subscribe(subscriber); EXPECT_EQ(subscriber->values(), std::vector({1, 2})); @@ -154,11 +176,12 @@ TEST(FlowableTest, MapWithException) { TEST(FlowableTest, Range) { EXPECT_EQ( - run(Flowables::range(10, 5)), std::vector({10, 11, 12, 13, 14})); + run(Flowable<>::range(10, 5)), + std::vector({10, 11, 12, 13, 14})); } TEST(FlowableTest, RangeWithMap) { - auto flowable = Flowables::range(1, 3) + auto flowable = Flowable<>::range(1, 3) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return std::to_string(v); }); @@ -167,17 +190,17 @@ TEST(FlowableTest, RangeWithMap) { } TEST(FlowableTest, RangeWithReduceMoreItems) { - auto flowable = Flowables::range(0, 10)->reduce( + auto flowable = Flowable<>::range(0, 10)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); EXPECT_EQ(run(std::move(flowable)), std::vector({45})); } TEST(FlowableTest, RangeWithReduceByMultiplication) { - auto flowable = Flowables::range(0, 10)->reduce( + auto flowable = Flowable<>::range(0, 10)->reduce( [](int64_t acc, int64_t v) { return acc * v; }); EXPECT_EQ(run(std::move(flowable)), std::vector({0})); - flowable = Flowables::range(1, 10)->reduce( + flowable = Flowable<>::range(1, 10)->reduce( [](int64_t acc, int64_t v) { return acc * v; }); EXPECT_EQ( run(std::move(flowable)), @@ -185,22 +208,22 @@ TEST(FlowableTest, RangeWithReduceByMultiplication) { } TEST(FlowableTest, RangeWithReduceLessItems) { - auto flowable = Flowables::range(0, 10)->reduce( + auto flowable = Flowable<>::range(0, 10)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); // Even if we ask for 1 item only, it will reduce all the items EXPECT_EQ(run(std::move(flowable), 5), std::vector({45})); } TEST(FlowableTest, RangeWithReduceOneItem) { - auto flowable = Flowables::range(5, 1)->reduce( + auto flowable = Flowable<>::range(5, 1)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); EXPECT_EQ(run(std::move(flowable)), std::vector({5})); } TEST(FlowableTest, RangeWithReduceNoItem) { - auto flowable = Flowables::range(0, 0)->reduce( + auto flowable = Flowable<>::range(0, 0)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); - auto subscriber = make_ref>(100); + auto subscriber = std::make_shared>(100); flowable->subscribe(subscriber); EXPECT_TRUE(subscriber->isComplete()); @@ -208,7 +231,7 @@ TEST(FlowableTest, RangeWithReduceNoItem) { } TEST(FlowableTest, RangeWithFilterAndReduce) { - auto flowable = Flowables::range(0, 10) + auto flowable = Flowable<>::range(0, 10) ->filter([](int64_t v) { return v % 2 != 0; }) ->reduce([](int64_t acc, int64_t v) { return acc + v; }); EXPECT_EQ( @@ -216,7 +239,7 @@ TEST(FlowableTest, RangeWithFilterAndReduce) { } TEST(FlowableTest, RangeWithReduceToBiggerType) { - auto flowable = Flowables::range(5, 1) + auto flowable = Flowable<>::range(5, 1) ->map([](int64_t v) { return (char)(v + 10); }) ->reduce([](int64_t acc, char v) { return acc + v; }); EXPECT_EQ(run(std::move(flowable)), std::vector({15})); @@ -224,7 +247,7 @@ TEST(FlowableTest, RangeWithReduceToBiggerType) { TEST(FlowableTest, StringReduce) { auto flowable = - Flowables::justN( + Flowable<>::justN( {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) ->reduce([](std::string acc, std::string v) { return acc + v; }); EXPECT_EQ(run(std::move(flowable)), std::vector({"abcdefghi"})); @@ -232,18 +255,18 @@ TEST(FlowableTest, StringReduce) { TEST(FlowableTest, RangeWithFilterRequestMoreItems) { auto flowable = - Flowables::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); + Flowable<>::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(flowable)), std::vector({1, 3, 5, 7, 9})); } TEST(FlowableTest, RangeWithFilterRequestLessItems) { auto flowable = - Flowables::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); + Flowable<>::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(flowable), 5), std::vector({1, 3, 5, 7, 9})); } TEST(FlowableTest, RangeWithFilterAndMap) { - auto flowable = Flowables::range(0, 10) + auto flowable = Flowable<>::range(0, 10) ->filter([](int64_t v) { return v % 2 != 0; }) ->map([](int64_t v) { return v + 10; }); EXPECT_EQ( @@ -251,7 +274,7 @@ TEST(FlowableTest, RangeWithFilterAndMap) { } TEST(FlowableTest, RangeWithMapAndFilter) { - auto flowable = Flowables::range(0, 10) + auto flowable = Flowable<>::range(0, 10) ->map([](int64_t v) { return (char)(v + 10); }) ->filter([](char v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(flowable)), std::vector({11, 13, 15, 17, 19})); @@ -259,23 +282,49 @@ TEST(FlowableTest, RangeWithMapAndFilter) { TEST(FlowableTest, SimpleTake) { EXPECT_EQ( - run(Flowables::range(0, 100)->take(3)), std::vector({0, 1, 2})); + run(Flowable<>::range(0, 100)->take(3)), std::vector({0, 1, 2})); EXPECT_EQ( - run(Flowables::range(10, 5)), std::vector({10, 11, 12, 13, 14})); + run(Flowable<>::range(10, 5)), + std::vector({10, 11, 12, 13, 14})); + + EXPECT_EQ(run(Flowable<>::range(0, 100)->take(0)), std::vector({})); +} + +TEST(FlowableTest, TakeError) { + auto take0 = + Flowable::error(std::runtime_error("something broke!"))->take(0); + + auto subscriber = std::make_shared>(); + take0->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); +} + +TEST(FlowableTes, NeverTake) { + auto take0 = Flowable::never()->take(0); + + auto subscriber = std::make_shared>(); + take0->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); } TEST(FlowableTest, SimpleSkip) { EXPECT_EQ( - run(Flowables::range(0, 10)->skip(8)), std::vector({8, 9})); + run(Flowable<>::range(0, 10)->skip(8)), std::vector({8, 9})); } TEST(FlowableTest, OverflowSkip) { - EXPECT_EQ(run(Flowables::range(0, 10)->skip(12)), std::vector({})); + EXPECT_EQ(run(Flowable<>::range(0, 10)->skip(12)), std::vector({})); } TEST(FlowableTest, SkipPartial) { - auto subscriber = make_ref>(2); - auto flowable = Flowables::range(0, 10)->skip(5); + auto subscriber = std::make_shared>(2); + auto flowable = Flowable<>::range(0, 10)->skip(5); flowable->subscribe(subscriber); EXPECT_EQ(subscriber->values(), std::vector({5, 6})); @@ -283,14 +332,14 @@ TEST(FlowableTest, SkipPartial) { } TEST(FlowableTest, IgnoreElements) { - auto flowable = Flowables::range(0, 100)->ignoreElements()->map( + auto flowable = Flowable<>::range(0, 100)->ignoreElements()->map( [](int64_t v) { return v * v; }); EXPECT_EQ(run(flowable), std::vector({})); } TEST(FlowableTest, IgnoreElementsPartial) { - auto subscriber = make_ref>(5); - auto flowable = Flowables::range(0, 10)->ignoreElements(); + auto subscriber = std::make_shared>(5); + auto flowable = Flowable<>::range(0, 10)->ignoreElements(); flowable->subscribe(subscriber); EXPECT_EQ(subscriber->values(), std::vector({})); @@ -300,11 +349,11 @@ TEST(FlowableTest, IgnoreElementsPartial) { subscriber->cancel(); } -TEST(FlowableTest, IgnoreElementsError) { +TEST(FlowableTest, FlowableErrorNoRequestN) { constexpr auto kMsg = "Failure"; - auto subscriber = make_ref>(); - auto flowable = Flowables::error(std::runtime_error(kMsg)); + auto subscriber = std::make_shared>(0); + auto flowable = Flowable::error(std::runtime_error(kMsg)); flowable->subscribe(subscriber); EXPECT_TRUE(subscriber->isError()); @@ -314,8 +363,8 @@ TEST(FlowableTest, IgnoreElementsError) { TEST(FlowableTest, FlowableError) { constexpr auto kMsg = "something broke!"; - auto flowable = Flowables::error(std::runtime_error(kMsg)); - auto subscriber = make_ref>(); + auto flowable = Flowable::error(std::runtime_error(kMsg)); + auto subscriber = std::make_shared>(); flowable->subscribe(subscriber); EXPECT_FALSE(subscriber->isComplete()); @@ -323,32 +372,58 @@ TEST(FlowableTest, FlowableError) { EXPECT_EQ(subscriber->getErrorMsg(), kMsg); } -TEST(FlowableTest, FlowableErrorPtr) { - constexpr auto kMsg = "something broke!"; +TEST(FlowableTest, FlowableEmpty) { + auto flowable = Flowable::empty(); + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); - auto flowable = Flowables::error(std::runtime_error(kMsg)); - auto subscriber = make_ref>(); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); +} + +TEST(FlowableTest, FlowableEmptyNoRequestN) { + auto flowable = Flowable::empty(); + auto subscriber = std::make_shared>(0); flowable->subscribe(subscriber); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); +} + +TEST(FlowableTest, FlowableNever) { + auto flowable = Flowable::never(); + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + EXPECT_THROW( + subscriber->awaitTerminalEvent(std::chrono::milliseconds(100)), + std::runtime_error); + EXPECT_FALSE(subscriber->isComplete()); - EXPECT_TRUE(subscriber->isError()); - EXPECT_EQ(subscriber->getErrorMsg(), kMsg); + EXPECT_FALSE(subscriber->isError()); + + subscriber->cancel(); } -TEST(FlowableTest, FlowableEmpty) { - auto flowable = Flowables::empty(); - auto subscriber = make_ref>(); +TEST(FlowableTest, FlowableNeverNoRequestN) { + auto flowable = Flowable::never(); + auto subscriber = std::make_shared>(0); flowable->subscribe(subscriber); + EXPECT_THROW( + subscriber->awaitTerminalEvent(std::chrono::milliseconds(100)), + std::runtime_error); - EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isComplete()); EXPECT_FALSE(subscriber->isError()); + + subscriber->cancel(); } TEST(FlowableTest, FlowableFromGenerator) { - auto flowable = Flowables::fromGenerator>( + auto flowable = Flowable>::fromGenerator( [] { return std::unique_ptr(); }); - auto subscriber = make_ref>>(10); + auto subscriber = + std::make_shared>>(10); flowable->subscribe(subscriber); EXPECT_FALSE(subscriber->isComplete()); @@ -361,14 +436,15 @@ TEST(FlowableTest, FlowableFromGenerator) { TEST(FlowableTest, FlowableFromGeneratorException) { constexpr auto errorMsg = "error from generator"; int count = 5; - auto flowable = Flowables::fromGenerator>([&] { + auto flowable = Flowable>::fromGenerator([&] { while (count--) { return std::unique_ptr(); } throw std::runtime_error(errorMsg); }); - auto subscriber = make_ref>>(10); + auto subscriber = + std::make_shared>>(10); flowable->subscribe(subscriber); EXPECT_FALSE(subscriber->isComplete()); @@ -378,31 +454,29 @@ TEST(FlowableTest, FlowableFromGeneratorException) { } TEST(FlowableTest, SubscribersComplete) { - auto flowable = Flowables::empty(); - auto subscriber = Subscribers::create( + auto flowable = Flowable::empty(); + auto subscriber = Subscriber::create( [](int) { FAIL(); }, [](folly::exception_wrapper) { FAIL(); }, [&] {}); flowable->subscribe(std::move(subscriber)); } TEST(FlowableTest, SubscribersError) { - auto flowable = Flowables::error(std::runtime_error("Whoops")); - auto subscriber = Subscribers::create( + auto flowable = Flowable::error(std::runtime_error("Whoops")); + auto subscriber = Subscriber::create( [](int) { FAIL(); }, [&](folly::exception_wrapper) {}, [] { FAIL(); }); flowable->subscribe(std::move(subscriber)); } TEST(FlowableTest, FlowableCompleteInTheMiddle) { auto flowable = - Flowable::create( - [](Reference> subscriber, int64_t requested) { - EXPECT_GT(requested, 1); - subscriber->onNext(123); - subscriber->onComplete(); - return std::make_tuple(int64_t(1), true); - }) + Flowable::create([](auto& subscriber, int64_t requested) { + EXPECT_GT(requested, 1); + subscriber.onNext(123); + subscriber.onComplete(); + }) ->map([](int v) { return std::to_string(v); }); - auto subscriber = make_ref>(10); + auto subscriber = std::make_shared>(10); flowable->subscribe(subscriber); EXPECT_TRUE(subscriber->isComplete()); @@ -443,22 +517,20 @@ namespace { // workaround for gcc-4.9 auto const expect_count = 10000; TEST(FlowableTest, FlowableFromDifferentThreads) { - auto flowable = Flowable::create([&](auto subscriber, int64_t req) { + auto flowable = Flowable::create([&](auto& subscriber, int64_t req) { EXPECT_EQ(req, expect_count); auto t1 = std::thread([&] { for (int32_t i = 0; i < req; i++) { - subscriber->onNext(i); + subscriber.onNext(i); } - - subscriber->onComplete(); + subscriber.onComplete(); }); t1.join(); - return std::make_tuple(req, true); }); auto t2 = std::thread([&] { folly::Baton<> on_flowable_complete; - flowable->subscribe(yarpl::make_ref( + flowable->subscribe(std::make_shared( expect_count, on_flowable_complete)); on_flowable_complete.timed_wait(std::chrono::milliseconds(100)); }); @@ -515,24 +587,22 @@ auto const expect = 5000; auto const the_ex = folly::make_exception_wrapper("wat"); TEST(FlowableTest, FlowableFromDifferentThreadsWithError) { - auto flowable = Flowable::create([=](auto subscriber, int64_t req) { + auto flowable = Flowable::create([=](auto& subscriber, int64_t req) { EXPECT_EQ(req, request); EXPECT_LT(expect, request); auto t1 = std::thread([&] { for (int32_t i = 0; i < expect; i++) { - subscriber->onNext(i); + subscriber.onNext(i); } - - subscriber->onError(the_ex); + subscriber.onError(the_ex); }); t1.join(); - return std::make_tuple(expect, true); }); auto t2 = std::thread([&] { folly::Baton<> on_flowable_error; - flowable->subscribe(yarpl::make_ref( + flowable->subscribe(std::make_shared( expect, request, on_flowable_error, the_ex)); on_flowable_error.timed_wait(std::chrono::milliseconds(100)); }); @@ -545,17 +615,16 @@ TEST(FlowableTest, SubscribeMultipleTimes) { using namespace ::testing; using StrictMockSubscriber = testing::StrictMock>; - auto f = Flowable::create([](auto subscriber, int64_t req) { + auto f = Flowable::create([](auto& subscriber, int64_t req) { for (int64_t i = 0; i < req; i++) { - subscriber->onNext(i); + subscriber.onNext(i); } - subscriber->onComplete(); - return std::make_tuple(req, true); + subscriber.onComplete(); }); auto setup_mock = [](auto request_num, auto& resps) { - auto mock = make_ref(request_num); + auto mock = std::make_shared(request_num); Sequence seq; EXPECT_CALL(*mock, onSubscribe_(_)).InSequence(seq); @@ -597,7 +666,7 @@ TEST(FlowableTest, SubscribeMultipleTimes) { /* following test should probably behave like: * TEST(FlowableTest, ConsumerThrows_OnNext) { - auto range = Flowables::range(1, 10); + auto range = Flowable<>::range(1, 10); EXPECT_THROWS({ range->subscribe( @@ -613,7 +682,7 @@ TEST(FlowableTest, ConsumerThrows_OnNext) { TEST(FlowableTest, ConsumerThrows_OnNext) { bool onErrorIsCalled{false}; - Flowables::range(1, 10)->subscribe( + Flowable<>::range(1, 10)->subscribe( [](auto) { throw std::runtime_error("throw at consumption"); }, [&onErrorIsCalled](auto ex) { onErrorIsCalled = true; }, []() { FAIL() << "onError should have been called"; }); @@ -621,31 +690,830 @@ TEST(FlowableTest, ConsumerThrows_OnNext) { EXPECT_TRUE(onErrorIsCalled); } -TEST(FlowableTest, ConsumerThrows_OnError) { - try { - Flowables::range(1, 10)->subscribe( - [](auto) { throw std::runtime_error("throw at consumption"); }, - [](auto) { throw std::runtime_error("throw at onError"); }, - []() { FAIL() << "onError should have been called"; }); - } catch (const std::runtime_error& exn) { - FAIL() << "Error thrown in onError should have been caught."; - } catch (...) { - LOG(INFO) << "The app crashes in DEBUG mode to inform the implementor."; +TEST(FlowableTest, ConsumerThrows_OnNext_Cancel) { + class TestOperator : public FlowableOperator { + public: + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = + std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + subscriber->onSubscribe(subscription); + + try { + subscriber->onNext(1); + } catch (const std::exception&) { + FAIL() + << "onNext should not throw but subscription should get canceled."; + } + } + }; + + auto testOperator = std::make_shared(); + auto mapped = testOperator->map([](uint32_t i) { + throw std::runtime_error("test"); + return i; + }); + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + mapped->subscribe(mockSubscriber); +} + +TEST(FlowableTest, DeferTest) { + int switchValue = 0; + auto flowable = Flowable::defer([&]() { + if (switchValue == 0) { + return Flowable<>::range(1, 1); + } else { + return Flowable<>::range(3, 1); + } + }); + + EXPECT_EQ(run(flowable), std::vector({1})); + switchValue = 1; + EXPECT_EQ(run(flowable), std::vector({3})); +} + +TEST(FlowableTest, DeferExceptionTest) { + auto flowable = Flowable::defer([&]() -> std::shared_ptr> { + throw std::runtime_error{"Too big!"}; + }); + + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "Too big!"); +} + +TEST(FlowableTest, DoOnSubscribeTest) { + auto a = Flowable::empty(); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnSubscribe([&] { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnNextTest) { + std::vector values; + auto a = Flowable<>::range(10, 14)->doOnNext( + [&](int64_t v) { values.push_back(v); }); + auto values2 = run(std::move(a)); + EXPECT_EQ(values, values2); +} + +TEST(FlowableTest, DoOnErrorTest) { + auto a = Flowable::error(std::runtime_error("something broke!")); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnError([&](const auto&) { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnTerminateTest) { + auto a = Flowable::empty(); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnTerminate([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnTerminate2Test) { + auto a = Flowable::error(std::runtime_error("something broke!")); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnTerminate([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnEachTest) { + // TODO(lehecka): rewrite with concatWith + auto a = Flowable::create([](Subscriber& s, int64_t) { + s.onNext(5); + s.onError(std::runtime_error("something broke!")); + }); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()).Times(2); + a->doOnEach([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnTest) { + // TODO(lehecka): rewrite with concatWith + auto a = Flowable::create([](Subscriber& s, int64_t) { + s.onNext(5); + s.onError(std::runtime_error("something broke!")); + }); + + MockFunction checkpoint1; + EXPECT_CALL(checkpoint1, Call()); + MockFunction checkpoint2; + EXPECT_CALL(checkpoint2, Call()); + + a->doOn( + [&](int value) { + checkpoint1.Call(); + EXPECT_EQ(value, 5); + }, + [] { FAIL(); }, + [&](const auto&) { checkpoint2.Call(); }) + ->subscribe(); +} + +TEST(FlowableTest, DoOnCancelTest) { + auto a = Flowable<>::range(1, 10); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnCancel([&]() { checkpoint.Call(); })->take(1)->subscribe(); +} + +template +void cancelDuringOnNext(Op&& op, F&& f) { + folly::Baton<> next, cancelled; + + folly::ScopedEventBaseThread thread; + + auto d = op(Flowable<>::justN({1, 2}), + [&, marker = std::make_shared(1), f](auto&&... args) { + auto weak = std::weak_ptr(marker); + // This simulates subscription cancellation during onNext + next.post(); + cancelled.wait(); + // Lambda with all captures should still exist, while it's + // handling onNext call. If it doesn't exist, the following + // lock will fail. + EXPECT_TRUE(weak.lock()); + return f(args...); + }) + ->observeOn(thread.getEventBase()) + ->subscribe([](int) {}); + + // Wait till onNext is called, and cancel subscription while onNext is still + // in progress + ASSERT_TRUE(next.try_wait_for(std::chrono::seconds(1))); + d->dispose(); + + // Let onNext finish + cancelled.post(); +} + +TEST(FlowableTest, CancelDuringMapOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->map(f); }, + [](int value) { return value; }); +} + +TEST(FlowableTest, CancelDuringFilterOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->filter(f); }, + [](int value) { return value > 0; }); +} + +TEST(FlowableTest, CancelDuringReduceOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->reduce(f); }, + [](int acc, int value) { return acc + value; }); +} + +TEST(FlowableTest, CancelDuringDoOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->doOnNext(f); }, + [](int) {}); +} + +TEST(FlowableTest, DoOnRequestTest) { + auto a = Flowable<>::range(1, 10); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call(2)); + + a->doOnRequest([&](int64_t n) { checkpoint.Call(n); })->take(2)->subscribe(); +} + +TEST(FlowableTest, ConcatWithTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto combined = first->concatWith(second); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); +} + +TEST(FlowableTest, ConcatWithMultipleTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable<>::range(10, 2); + auto fourth = Flowable<>::range(15, 2); + auto firstSecond = first->concatWith(second); + auto thirdFourth = third->concatWith(fourth); + auto combined = firstSecond->concatWith(thirdFourth); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(FlowableTest, ConcatWithExceptionTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable::error(std::runtime_error("error")); + + auto combined = first->concatWith(second)->concatWith(third); + + auto subscriber = std::make_shared>(); + combined->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 5, 6})); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "error"); +} + +TEST(FlowableTest, ConcatWithFlowControlTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable<>::range(10, 2); + auto fourth = Flowable<>::range(15, 2); + auto firstSecond = first->concatWith(second); + auto thirdFourth = third->concatWith(fourth); + auto combined = firstSecond->concatWith(thirdFourth); + + auto subscriber = std::make_shared>(0); + combined->subscribe(subscriber); + EXPECT_EQ(subscriber->values(), std::vector{}); + + const std::vector allResults{1, 2, 5, 6, 10, 11, 15, 16}; + for (int i = 1; i <= 8; ++i) { + subscriber->request(1); + subscriber->awaitValueCount(1, std::chrono::seconds(1)); + EXPECT_EQ( + subscriber->values(), + std::vector(allResults.begin(), allResults.begin() + i)); } } -TEST(FlowableTest, ConsumerThrows_OnComplete) { - try { - Flowables::range(1, 10)->subscribe( - [](auto) {}, - [](auto) { FAIL() << "onComplete should have been called"; }, - []() { throw std::runtime_error("throw at onComplete"); }); - } catch (const std::runtime_error&) { - FAIL() << "Error thrown in onComplete should have been caught."; - } catch (...) { - LOG(INFO) << "The app crashes in DEBUG mode to inform the implementor."; +TEST(FlowableTest, ConcatWithCancel) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + + auto combined = first->concatWith(second); + auto subscriber = std::make_shared>(0); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + combined->doOnCancel([&]() { checkpoint.Call(); })->subscribe(subscriber); + + subscriber->request(3); + subscriber->awaitValueCount(3, std::chrono::seconds(1)); + + subscriber->cancel(); + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 5})); +} + +TEST(FlowableTest, ConcatWithCompleteAtSubscription) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + + auto combined = first->concatWith(second)->take(0); + EXPECT_EQ(run(combined), std::vector({})); +} + +TEST(FlowableTest, ConcatWithVarArgsTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable<>::range(10, 2); + auto fourth = Flowable<>::range(15, 2); + + auto combined = first->concatWith(second, third, fourth); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(FlowableTest, ConcatTest) { + auto combined = Flowable::concat( + Flowable<>::range(1, 2), Flowable<>::range(5, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + + // Flowable::concat shoud not accept one parameter! + // Next line should cause compiler failure: OK! + // combined = Flowable::concat(Flowable<>::range(1, 2)); + + combined = Flowable::concat( + Flowable<>::range(1, 2), + Flowable<>::range(5, 2), + Flowable<>::range(10, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11})); + + combined = Flowable::concat( + Flowable<>::range(1, 2), + Flowable<>::range(5, 2), + Flowable<>::range(10, 2), + Flowable<>::range(15, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(FlowableTest, ConcatWith_DelaySubscribe) { + // If there is no request for the second flowable, don't subscribe to it + bool subscribed = false; + auto a = Flowable<>::range(1, 1); + auto b = Flowable<>::range(2, 1)->doOnSubscribe( + [&subscribed]() { subscribed = true; }); + auto combined = a->concatWith(b); + + uint32_t request = 0; + auto subscriber = std::make_shared>(request); + combined->subscribe(subscriber); + subscriber->request(1); + + ASSERT_EQ(subscriber->values(), std::vector({1})); + ASSERT_FALSE(subscribed); + + // termination signal! + subscriber->cancel(); // otherwise we leak the active subscription +} + +TEST(FlowableTest, ConcatWith_EagerCancel) { + // If there is no request for the second flowable, don't subscribe to it + bool subscribed = false; + + // Control the execution of SubscribeOn operator + folly::EventBase evb; + + auto a = Flowable<>::range(1, 1); + auto b = Flowable<>::range(2, 1)->subscribeOn(evb)->doOnSubscribe( + [&subscribed]() { subscribed = true; }); + auto combined = a->concatWith(b); + + uint32_t request = 2; + std::vector values; + auto subscriber = yarpl::flowable::Subscriber::create( + [&values](int64_t value) { values.push_back(value); }, request); + + combined->subscribe(subscriber); + + // Even though we requested 2 items, we received 1 item + ASSERT_EQ(values, std::vector({1})); + ASSERT_FALSE(subscribed); // not yet, callback did not arrive yet! + + // We have requested 2 items, but did not consume the second item yet + // and we send a cancel before looping the eventBase + auto baseSubscriber = static_cast*>(subscriber.get()); + baseSubscriber->cancel(); + + // If the evb is never looped, it will cause memory leak + evb.loop(); + ASSERT_EQ(values, std::vector({1})); // no change! + ASSERT_TRUE(subscribed); // subscribe() already issued before the cancel +} + +class TestTimeout : public folly::AsyncTimeout { + public: + explicit TestTimeout(folly::EventBase* eventBase, folly::Function fn) + : AsyncTimeout(eventBase), fn_(std::move(fn)) {} + + void timeoutExpired() noexcept override { + fn_(); } + + folly::Function fn_; +}; + +TEST(FlowableTest, Timeout_SpecialException) { + class RestrictedType { + public: + RestrictedType() = default; + RestrictedType(RestrictedType&&) noexcept = default; + RestrictedType& operator=(RestrictedType&&) noexcept = default; + auto operator()() { + return std::logic_error("RestrictedType"); + } + }; + + folly::EventBase timerEvb; + auto flowable = Flowable::never()->timeout( + timerEvb, + std::chrono::milliseconds(0), + std::chrono::milliseconds(1), + RestrictedType{}); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + + timerEvb.loop(); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->exceptionWrapper().with_exception( + [](const std::logic_error& ex) { + EXPECT_STREQ("RestrictedType", ex.what()); + })); } -} // namespace flowable -} // namespace yarpl +TEST(FlowableTest, Timeout_NoTimeout) { + folly::EventBase timerEvb; + auto flowable = Flowable<>::range(1, 1)->observeOn(timerEvb)->timeout( + timerEvb, std::chrono::milliseconds(0), std::chrono::milliseconds(0)); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + EXPECT_EQ(subscriber->values(), std::vector({1})); + + flowable = + Flowable::create([=](auto& subscriber, int64_t) { + subscriber.onNext(2); + subscriber.onComplete(); + }) + ->observeOn(timerEvb) + ->timeout(timerEvb, std::chrono::seconds(0), std::chrono::seconds(0)); + + subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + EXPECT_EQ(subscriber->values(), std::vector({2})); +} + +TEST(FlowableTest, Timeout_OnNextTimeout) { + folly::EventBase timerEvb; + + auto flowable = Flowable<>::range(1, 2)->observeOn(timerEvb)->timeout( + timerEvb, + std::chrono::milliseconds(50), + std::chrono::milliseconds(0)); // no init_timeout + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // request next in 100 msec, timeout! + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + // first one is consumed + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_InitTimeout) { + folly::EventBase timerEvb; + auto flowable = Flowable::create([=](auto& subscriber, int64_t req) { + if (req > 0) { + subscriber.onNext(2); + subscriber.onComplete(); + } + }) + ->observeOn(timerEvb) + ->timeout( + timerEvb, + std::chrono::milliseconds(0), + std::chrono::milliseconds(10)); + + int requestCount = 0; + auto subscriber = std::make_shared>(requestCount); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // timeout the init + + flowable->subscribe(subscriber); + flowable.reset(); + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_StopUsageOfTimer) { + // When the consumption completes, it should stop using the timer + auto flowable = Flowable<>::range(1, 1); + { + // EventBase will be deleted before the flowable + folly::EventBase timerEvb; + auto flowableIn = flowable->timeout( + timerEvb, std::chrono::milliseconds(1), std::chrono::milliseconds(0)); + EXPECT_EQ(run(flowableIn), std::vector({1})); + } +} + +TEST(FlowableTest, Timeout_NeverOperator_Timesout) { + folly::EventBase timerEvb; + auto flowable = Flowable::never()->observeOn(timerEvb)->timeout( + timerEvb, std::chrono::milliseconds(10), std::chrono::milliseconds(10)); + + int requestCount = 10; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_BecauseOfNoRequest) { + folly::ScopedEventBaseThread timerThread; + auto flowable = Flowable<>::range(1, 2) + ->observeOn(*timerThread.getEventBase()) + ->timeout( + *timerThread.getEventBase(), + std::chrono::seconds(1), + std::chrono::milliseconds(10)); + + int requestCount = 0; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_WithObserveOnSubscribeOn) { + folly::ScopedEventBaseThread subscribeOnThread; + folly::EventBase timerEvb; + auto flowable = Flowable<>::range(1, 2) + ->subscribeOn(*subscribeOnThread.getEventBase()) + ->observeOn(timerEvb) + ->timeout( + timerEvb, + std::chrono::milliseconds(10), + std::chrono::milliseconds(100)); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // timeout onNext + + flowable->subscribe(subscriber); + flowable.reset(); + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + // first one is consumed + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_SameThread) { + folly::EventBase timerEvb; + auto flowable = Flowable<>::range(1, 2) + ->subscribeOn(timerEvb) + ->observeOn(timerEvb) + ->timeout( + timerEvb, + std::chrono::milliseconds(10), + std::chrono::milliseconds(100)); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // timeout onNext + + flowable->subscribe(subscriber); + flowable.reset(); + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + // first one is consumed + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, SwapException) { + auto flowable = Flowable::error(std::runtime_error("private")); + flowable = flowable->map( + [](auto&& a) { return a; }, + [](auto) { return std::runtime_error("public"); }); + + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "public"); +} + +#if FOLLY_HAS_COROUTINES +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorIntType) { + folly::ScopedEventBaseThread th; + const int length = 5; + folly::Baton<> baton; + auto stream = + folly::coro::co_invoke([]() -> folly::coro::AsyncGenerator { + for (int i = 0; i < length; i++) { + co_yield std::move(i); + } + }); + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](int i) { EXPECT_EQ(expected_i++, i); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 2); + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorStringType) { + folly::ScopedEventBaseThread th; + const int length = 5; + folly::Baton<> baton; + auto stream = folly::coro::co_invoke( + []() -> folly::coro::AsyncGenerator { + for (int i = 0; i < length; i++) { + co_yield folly::to(i); + } + }); + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](std::string i) { EXPECT_EQ(expected_i++, folly::to(i)); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 2); + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorReverseFulfill) { + folly::ScopedEventBaseThread th; + folly::ScopedEventBaseThread pth; + const int length = 5; + std::vector> vp(length); + + int i = 0; + folly::Baton<> baton; + auto stream = + folly::coro::co_invoke([&]() -> folly::coro::AsyncGenerator { + while (i < length) { + co_yield co_await vp[i++].getSemiFuture(); + } + }); + + // intentionally let promised fulfilled in reverse order, but the result + // should come back to stream in order + for (int i = length - 1; i >= 0; i--) { + pth.add([&vp, i]() { vp[i].setValue(i); }); + } + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](int i) { EXPECT_EQ(expected_i++, i); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 2); + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorLambdaGaptureVariable) { + folly::ScopedEventBaseThread th; + std::string t = "test"; + folly::Baton<> baton; + auto stream = folly::coro::co_invoke( + [&, t = std::move(t) ]() mutable + -> folly::coro::AsyncGenerator { + co_yield std::move(t); + co_return; + }); + + std::string result; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](std::string t) { result = t; }, + [&](folly::exception_wrapper ex) { + ADD_FAILURE() << "on Error " << ex.what(); + baton.post(); + }, + [&] { baton.post(); }, + 2); + baton.wait(); + + EXPECT_EQ("test", result); +} + +TEST(AsyncGeneratorShimTest, ShouldNotHaveCoAwaitMoreThanOnce) { + folly::ScopedEventBaseThread th; + folly::ScopedEventBaseThread pth; + const int length = 5; + std::vector> vp(length); + + int i = 0; + folly::Baton<> baton; + auto stream = + folly::coro::co_invoke([&]() -> folly::coro::AsyncGenerator { + while (i < length) { + co_yield co_await vp[i++].getSemiFuture(); + } + }); + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](int i) { EXPECT_EQ(expected_i++, i); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 5); + // subscribe before fulfill future, expecting co_await on the future will + // happen before setValue() + for (int i = 0; i < length; i++) { + pth.add([&vp, i]() { vp[i].setValue(i); }); + } + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorPreemptiveCancel) { + folly::ScopedEventBaseThread th; + folly::coro::Baton b; + bool canceled = false; + auto stream = folly::coro:: + co_invoke([&]() -> folly::coro::AsyncGenerator { + // cancelCallback will be execute in the same event loop + // as async generator + folly::CancellationCallback cancelCallback( + co_await folly::coro::co_current_cancellation_token, [&]() { + canceled = true; + b.post(); + }); + co_yield "first"; + co_await b; + if (!canceled) { + co_yield "never_reach"; + } + }); + + struct TestSubscriber : public Subscriber { + void onSubscribe(std::shared_ptr s) override final { + s->request(2); + s_ = std::move(s); + } + + void onNext(std::string s) override { + EXPECT_EQ("first", s); + b1_.post(); + } + void onComplete() override { + b2_.post(); + } + void onError(folly::exception_wrapper) override {} + std::shared_ptr s_; + folly::Baton<> b1_, b2_; + }; + auto subscriber = std::make_shared(); + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe(subscriber); + subscriber->b1_.wait(); + subscriber->s_->cancel(); + subscriber->b2_.wait(); + EXPECT_TRUE(canceled); +} +#endif diff --git a/yarpl/test/MocksTest.cpp b/yarpl/test/MocksTest.cpp index 18655479a..db97df1ff 100644 --- a/yarpl/test/MocksTest.cpp +++ b/yarpl/test/MocksTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "yarpl/test_utils/Mocks.h" @@ -15,12 +27,12 @@ TEST(MocksTest, SelfManagedMocks) { int value = 42; MockFlowable flowable; - auto subscription = yarpl::make_ref(); - auto subscriber = yarpl::make_ref>(0); + auto subscription = std::make_shared(); + auto subscriber = std::make_shared>(0); { InSequence dummy; EXPECT_CALL(flowable, subscribe_(_)) - .WillOnce(Invoke([&](yarpl::Reference> consumer) { + .WillOnce(Invoke([&](std::shared_ptr> consumer) { consumer->onSubscribe(subscription); })); EXPECT_CALL(*subscriber, onSubscribe_(_)); diff --git a/yarpl/test/Observable_test.cpp b/yarpl/test/Observable_test.cpp index bfc845507..f2c31a1ba 100644 --- a/yarpl/test/Observable_test.cpp +++ b/yarpl/test/Observable_test.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -7,8 +19,8 @@ #include #include "yarpl/Observable.h" +#include "yarpl/flowable/Flowable.h" #include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscribers.h" #include "yarpl/test_utils/Mocks.h" #include "yarpl/test_utils/Tuple.h" @@ -34,12 +46,14 @@ class CollectingObserver : public Observer { void onComplete() override { Observer::onComplete(); complete_ = true; + terminated_ = true; } void onError(folly::exception_wrapper ex) override { Observer::onError(ex); error_ = true; errorMsg_ = ex.get_exception()->what(); + terminated_ = true; } std::vector& values() { @@ -58,39 +72,57 @@ class CollectingObserver : public Observer { return errorMsg_; } + /** + * Block the current thread until either onSuccess or onError is called. + */ + void awaitTerminalEvent( + std::chrono::milliseconds ms = std::chrono::seconds{1}) { + // now block this thread + std::unique_lock lk(m_); + // if shutdown gets implemented this would then be released by it + if (!terminalEventCV_.wait_for(lk, ms, [this] { return terminated_; })) { + throw std::runtime_error("timeout in awaitTerminalEvent"); + } + } + private: std::vector values_; std::string errorMsg_; bool complete_{false}; bool error_{false}; + + bool terminated_{false}; + std::mutex m_; + std::condition_variable terminalEventCV_; }; /// Construct a pipeline with a collecting observer against the supplied /// observable. Return the items that were sent to the observer. If some /// exception was sent, the exception is thrown. template -std::vector run(Reference> observable) { - auto collector = make_ref>(); +std::vector run(std::shared_ptr> observable) { + auto collector = std::make_shared>(); observable->subscribe(collector); + collector->awaitTerminalEvent(std::chrono::seconds(1)); return std::move(collector->values()); } } // namespace TEST(Observable, SingleOnNext) { - auto a = Observable::create([](Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(1); obs->onComplete(); }); std::vector v; a->subscribe( - Observers::create([&v](const int& value) { v.push_back(value); })); + Observer::create([&v](const int& value) { v.push_back(value); })); EXPECT_EQ(v.at(0), 1); } TEST(Observable, MultiOnNext) { - auto a = Observable::create([](Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(1); obs->onNext(2); obs->onNext(3); @@ -99,7 +131,7 @@ TEST(Observable, MultiOnNext) { std::vector v; a->subscribe( - Observers::create([&v](const int& value) { v.push_back(value); })); + Observer::create([&v](const int& value) { v.push_back(value); })); EXPECT_EQ(v.at(0), 1); EXPECT_EQ(v.at(1), 2); @@ -108,11 +140,11 @@ TEST(Observable, MultiOnNext) { TEST(Observable, OnError) { std::string errorMessage("DEFAULT->No Error Message"); - auto a = Observable::create([](Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onError(std::runtime_error("something broke!")); }); - a->subscribe(Observers::create( + a->subscribe(Observer::create( [](int) { /* do nothing */ }, [&errorMessage](folly::exception_wrapper ex) { errorMessage = ex.get_exception()->what(); @@ -125,14 +157,14 @@ TEST(Observable, OnError) { * Assert that all items passed through the Observable get destroyed */ TEST(Observable, ItemsCollectedSynchronously) { - auto a = Observable::create([](Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(Tuple{1, 2}); obs->onNext(Tuple{2, 3}); obs->onNext(Tuple{3, 4}); obs->onComplete(); }); - a->subscribe(Observers::create([](const Tuple& value) { + a->subscribe(Observer::create([](const Tuple& value) { std::cout << "received value " << value.a << std::endl; })); } @@ -145,7 +177,7 @@ TEST(Observable, ItemsCollectedSynchronously) { * in a Vector which could then be consumed on another thread. */ TEST(DISABLED_Observable, ItemsCollectedAsynchronously) { - auto a = Observable::create([](Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { std::cout << "-----------------------------" << std::endl; obs->onNext(Tuple{1, 2}); std::cout << "-----------------------------" << std::endl; @@ -158,18 +190,17 @@ TEST(DISABLED_Observable, ItemsCollectedAsynchronously) { std::vector v; v.reserve(10); // otherwise it resizes and copies on each push_back - a->subscribe(Observers::create([&v](const Tuple& value) { + a->subscribe(Observer::create([&v](const Tuple& value) { std::cout << "received value " << value.a << std::endl; // copy into vector v.push_back(value); std::cout << "done pushing into vector" << std::endl; })); - // expect that 3 instances were originally created, then 3 more when copying - EXPECT_EQ(6, Tuple::createdCount); - // expect that 3 instances still exist in the vector, so only 3 destroyed so - // far - EXPECT_EQ(3, Tuple::destroyedCount); + // 3 copy & 3 move and 3 more copy constructed + EXPECT_EQ(9, Tuple::createdCount); + // 3 still exists in the vector, 6 destroyed + EXPECT_EQ(6, Tuple::destroyedCount); std::cout << "Leaving block now so Vector should release Tuples..." << std::endl; @@ -179,7 +210,7 @@ class TakeObserver : public Observer { private: const int limit; int count = 0; - Reference subscription_; + std::shared_ptr subscription_; std::vector& v; public: @@ -187,7 +218,8 @@ class TakeObserver : public Observer { v.reserve(5); } - void onSubscribe(Reference s) override { + void onSubscribe( + std::shared_ptr s) override { subscription_ = std::move(s); } @@ -207,7 +239,7 @@ class TakeObserver : public Observer { // assert behavior of onComplete after subscription.cancel TEST(Observable, SubscriptionCancellation) { std::atomic_int emitted{0}; - auto a = Observable::create([&](Reference> obs) { + auto a = Observable::create([&](std::shared_ptr> obs) { int i = 0; while (!obs->isUnsubscribed() && i <= 10) { emitted++; @@ -220,7 +252,7 @@ TEST(Observable, SubscriptionCancellation) { }); std::vector v; - a->subscribe(make_ref(2, v)); + a->subscribe(std::make_shared(2, v)); EXPECT_EQ((unsigned long)2, v.size()); EXPECT_EQ(2, emitted); } @@ -234,7 +266,7 @@ TEST(Observable, CancelFromDifferentThread) { std::atomic cancelled2{false}; std::thread t; - auto a = Observable::create([&](Reference> obs) { + auto a = Observable::create([&](std::shared_ptr> obs) { t = std::thread([obs, &emitted, &cancelled1]() { obs->addSubscription([&]() { cancelled1 = true; }); while (!obs->isUnsubscribed()) { @@ -259,12 +291,13 @@ TEST(Observable, CancelFromDifferentThread) { } TEST(Observable, toFlowableDrop) { - auto a = Observables::range(1, 10); + auto a = Observable<>::range(1, 10); auto f = a->toFlowable(BackpressureStrategy::DROP); std::vector v; - auto subscriber = make_ref>>(5); + auto subscriber = + std::make_shared>>(5); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onNext_(_)) @@ -277,7 +310,7 @@ TEST(Observable, toFlowableDrop) { } TEST(Observable, toFlowableDropWithCancel) { - auto a = Observable::create([](Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { int i = 0; while (!obs->isUnsubscribed()) { obs->onNext(++i); @@ -287,19 +320,26 @@ TEST(Observable, toFlowableDropWithCancel) { auto f = a->toFlowable(BackpressureStrategy::DROP); std::vector v; - f->take(5)->subscribe(yarpl::flowable::Subscribers::create( + f->take(5)->subscribe(yarpl::flowable::Subscriber::create( [&v](const int& value) { v.push_back(value); })); EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); } TEST(Observable, toFlowableErrorStrategy) { - auto a = Observables::range(1, 10); + auto a = Observable::createEx([](auto observer, auto subscription) { + int64_t i = 1; + for (; !subscription->isCancelled() && i <= 10; ++i) { + observer->onNext(i); + } + EXPECT_EQ(7, i); + }); auto f = a->toFlowable(BackpressureStrategy::ERROR); std::vector v; - auto subscriber = make_ref>>(5); + auto subscriber = + std::make_shared>>(5); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onNext_(_)) @@ -316,12 +356,13 @@ TEST(Observable, toFlowableErrorStrategy) { } TEST(Observable, toFlowableBufferStrategy) { - auto a = Observables::range(1, 10); + auto a = Observable<>::range(1, 10); auto f = a->toFlowable(BackpressureStrategy::BUFFER); std::vector v; - auto subscriber = make_ref>>(5); + auto subscriber = + std::make_shared>>(5); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onNext_(_)) @@ -332,16 +373,123 @@ TEST(Observable, toFlowableBufferStrategy) { EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); subscriber->subscription()->request(5); - EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5, 6, 7, 8, 9})); + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +TEST(Observable, toFlowableBufferStrategyLimit) { + std::shared_ptr> observer; + std::shared_ptr subscription; + + auto a = Observable::createEx([&](auto o, auto s) { + observer = std::move(o); + subscription = std::move(s); + }); + auto f = + a->toFlowable(std::make_shared>(3)); + + std::vector v; + + auto subscriber = + std::make_shared>>(5); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + + EXPECT_FALSE(observer); + EXPECT_FALSE(subscription); + + f->subscribe(subscriber); + + EXPECT_TRUE(observer); + EXPECT_TRUE(subscription); + + for (size_t i = 1; i <= 5; ++i) { + observer->onNext(i); + } + + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); + + observer->onNext(6); + observer->onNext(7); + observer->onNext(8); + + EXPECT_FALSE(observer->isUnsubscribedOrTerminated()); + EXPECT_FALSE(subscription->isCancelled()); + + EXPECT_CALL(*subscriber, onError_(_)) + .WillOnce(Invoke([&](folly::exception_wrapper ex) { + EXPECT_TRUE(ex.is_compatible_with< + yarpl::flowable::MissingBackpressureException>()); + })); + + observer->onNext(9); + + EXPECT_TRUE(observer->isUnsubscribedOrTerminated()); + EXPECT_TRUE(subscription->isCancelled()); +} + +TEST(Observable, toFlowableBufferStrategyStress) { + std::shared_ptr> observer; + auto a = Observable::createEx( + [&](auto o, auto) { observer = std::move(o); }); + auto f = a->toFlowable(BackpressureStrategy::BUFFER); + + std::vector v; + std::atomic tokens{0}; + + auto subscriber = + std::make_shared>>(0); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onComplete_()); + + f->subscribe(subscriber); + EXPECT_TRUE(observer); + + constexpr size_t kNumElements = 100000; + + std::thread nextThread([&] { + for (size_t i = 0; i < kNumElements; ++i) { + while (tokens.load() < -5) { + std::this_thread::yield(); + } + + observer->onNext(i); + --tokens; + } + observer->onComplete(); + }); + + std::thread requestThread([&] { + for (size_t i = 0; i < kNumElements; ++i) { + while (tokens.load() > 5) { + std::this_thread::yield(); + } + + subscriber->subscription()->request(1); + ++tokens; + } + }); + + nextThread.join(); + requestThread.join(); + + for (size_t i = 0; i < kNumElements; ++i) { + CHECK_EQ(i, v[i]); + } } TEST(Observable, toFlowableLatestStrategy) { - auto a = Observables::range(1, 10); + auto a = Observable<>::range(1, 10); auto f = a->toFlowable(BackpressureStrategy::LATEST); std::vector v; - auto subscriber = make_ref>>(5); + auto subscriber = + std::make_shared>>(5); EXPECT_CALL(*subscriber, onSubscribe_(_)); EXPECT_CALL(*subscriber, onNext_(_)) @@ -352,24 +500,24 @@ TEST(Observable, toFlowableLatestStrategy) { EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); subscriber->subscription()->request(5); - EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5, 9})); + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5, 10})); } TEST(Observable, Just) { - EXPECT_EQ(run(Observables::just(22)), std::vector{22}); + EXPECT_EQ(run(Observable<>::just(22)), std::vector{22}); EXPECT_EQ( - run(Observables::justN({12, 34, 56, 98})), + run(Observable<>::justN({12, 34, 56, 98})), std::vector({12, 34, 56, 98})); EXPECT_EQ( - run(Observables::justN({"ab", "pq", "yz"})), + run(Observable<>::justN({"ab", "pq", "yz"})), std::vector({"ab", "pq", "yz"})); } TEST(Observable, SingleMovable) { auto value = std::make_unique(123456); - auto observable = Observables::justOnce(std::move(value)); - EXPECT_EQ(std::size_t{1}, observable->count()); + auto observable = Observable<>::justOnce(std::move(value)); + EXPECT_EQ(std::size_t{1}, observable.use_count()); auto values = run(std::move(observable)); EXPECT_EQ(values.size(), size_t(1)); @@ -378,14 +526,14 @@ TEST(Observable, SingleMovable) { } TEST(Observable, MapWithException) { - auto observable = Observables::justN({1, 2, 3, 4})->map([](int n) { + auto observable = Observable<>::justN({1, 2, 3, 4})->map([](int n) { if (n > 2) { throw std::runtime_error{"Too big!"}; } return n; }); - auto observer = yarpl::make_ref>(); + auto observer = std::make_shared>(); observable->subscribe(observer); EXPECT_EQ(observer->values(), std::vector({1, 2})); @@ -394,12 +542,12 @@ TEST(Observable, MapWithException) { } TEST(Observable, Range) { - auto observable = Observables::range(10, 14); + auto observable = Observable<>::range(10, 4); EXPECT_EQ(run(std::move(observable)), std::vector({10, 11, 12, 13})); } TEST(Observable, RangeWithMap) { - auto observable = Observables::range(1, 4) + auto observable = Observable<>::range(1, 3) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return std::to_string(v); }); @@ -408,33 +556,33 @@ TEST(Observable, RangeWithMap) { } TEST(Observable, RangeWithReduce) { - auto observable = Observables::range(0, 10)->reduce( + auto observable = Observable<>::range(0, 10)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); EXPECT_EQ(run(std::move(observable)), std::vector({45})); } TEST(Observable, RangeWithReduceByMultiplication) { - auto observable = Observables::range(0, 10)->reduce( + auto observable = Observable<>::range(0, 10)->reduce( [](int64_t acc, int64_t v) { return acc * v; }); EXPECT_EQ(run(std::move(observable)), std::vector({0})); - observable = Observables::range(1, 10)->reduce( + observable = Observable<>::range(1, 10)->reduce( [](int64_t acc, int64_t v) { return acc * v; }); EXPECT_EQ( run(std::move(observable)), - std::vector({2 * 3 * 4 * 5 * 6 * 7 * 8 * 9})); + std::vector({1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10})); } TEST(Observable, RangeWithReduceOneItem) { - auto observable = Observables::range(5, 6)->reduce( + auto observable = Observable<>::range(5, 1)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); EXPECT_EQ(run(std::move(observable)), std::vector({5})); } TEST(Observable, RangeWithReduceNoItem) { - auto observable = Observables::range(0, 0)->reduce( + auto observable = Observable<>::range(0, 0)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); - auto collector = make_ref>(); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->error(), false); EXPECT_EQ(collector->values(), std::vector({})); @@ -442,7 +590,7 @@ TEST(Observable, RangeWithReduceNoItem) { TEST(Observable, RangeWithReduceToBiggerType) { auto observable = - Observables::range(5, 6) + Observable<>::range(5, 1) ->map([](int64_t v) { return (int32_t)v; }) ->reduce([](int64_t acc, int32_t v) { return acc + v; }); EXPECT_EQ(run(std::move(observable)), std::vector({5})); @@ -450,7 +598,7 @@ TEST(Observable, RangeWithReduceToBiggerType) { TEST(Observable, StringReduce) { auto observable = - Observables::justN( + Observable<>::justN( {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) ->reduce([](std::string acc, std::string v) { return acc + v; }); EXPECT_EQ( @@ -459,28 +607,45 @@ TEST(Observable, StringReduce) { TEST(Observable, RangeWithFilter) { auto observable = - Observables::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); + Observable<>::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(observable)), std::vector({1, 3, 5, 7, 9})); } TEST(Observable, SimpleTake) { EXPECT_EQ( - run(Observables::range(0, 100)->take(3)), + run(Observable<>::range(0, 100)->take(3)), std::vector({0, 1, 2})); + + EXPECT_EQ( + run(Observable<>::range(0, 100)->take(0)), std::vector({})); +} + +TEST(Observable, TakeError) { + auto take0 = + Observable::error(std::runtime_error("something broke!")) + ->take(0); + + auto collector = std::make_shared>(); + take0->subscribe(collector); + + EXPECT_EQ(collector->values(), std::vector({})); + EXPECT_TRUE(collector->complete()); + EXPECT_FALSE(collector->error()); } TEST(Observable, SimpleSkip) { EXPECT_EQ( - run(Observables::range(0, 10)->skip(8)), std::vector({8, 9})); + run(Observable<>::range(0, 10)->skip(8)), std::vector({8, 9})); } TEST(Observable, OverflowSkip) { - EXPECT_EQ(run(Observables::range(0, 10)->skip(12)), std::vector({})); + EXPECT_EQ( + run(Observable<>::range(0, 10)->skip(12)), std::vector({})); } TEST(Observable, IgnoreElements) { - auto collector = make_ref>(); - auto observable = Observables::range(0, 105)->ignoreElements()->map( + auto collector = std::make_shared>(); + auto observable = Observable<>::range(0, 105)->ignoreElements()->map( [](int64_t v) { return v + 1; }); observable->subscribe(collector); @@ -491,8 +656,8 @@ TEST(Observable, IgnoreElements) { TEST(Observable, Error) { auto observable = - Observables::error(std::runtime_error("something broke!")); - auto collector = make_ref>(); + Observable::error(std::runtime_error("something broke!")); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->complete(), false); @@ -502,8 +667,8 @@ TEST(Observable, Error) { TEST(Observable, ErrorPtr) { auto observable = - Observables::error(std::runtime_error("something broke!")); - auto collector = make_ref>(); + Observable::error(std::runtime_error("something broke!")); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->complete(), false); @@ -512,8 +677,8 @@ TEST(Observable, ErrorPtr) { } TEST(Observable, Empty) { - auto observable = Observables::empty(); - auto collector = make_ref>(); + auto observable = Observable::empty(); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->complete(), true); @@ -521,10 +686,10 @@ TEST(Observable, Empty) { } TEST(Observable, ObserversComplete) { - auto observable = Observables::empty(); + auto observable = Observable::empty(); bool completed = false; - auto observer = Observers::create( + auto observer = Observer::create( [](int) { unreachable(); }, [](folly::exception_wrapper) { unreachable(); }, [&] { completed = true; }); @@ -534,10 +699,10 @@ TEST(Observable, ObserversComplete) { } TEST(Observable, ObserversError) { - auto observable = Observables::error(std::runtime_error("Whoops")); + auto observable = Observable::error(std::runtime_error("Whoops")); bool errored = false; - auto observer = Observers::create( + auto observer = Observer::create( [](int) { unreachable(); }, [&](folly::exception_wrapper) { errored = true; }, [] { unreachable(); }); @@ -547,30 +712,61 @@ TEST(Observable, ObserversError) { } TEST(Observable, CancelReleasesObjects) { - auto lambda = [](Reference> observer) { + auto lambda = [](std::shared_ptr> observer) { // we will send nothing }; auto observable = Observable::create(std::move(lambda)); - auto collector = make_ref>(); + auto collector = std::make_shared>(); observable->subscribe(collector); } -class InfiniteAsyncTestOperator - : public ObservableOperator { - using Super = ObservableOperator; +TEST(Observable, CompleteReleasesObjects) { + auto shared = std::make_shared>>(); + { + auto observable = Observable::create( + [shared](std::shared_ptr> observer) { + *shared = observer; + // onComplete releases the DoOnComplete operator + // so the lambda params will be freed + observer->onComplete(); + }) + ->doOnComplete([shared] {}); + observable->subscribe(); + } + EXPECT_EQ(1, shared->use_count()); +} + +TEST(Observable, ErrorReleasesObjects) { + auto shared = std::make_shared>>(); + { + auto observable = Observable::create( + [shared](std::shared_ptr> observer) { + *shared = observer; + // onError releases the DoOnComplete operator + // so the lambda params will be freed + observer->onError(std::runtime_error("error")); + }) + ->doOnComplete([shared] { /*never executed*/ }); + observable->subscribe(); + } + EXPECT_EQ(1, shared->use_count()); +} + +class InfiniteAsyncTestOperator : public ObservableOperator { + using Super = ObservableOperator; public: InfiniteAsyncTestOperator( - Reference> upstream, + std::shared_ptr> upstream, MockFunction& checkpoint) - : Super(std::move(upstream)), checkpoint_(checkpoint) {} + : upstream_(std::move(upstream)), checkpoint_(checkpoint) {} - Reference subscribe( - Reference> observer) override { - auto subscription = make_ref( - this->ref_from_this(this), std::move(observer), checkpoint_); - Super::upstream_->subscribe( + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(std::move(observer), checkpoint_); + upstream_->subscribe( // Note: implicit cast to a reference to a observer. subscription); return subscription; @@ -580,24 +776,22 @@ class InfiniteAsyncTestOperator class TestSubscription : public Super::OperatorSubscription { using SuperSub = typename Super::OperatorSubscription; - ~TestSubscription() { + public: + ~TestSubscription() override { t_.join(); } - public: void sendSuperNext() { // workaround for gcc bug 58972. SuperSub::observerOnNext(1); } TestSubscription( - Reference observable, - Reference> observer, + std::shared_ptr> observer, MockFunction& checkpoint) - : SuperSub(std::move(observable), std::move(observer)), - checkpoint_(checkpoint) {} + : SuperSub(std::move(observer)), checkpoint_(checkpoint) {} - void onSubscribe(yarpl::Reference subscription) override { + void onSubscribe(std::shared_ptr subscription) override { SuperSub::onSubscribe(std::move(subscription)); t_ = std::thread([this]() { while (!isCancelled()) { @@ -606,12 +800,13 @@ class InfiniteAsyncTestOperator checkpoint_.Call(); }); } - void onNext(int value) override {} + void onNext(int /*value*/) override {} std::thread t_; MockFunction& checkpoint_; }; + std::shared_ptr> upstream_; MockFunction& checkpoint_; }; @@ -626,21 +821,24 @@ TEST(Observable, DISABLED_CancelSubscriptionChain) { MockFunction checkpoint2; MockFunction checkpoint3; std::thread t; - auto infinite1 = Observable::create([&](Reference> obs) { - EXPECT_CALL(checkpoint, Call()).Times(1); - EXPECT_CALL(checkpoint2, Call()).Times(1); - EXPECT_CALL(checkpoint3, Call()).Times(1); - t = std::thread([obs, &emitted, &checkpoint]() { - while (!obs->isUnsubscribed()) { - ++emitted; - obs->onNext(0); - } - checkpoint.Call(); - }); - }); + auto infinite1 = + Observable::create([&](std::shared_ptr> obs) { + EXPECT_CALL(checkpoint, Call()).Times(1); + EXPECT_CALL(checkpoint2, Call()).Times(1); + EXPECT_CALL(checkpoint3, Call()).Times(1); + t = std::thread([obs, &emitted, &checkpoint]() { + while (!obs->isUnsubscribed()) { + ++emitted; + obs->onNext(0); + } + checkpoint.Call(); + }); + }); auto infinite2 = infinite1->skip(1)->skip(1); - auto test1 = make_ref(infinite2, checkpoint2); - auto test2 = make_ref(test1->skip(1), checkpoint3); + auto test1 = + std::make_shared(infinite2, checkpoint2); + auto test2 = + std::make_shared(test1->skip(1), checkpoint3); auto skip = test2->skip(8); auto subscription = skip->subscribe([](int) {}); @@ -656,8 +854,7 @@ TEST(Observable, DISABLED_CancelSubscriptionChain) { } TEST(Observable, DoOnSubscribeTest) { - auto a = Observable::create( - [](Reference> obs) { obs->onComplete(); }); + auto a = Observable::empty(); MockFunction checkpoint; EXPECT_CALL(checkpoint, Call()); @@ -667,16 +864,14 @@ TEST(Observable, DoOnSubscribeTest) { TEST(Observable, DoOnNextTest) { std::vector values; - auto observable = Observables::range(10, 14)->doOnNext( + auto observable = Observable<>::range(10, 14)->doOnNext( [&](int64_t v) { values.push_back(v); }); auto values2 = run(std::move(observable)); EXPECT_EQ(values, values2); } TEST(Observable, DoOnErrorTest) { - auto a = Observable::create([](Reference> obs) { - obs->onError(std::runtime_error("something broke!")); - }); + auto a = Observable::error(std::runtime_error("something broke!")); MockFunction checkpoint; EXPECT_CALL(checkpoint, Call()); @@ -685,8 +880,7 @@ TEST(Observable, DoOnErrorTest) { } TEST(Observable, DoOnTerminateTest) { - auto a = Observable::create( - [](Reference> obs) { obs->onComplete(); }); + auto a = Observable::empty(); MockFunction checkpoint; EXPECT_CALL(checkpoint, Call()); @@ -695,9 +889,7 @@ TEST(Observable, DoOnTerminateTest) { } TEST(Observable, DoOnTerminate2Test) { - auto a = Observable::create([](Reference> obs) { - obs->onError(std::runtime_error("something broke!")); - }); + auto a = Observable::error(std::runtime_error("something broke!")); MockFunction checkpoint; EXPECT_CALL(checkpoint, Call()); @@ -706,7 +898,8 @@ TEST(Observable, DoOnTerminate2Test) { } TEST(Observable, DoOnEachTest) { - auto a = Observable::create([](Reference> obs) { + // TODO(lehecka): rewrite with concatWith + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(5); obs->onError(std::runtime_error("something broke!")); }); @@ -717,7 +910,8 @@ TEST(Observable, DoOnEachTest) { } TEST(Observable, DoOnTest) { - auto a = Observable::create([](Reference> obs) { + // TODO(lehecka): rewrite with concatWith + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(5); obs->onError(std::runtime_error("something broke!")); }); @@ -736,3 +930,163 @@ TEST(Observable, DoOnTest) { [&](const auto&) { checkpoint2.Call(); }) ->subscribe(); } + +TEST(Observable, DoOnCancelTest) { + auto a = Observable<>::range(1, 10); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnCancel([&]() { checkpoint.Call(); })->take(1)->subscribe(); +} + +TEST(Observable, DeferTest) { + int switchValue = 0; + auto observable = Observable::defer([&]() { + if (switchValue == 0) { + return Observable<>::range(1, 1); + } else { + return Observable<>::range(3, 1); + } + }); + + EXPECT_EQ(run(observable), std::vector({1})); + switchValue = 1; + EXPECT_EQ(run(observable), std::vector({3})); +} + +TEST(Observable, DeferExceptionTest) { + auto observable = + Observable::defer([&]() -> std::shared_ptr> { + throw std::runtime_error{"Too big!"}; + }); + + auto observer = std::make_shared>(); + observable->subscribe(observer); + + EXPECT_TRUE(observer->error()); + EXPECT_EQ(observer->errorMsg(), "Too big!"); +} + +TEST(Observable, ConcatWithTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto combined = first->concatWith(second); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + // Subscribe again + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); +} + +TEST(Observable, ConcatWithMultipleTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto third = Observable<>::range(10, 2); + auto fourth = Observable<>::range(15, 2); + auto firstSecond = first->concatWith(second); + auto thirdFourth = third->concatWith(fourth); + auto combined = firstSecond->concatWith(thirdFourth); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(Observable, ConcatWithExceptionTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto third = Observable::error(std::runtime_error("error")); + + auto combined = first->concatWith(second)->concatWith(third); + + auto observer = std::make_shared>(); + combined->subscribe(observer); + + EXPECT_EQ(observer->values(), std::vector({1, 2, 5, 6})); + EXPECT_TRUE(observer->error()); + EXPECT_EQ(observer->errorMsg(), "error"); +} + +TEST(Observable, ConcatWithCancelTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto combined = first->concatWith(second); + auto take0 = combined->take(0); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + EXPECT_EQ(run(take0), std::vector({})); +} + +TEST(Observable, ConcatWithCompleteAtSubscription) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + + auto combined = first->concatWith(second)->take(0); + EXPECT_EQ(run(combined), std::vector({})); +} + +TEST(Observable, ConcatWithVarArgsTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto third = Observable<>::range(10, 2); + auto fourth = Observable<>::range(15, 2); + + auto combined = first->concatWith(second, third, fourth); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(Observable, ConcatTest) { + auto combined = Observable::concat( + Observable<>::range(1, 2), Observable<>::range(5, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + + // Observable::concat shoud not accept one parameter! + // Next line should cause compiler failure: OK! + // combined = Observable::concat(Observable<>::range(1, 2)); + + combined = Observable::concat( + Observable<>::range(1, 2), + Observable<>::range(5, 2), + Observable<>::range(10, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11})); + + combined = Observable::concat( + Observable<>::range(1, 2), + Observable<>::range(5, 2), + Observable<>::range(10, 2), + Observable<>::range(15, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(Observable, ToFlowableConcat) { + // Concat a flowable with an observable. + // Convert the observable to flowable before concat. + // Use ERROR as backpressure strategy. + + // Test: Request only as much as the initial flowable provides + // - Check that the observable is not subscribed to so it doesn't flood + + auto a = yarpl::flowable::Flowable<>::range(1, 1); + auto b = Observable<>::range(2, 9)->toFlowable(BackpressureStrategy::ERROR); + + auto c = a->concatWith(b); + + uint32_t request = 1; + auto subscriber = + std::make_shared>>(request); + + std::vector v; + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onError_(_)).Times(0); + + c->subscribe(subscriber); + + // As only 1 item is requested, the second flowable will not be subscribed. So + // the observer will not flood the stream and cause ERROR. + EXPECT_EQ(v, std::vector({1})); + + // Now flood the stream + EXPECT_CALL(*subscriber, onError_(_)); + subscriber->subscription()->request(1); +} diff --git a/yarpl/test/PublishProcessorTest.cpp b/yarpl/test/PublishProcessorTest.cpp new file mode 100644 index 000000000..802f41e24 --- /dev/null +++ b/yarpl/test/PublishProcessorTest.cpp @@ -0,0 +1,219 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/PublishProcessor.h" +#include +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +using namespace yarpl; +using namespace yarpl::flowable; + +TEST(PublishProcessorTest, OnNextTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 3})); + + // cancel the subscription as its a cyclic reference + subscriber->cancel(); +} + +TEST(PublishProcessorTest, OnCompleteTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + pp->onComplete(); + + EXPECT_EQ( + subscriber->values(), + std::vector({ + 1, + 2, + })); + EXPECT_TRUE(subscriber->isComplete()); + + auto subscriber2 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + EXPECT_EQ(subscriber2->values(), std::vector()); + EXPECT_TRUE(subscriber2->isComplete()); +} + +TEST(PublishProcessorTest, OnErrorTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + pp->onError(std::runtime_error("error!")); + + EXPECT_EQ( + subscriber->values(), + std::vector({ + 1, + 2, + })); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "error!"); + + auto subscriber2 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + EXPECT_EQ(subscriber2->values(), std::vector()); + EXPECT_TRUE(subscriber2->isError()); +} + +TEST(PublishProcessorTest, OnNextMultipleSubscribersTest) { + auto pp = PublishProcessor::create(); + + auto subscriber1 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber1); + auto subscriber2 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + + EXPECT_EQ(subscriber1->values(), std::vector({1, 2, 3})); + EXPECT_EQ(subscriber2->values(), std::vector({1, 2, 3})); + + subscriber1->cancel(); + subscriber2->cancel(); +} + +TEST(PublishProcessorTest, OnNextSlowSubscriberTest) { + auto pp = PublishProcessor::create(); + + auto subscriber1 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber1); + auto subscriber2 = std::make_shared>(1); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + + EXPECT_EQ(subscriber1->values(), std::vector({1, 2, 3})); + subscriber1->cancel(); + + EXPECT_EQ(subscriber2->values(), std::vector({1})); + EXPECT_TRUE(subscriber2->isError()); + EXPECT_EQ( + subscriber2->exceptionWrapper().type(), + typeid(MissingBackpressureException)); +} + +TEST(PublishProcessorTest, CancelTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + + subscriber->cancel(); + + pp->onNext(3); + pp->onNext(4); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2})); + + subscriber->onComplete(); // to break any reference cycles +} + +TEST(PublishProcessorTest, OnMultipleSubscribersMultithreadedWithErrorTest) { + auto pp = PublishProcessor::create(); + + std::vector threads; + std::atomic threadsDone{0}; + + for (int i = 0; i < 100; i++) { + threads.push_back(std::thread([&] { + for (int j = 0; j < 100; j++) { + auto subscriber = std::make_shared>(1); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + subscriber->awaitTerminalEvent(std::chrono::milliseconds(500)); + + EXPECT_EQ(subscriber->values().size(), 1ULL); + + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ( + subscriber->exceptionWrapper().type(), + typeid(MissingBackpressureException)); + } + ++threadsDone; + })); + } + + int k = 0; + while (threadsDone < threads.size()) { + pp->onNext(k++); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(PublishProcessorTest, OnMultipleSubscribersMultithreadedTest) { + auto pp = PublishProcessor::create(); + + std::vector threads; + std::atomic subscribersReady{0}; + std::atomic threadsDone{0}; + + for (int i = 0; i < 100; i++) { + threads.push_back(std::thread([&] { + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + ++subscribersReady; + subscriber->awaitTerminalEvent(std::chrono::milliseconds(50)); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 3, 4, 5})); + EXPECT_FALSE(subscriber->isError()); + EXPECT_TRUE(subscriber->isComplete()); + + ++threadsDone; + })); + } + + while (subscribersReady < threads.size()) + ; + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + pp->onNext(4); + pp->onNext(5); + pp->onComplete(); + + for (auto& thread : threads) { + thread.join(); + } +} diff --git a/yarpl/test/RefcountedTest.cpp b/yarpl/test/RefcountedTest.cpp deleted file mode 100644 index 0c73bfc0b..000000000 --- a/yarpl/test/RefcountedTest.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include - -#include "yarpl/Refcounted.h" - -namespace yarpl { - -TEST(RefcountedTest, ObjectCountsAreMaintained) { - std::vector> v; - for (std::size_t i = 0; i < 16; ++i) { - v.push_back(std::make_unique()); - EXPECT_EQ(1U, v[i]->count()); // no references. - } - - v.resize(11); -} - -TEST(RefcountedTest, ReferenceCountingWorks) { - auto first = make_ref(); - EXPECT_EQ(1U, first->count()); - - auto second = first; - - EXPECT_EQ(second.get(), first.get()); - EXPECT_EQ(2U, first->count()); - - auto third = std::move(second); - EXPECT_EQ(nullptr, second.get()); - EXPECT_EQ(third.get(), first.get()); - EXPECT_EQ(2U, first->count()); - - // second was already moved from, above. - second.reset(); - EXPECT_EQ(nullptr, second.get()); - EXPECT_EQ(2U, first->count()); - - auto fourth = third; - EXPECT_EQ(3U, first->count()); - - fourth.reset(); - EXPECT_EQ(nullptr, fourth.get()); - EXPECT_EQ(2U, first->count()); -} - - -class MyRefInside : public virtual Refcounted, public yarpl::enable_get_ref { -public: - MyRefInside() { - auto r = this->ref_from_this(this); - } - - auto a_const_method() const { - return ref_from_this(this); - } -}; - -TEST(RefcountedTest, CanCallGetRefInCtor) { - auto r = make_ref(); - auto r2 = r->a_const_method(); - EXPECT_EQ(r, r2); -} - -} // yarpl diff --git a/yarpl/test/ReferenceTest.cpp b/yarpl/test/ReferenceTest.cpp deleted file mode 100644 index 206327e3c..000000000 --- a/yarpl/test/ReferenceTest.cpp +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include - -#include "yarpl/Flowable.h" -#include "yarpl/Refcounted.h" - -using yarpl::Refcounted; -using yarpl::Reference; -using yarpl::AtomicReference; -using yarpl::flowable::Subscriber; -using yarpl::flowable::BaseSubscriber; - -namespace { - -template -class MySubscriber : public BaseSubscriber { - void onSubscribeImpl() override {} - void onNextImpl(T) override {} - void onCompleteImpl() override {} - void onErrorImpl(folly::exception_wrapper) override {} -}; -} - -struct MyRefcounted : virtual Refcounted { - MyRefcounted(int i) : i(i) {} - int i; -}; - -TEST(ReferenceTest, Upcast) { - Reference> derived = yarpl::make_ref>(); - Reference> base1(derived); - - Reference> base2; - base2 = derived; - - Reference> derivedCopy1(derived); - Reference> derivedCopy2(derived); - - Reference> base3(std::move(derivedCopy1)); - - Reference> base4; - base4 = std::move(derivedCopy2); -} - -TEST(ReferenceTest, CopyAssign) { - using Sub = MySubscriber; - Reference a = yarpl::make_ref(); - Reference b(a); - EXPECT_EQ(2u, a->count()); - Reference c = yarpl::make_ref(); - b = c; - EXPECT_EQ(1u, a->count()); - EXPECT_EQ(2u, b->count()); - EXPECT_EQ(2u, c->count()); - EXPECT_EQ(c, b); -} - -TEST(ReferenceTest, MoveAssign) { - using Sub = MySubscriber; - Reference a = yarpl::make_ref(); - Reference b(std::move(a)); - EXPECT_EQ(nullptr, a); - EXPECT_EQ(1u, b->count()); - - Reference c; - c = std::move(b); - EXPECT_EQ(nullptr, b); - EXPECT_EQ(1u, c->count()); -} - -TEST(ReferenceTest, MoveAssignTemplate) { - using Sub = MySubscriber; - Reference a = yarpl::make_ref(); - Reference b(a); - EXPECT_EQ(2u, a->count()); - using Sub2 = MySubscriber; - b = yarpl::make_ref(); - EXPECT_EQ(1u, a->count()); -} - -TEST(ReferenceTest, Atomic) { - auto a = yarpl::make_ref(1); - AtomicReference b = a; - EXPECT_EQ(2u, a->count()); - EXPECT_EQ(2u, b->count()); // b and a point to same object - EXPECT_EQ(1, a->i); - EXPECT_EQ(1, b->i); - - auto c = yarpl::make_ref(2); - { - auto a_copy = b.exchange(c); - EXPECT_EQ(2, b->i); - EXPECT_EQ(2u, a->count()); - EXPECT_EQ(2u, a_copy->count()); - EXPECT_EQ(1, a_copy->i); - } - EXPECT_EQ(1u, a->count()); // a_copy destroyed - - EXPECT_EQ(2u, c->count()); - EXPECT_EQ(2u, b->count()); // b and c point to same object -} - -TEST(ReferenceTest, Construction) { - AtomicReference a{yarpl::make_ref(1)}; - EXPECT_EQ(1u, a->count()); - EXPECT_EQ(1, a->i); - - AtomicReference b = yarpl::make_ref(2); - EXPECT_EQ(1u, b->count()); - EXPECT_EQ(2, b->i); -} diff --git a/yarpl/test/Single_test.cpp b/yarpl/test/Single_test.cpp index 6f0db4baa..48d3c666e 100644 --- a/yarpl/test/Single_test.cpp +++ b/yarpl/test/Single_test.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -9,12 +21,10 @@ #include "yarpl/single/SingleTestObserver.h" #include "yarpl/test_utils/Tuple.h" -// TODO can we eliminate need to import both of these? -using namespace yarpl; using namespace yarpl::single; TEST(Single, SingleOnNext) { - auto a = Single::create([](Reference> obs) { + auto a = Single::create([](std::shared_ptr> obs) { obs->onSubscribe(SingleSubscriptions::empty()); obs->onSuccess(1); }); @@ -27,7 +37,7 @@ TEST(Single, SingleOnNext) { TEST(Single, OnError) { std::string errorMessage("DEFAULT->No Error Message"); - auto a = Single::create([](Reference> obs) { + auto a = Single::create([](std::shared_ptr> obs) { obs->onError( folly::exception_wrapper(std::runtime_error("something broke!"))); }); @@ -58,7 +68,7 @@ TEST(Single, Error) { } TEST(Single, SingleMap) { - auto a = Single::create([](Reference> obs) { + auto a = Single::create([](std::shared_ptr> obs) { obs->onSubscribe(SingleSubscriptions::empty()); obs->onSuccess(1); }); @@ -77,7 +87,7 @@ TEST(Single, MapWithException) { return n; }); - auto observer = yarpl::make_ref>(); + auto observer = std::make_shared>(); single->subscribe(observer); observer->assertOnErrorMessage("Too big!"); diff --git a/yarpl/test/SubscribeObserveOnTests.cpp b/yarpl/test/SubscribeObserveOnTests.cpp index 9bd905d95..dc95e28d9 100644 --- a/yarpl/test/SubscribeObserveOnTests.cpp +++ b/yarpl/test/SubscribeObserveOnTests.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -6,42 +18,57 @@ #include #include #include - #include "yarpl/Flowable.h" +#include "yarpl/Observable.h" #include "yarpl/flowable/TestSubscriber.h" -#include "yarpl/test_utils/utils.h" +#include "yarpl/observable/TestObserver.h" -namespace yarpl { -namespace flowable { -namespace { +using namespace yarpl::flowable; +using namespace yarpl::observable; + +constexpr std::chrono::milliseconds timeout{100}; -TEST(ObserveSubscribeTests, SubscribeOnWorksAsExpected) { +TEST(FlowableTests, SubscribeOnWorksAsExpected) { folly::ScopedEventBaseThread worker; - auto f = Flowable::create([&](auto subscriber, auto req) { + auto f = Flowable::create([&](auto& subscriber, auto req) { EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); EXPECT_EQ(1, req); - subscriber->onNext("foo"); - subscriber->onComplete(); - return std::tuple(1, true); + subscriber.onNext("foo"); + subscriber.onComplete(); }); - auto subscriber = make_ref>(1); + auto subscriber = std::make_shared>(1); f->subscribeOn(*worker.getEventBase())->subscribe(subscriber); subscriber->awaitTerminalEvent(std::chrono::milliseconds(100)); EXPECT_EQ(1, subscriber->getValueCount()); EXPECT_TRUE(subscriber->isComplete()); } -TEST(ObserveSubscribeTests, ObserveOnWorksAsExpectedSuccess) { +TEST(ObservableTests, SubscribeOnWorksAsExpected) { + folly::ScopedEventBaseThread worker; + + auto f = Observable::create([&](auto observer) { + EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); + observer->onNext("foo"); + observer->onComplete(); + }); + + auto observer = std::make_shared>(); + f->subscribeOn(*worker.getEventBase())->subscribe(observer); + observer->awaitTerminalEvent(std::chrono::milliseconds(100)); + EXPECT_EQ(1, observer->getValueCount()); + EXPECT_TRUE(observer->isComplete()); +} + +TEST(FlowableTests, ObserveOnWorksAsExpectedSuccess) { folly::ScopedEventBaseThread worker; folly::Baton<> subscriber_complete; - auto f = Flowable::create([&](auto subscriber, auto req) { + auto f = Flowable::create([&](auto& subscriber, auto req) { EXPECT_EQ(1, req); - subscriber->onNext("foo"); - subscriber->onComplete(); - return std::tuple(1, true); + subscriber.onNext("foo"); + subscriber.onComplete(); }); bool calledOnNext{false}; @@ -68,17 +95,16 @@ TEST(ObserveSubscribeTests, ObserveOnWorksAsExpectedSuccess) { 1 /* initial request(n) */ ); - CHECK_WAIT(subscriber_complete); + subscriber_complete.timed_wait(timeout); } -TEST(ObserveSubscribeTests, ObserveOnWorksAsExpectedError) { +TEST(FlowableTests, ObserveOnWorksAsExpectedError) { folly::ScopedEventBaseThread worker; folly::Baton<> subscriber_complete; - auto f = Flowable::create([&](auto subscriber, auto req) { + auto f = Flowable::create([&](auto& subscriber, auto req) { EXPECT_EQ(1, req); - subscriber->onError(std::runtime_error("oops!")); - return std::tuple(0, true); + subscriber.onError(std::runtime_error("oops!")); }); f->observeOn(*worker.getEventBase()) @@ -98,20 +124,19 @@ TEST(ObserveSubscribeTests, ObserveOnWorksAsExpectedError) { 1 /* initial request(n) */ ); - CHECK_WAIT(subscriber_complete); + subscriber_complete.timed_wait(timeout); } -TEST(ObserveSubscribeTests, BothObserveAndSubscribeOn) { +TEST(FlowableTests, BothObserveAndSubscribeOn) { folly::ScopedEventBaseThread subscriber_eb; folly::ScopedEventBaseThread producer_eb; folly::Baton<> subscriber_complete; - auto f = Flowable::create([&](auto subscriber, auto req) { + auto f = Flowable::create([&](auto& subscriber, auto req) { EXPECT_EQ(1, req); EXPECT_TRUE(producer_eb.getEventBase()->isInEventBaseThread()); - subscriber->onNext("foo"); - subscriber->onComplete(); - return std::tuple(1, true); + subscriber.onNext("foo"); + subscriber.onComplete(); }) ->subscribeOn(*producer_eb.getEventBase()) ->observeOn(*subscriber_eb.getEventBase()); @@ -139,7 +164,7 @@ TEST(ObserveSubscribeTests, BothObserveAndSubscribeOn) { 1 /* initial request(n) */ ); - CHECK_WAIT(subscriber_complete); + subscriber_complete.timed_wait(timeout); } namespace { @@ -166,7 +191,7 @@ class EarlyCancelSubscriber : public yarpl::flowable::BaseSubscriber { subscriber_complete_.post(); } - void onErrorImpl(folly::exception_wrapper e) override { + void onErrorImpl(folly::exception_wrapper /*e*/) override { FAIL(); } @@ -180,18 +205,15 @@ class EarlyCancelSubscriber : public yarpl::flowable::BaseSubscriber { }; } // namespace -TEST(ObserveSubscribeTests, EarlyCancelObserveOn) { +TEST(FlowableTests, EarlyCancelObserveOn) { folly::ScopedEventBaseThread worker; folly::Baton<> subscriber_complete; - Flowables::range(1, 100) + Flowable<>::range(1, 100) ->observeOn(*worker.getEventBase()) - ->subscribe(make_ref( + ->subscribe(std::make_shared( *worker.getEventBase(), subscriber_complete)); - CHECK_WAIT(subscriber_complete); + subscriber_complete.timed_wait(timeout); } -} // namespace -} // namespace flowable -} // namespace yarpl diff --git a/yarpl/test/ThriftStreamShimTest.cpp b/yarpl/test/ThriftStreamShimTest.cpp new file mode 100644 index 000000000..78d178bbb --- /dev/null +++ b/yarpl/test/ThriftStreamShimTest.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" +#include "yarpl/flowable/ThriftStreamShim.h" + +using namespace yarpl::flowable; + +template +std::vector run( + std::shared_ptr> flowable, + int64_t requestCount = 100) { + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + return std::move(subscriber->values()); +} +template +std::vector run(apache::thrift::ServerStream&& stream) { + std::vector values; + std::move(stream).toClientStreamUnsafeDoNotUse().subscribeInline([&](auto&& val) { + if (val.hasValue()) { + values.push_back(std::move(*val)); + } + }); + return values; +} + +apache::thrift::ClientBufferedStream makeRange(int start, int count) { + auto streamAndPublisher = + apache::thrift::ServerStream::createPublisher(); + for (int i = 0; i < count; ++i) { + streamAndPublisher.second.next(i + start); + } + std::move(streamAndPublisher.second).complete(); + return std::move(streamAndPublisher.first).toClientStreamUnsafeDoNotUse(); +} + +TEST(ThriftStreamShimTest, ClientStream) { + auto flowable = ThriftStreamShim::fromClientStream( + makeRange(1, 5), folly::getEventBase()); + EXPECT_EQ(run(flowable), std::vector({1, 2, 3, 4, 5})); +} + +TEST(ThriftStreamShimTest, ServerStream) { + auto stream = ThriftStreamShim::toServerStream(Flowable<>::range(1, 5)); + EXPECT_EQ(run(std::move(stream)), std::vector({1, 2, 3, 4, 5})); + + stream = ThriftStreamShim::toServerStream(Flowable::never()); + auto sub = std::move(stream).toClientStreamUnsafeDoNotUse().subscribeExTry( + folly::getEventBase(), [](auto) {}); + sub.cancel(); + std::move(sub).join(); + + ThriftStreamShim::toServerStream(Flowable<>::just(std::make_unique(42))); +} diff --git a/yarpl/test/credits-test.cpp b/yarpl/test/credits-test.cpp index 26e2206ca..41d23880d 100644 --- a/yarpl/test/credits-test.cpp +++ b/yarpl/test/credits-test.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -76,6 +88,14 @@ TEST(Credits, cancel3) { ASSERT_TRUE(isCancelled(&rn)); } +TEST(Credits, cancel4) { + std::atomic rn{9999}; + cancel(&rn); + // it should stay cancelled once cancelled + consume(&rn, 1); + ASSERT_TRUE(isCancelled(&rn)); +} + TEST(Credits, isInfinite) { std::atomic rn{0}; add(&rn, INT64_MAX); diff --git a/yarpl/test/test_has_shared_ptr_support.cpp b/yarpl/test/test_has_shared_ptr_support.cpp new file mode 100644 index 000000000..61bcbe004 --- /dev/null +++ b/yarpl/test/test_has_shared_ptr_support.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main() { + std::shared_ptr i; + auto il = std::atomic_load(&i); + return 0; +} diff --git a/yarpl/test/test_wrap_shared_in_atomic_support.cpp b/yarpl/test/test_wrap_shared_in_atomic_support.cpp new file mode 100644 index 000000000..136e87f2b --- /dev/null +++ b/yarpl/test/test_wrap_shared_in_atomic_support.cpp @@ -0,0 +1,23 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main() { + std::atomic> i; + std::shared_ptr j; + std::atomic_store(&i, j); + return 0; +} diff --git a/yarpl/test/yarpl-tests.cpp b/yarpl/test/yarpl-tests.cpp index 655e5a737..5046e46dd 100644 --- a/yarpl/test/yarpl-tests.cpp +++ b/yarpl/test/yarpl-tests.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -15,7 +27,5 @@ int main(int argc, char** argv) { ret = RUN_ALL_TESTS(); } - yarpl::detail::debug_refcounts(std::cerr); - return ret; } diff --git a/yarpl/test/yarpl/test_utils/Tuple.cpp b/yarpl/test/yarpl/test_utils/Tuple.cpp deleted file mode 100644 index 9d23a2cc9..000000000 --- a/yarpl/test/yarpl/test_utils/Tuple.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "Tuple.h" - -namespace yarpl { - -std::atomic Tuple::createdCount; -std::atomic Tuple::destroyedCount; -std::atomic Tuple::instanceCount; -} diff --git a/yarpl/test/yarpl/test_utils/utils.h b/yarpl/test/yarpl/test_utils/utils.h deleted file mode 100644 index e8e2cb68b..000000000 --- a/yarpl/test/yarpl/test_utils/utils.h +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -namespace yarpl { -namespace test_utils { - -auto const default_baton_timeout = std::chrono::milliseconds(100); -#define CHECK_WAIT(baton) \ - CHECK(baton.timed_wait(::yarpl::test_utils::default_baton_timeout)) -} -} diff --git a/yarpl/test/yarpl/test_utils/Mocks.h b/yarpl/test_utils/Mocks.h similarity index 76% rename from yarpl/test/yarpl/test_utils/Mocks.h rename to yarpl/test_utils/Mocks.h index fac963d63..8662fcdf1 100644 --- a/yarpl/test/yarpl/test_utils/Mocks.h +++ b/yarpl/test_utils/Mocks.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -22,10 +34,10 @@ class MockFlowable : public flowable::Flowable { public: MOCK_METHOD1_T( subscribe_, - void(yarpl::Reference> subscriber)); + void(std::shared_ptr> subscriber)); void subscribe( - yarpl::Reference> subscriber) noexcept override { + std::shared_ptr> subscriber) noexcept override { subscribe_(std::move(subscriber)); } }; @@ -35,11 +47,12 @@ class MockFlowable : public flowable::Flowable { /// For the same reason putting mock instance in a smart pointer is a poor idea. /// Can only be instanciated for CopyAssignable E type. template -class MockSubscriber : public flowable::Subscriber { +class MockSubscriber : public flowable::Subscriber, + public yarpl::enable_get_ref { public: MOCK_METHOD1( onSubscribe_, - void(yarpl::Reference subscription)); + void(std::shared_ptr subscription)); MOCK_METHOD1_T(onNext_, void(const T& value)); MOCK_METHOD0(onComplete_, void()); MOCK_METHOD1_T(onError_, void(folly::exception_wrapper ex)); @@ -48,7 +61,7 @@ class MockSubscriber : public flowable::Subscriber { : initial_(initial) {} void onSubscribe( - yarpl::Reference subscription) override { + std::shared_ptr subscription) override { subscription_ = subscription; auto this_ = this->ref_from_this(this); onSubscribe_(subscription); @@ -116,7 +129,7 @@ class MockSubscriber : public flowable::Subscriber { protected: // As the 'subscription_' member in the parent class is private, // we define it here again. - yarpl::Reference subscription_; + std::shared_ptr subscription_; int64_t initial_; @@ -142,15 +155,16 @@ class MockSubscription : public flowable::Subscription { cancel_(); } }; -} +} // namespace mocks template -class MockBaseSubscriber : public flowable::BaseSubscriber { -public: +class MockBaseSubscriber + : public flowable::BaseSubscriber { + public: MOCK_METHOD0_T(onSubscribeImpl, void()); MOCK_METHOD1_T(onNextImpl, void(T)); MOCK_METHOD0_T(onCompleteImpl, void()); MOCK_METHOD1_T(onErrorImpl, void(folly::exception_wrapper)); }; -} // namespace yarpl::mocks +} // namespace yarpl diff --git a/yarpl/test_utils/Tuple.cpp b/yarpl/test_utils/Tuple.cpp new file mode 100644 index 000000000..0ab948c42 --- /dev/null +++ b/yarpl/test_utils/Tuple.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Tuple.h" + +namespace yarpl { + +std::atomic Tuple::createdCount; +std::atomic Tuple::destroyedCount; +std::atomic Tuple::instanceCount; +} // namespace yarpl diff --git a/yarpl/test/yarpl/test_utils/Tuple.h b/yarpl/test_utils/Tuple.h similarity index 61% rename from yarpl/test/yarpl/test_utils/Tuple.h rename to yarpl/test_utils/Tuple.h index ec1829132..663e29637 100644 --- a/yarpl/test/yarpl/test_utils/Tuple.h +++ b/yarpl/test_utils/Tuple.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -43,4 +55,4 @@ struct Tuple { static std::atomic instanceCount; }; -} // yarpl +} // namespace yarpl diff --git a/yarpl/src/yarpl/utils/credits.cpp b/yarpl/utils/credits.cpp similarity index 51% rename from yarpl/src/yarpl/utils/credits.cpp rename to yarpl/utils/credits.cpp index 12687fddf..28cd9e9b5 100644 --- a/yarpl/src/yarpl/utils/credits.cpp +++ b/yarpl/utils/credits.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "yarpl/utils/credits.h" @@ -8,7 +20,7 @@ namespace yarpl { namespace credits { -int64_t add(std::atomic* current, int64_t n) { +int64_t add(std::atomic* current, int64_t n) { for (;;) { auto r = current->load(); // if already "infinite" @@ -52,7 +64,7 @@ int64_t add(int64_t current, int64_t n) { return current + n; } -bool cancel(std::atomic* current) { +bool cancel(std::atomic* current) { for (;;) { auto r = current->load(); if (r == kCanceled) { @@ -67,9 +79,17 @@ bool cancel(std::atomic* current) { } } -int64_t consume(std::atomic* current, int64_t n) { +int64_t consume(std::atomic* current, int64_t n) { for (;;) { auto r = current->load(); + // if already "infinite" + if (r == kNoFlowControl) { + return kNoFlowControl; + } + // if already "cancelled" + if (r == kCanceled) { + return kCanceled; + } if (n <= 0) { // do nothing, return existing unmodified value return r; @@ -89,11 +109,47 @@ int64_t consume(std::atomic* current, int64_t n) { } } -bool isCancelled(std::atomic* current) { +bool tryConsume(std::atomic* current, int64_t n) { + if (n <= 0) { + // do nothing, return existing unmodified value + return false; + } + + for (;;) { + auto r = current->load(); + if (r < n) { + return false; + } + + auto u = r - n; + + // set the new number + if (current->compare_exchange_strong(r, u)) { + return true; + } + // if failed to set (concurrent modification) loop and try again + } +} + +bool isCancelled(std::atomic* current) { return current->load() == kCanceled; } -bool isInfinite(std::atomic* current) { +int64_t consume(int64_t& current, int64_t n) { + if (n <= 0) { + // do nothing, return existing unmodified value + return current; + } + if (current < n) { + // bad usage somewhere ... be resilient, just set to r + n = current; + } + + current -= n; + return current; +} + +bool isInfinite(std::atomic* current) { return current->load() == kNoFlowControl; } diff --git a/yarpl/include/yarpl/utils/credits.h b/yarpl/utils/credits.h similarity index 63% rename from yarpl/include/yarpl/utils/credits.h rename to yarpl/utils/credits.h index 39da9fb15..10063b728 100644 --- a/yarpl/include/yarpl/utils/credits.h +++ b/yarpl/utils/credits.h @@ -1,10 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include #include +#include namespace yarpl { namespace credits { @@ -34,6 +47,8 @@ constexpr int64_t kNoFlowControl{std::numeric_limits::max()}; * * If 'current' is set to "cancelled" using the magic number INT64_MIN it will * not be changed. + * + * Returns new value of credits. */ int64_t add(std::atomic*, int64_t); @@ -54,9 +69,23 @@ bool cancel(std::atomic*); * Consume (remove) credits from the 'current' atomic. * * This MUST only be used to remove credits after emitting a value via onNext. + * + * Returns new value of credits. */ int64_t consume(std::atomic*, int64_t); +/** + * Try Consume (remove) credits from the 'current' atomic. + * + * Returns true if consuming the credit was successful. + */ +bool tryConsume(std::atomic*, int64_t); + +/** + * Version of consume that works for non-atomic integers. + */ +int64_t consume(int64_t&, int64_t); + /** * Whether the current value represents a "cancelled" subscription. */ @@ -67,5 +96,5 @@ bool isCancelled(std::atomic*); */ bool isInfinite(std::atomic*); -} -} +} // namespace credits +} // namespace yarpl