diff --git a/.github/workflows/TagIt.yml b/.github/workflows/TagIt.yml new file mode 100644 index 000000000..2c4b889d6 --- /dev/null +++ b/.github/workflows/TagIt.yml @@ -0,0 +1,68 @@ +on: + push: + tags: + # Only match TagIt tags, which always start with this prefix + - 'v20*' + +name: TagIt + +jobs: + build: + name: Release + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + - name: Archive project + id: archive_project + run: | + FILE_NAME=${GITHUB_REPOSITORY#*/}-${GITHUB_REF##*/} + git archive ${{ github.ref }} -o ${FILE_NAME}.zip + git archive ${{ github.ref }} -o ${FILE_NAME}.tar.gz + echo "::set-output name=file_name::${FILE_NAME}" + - name: Compute digests + id: compute_digests + run: | + echo "::set-output name=tgz_256::$(openssl dgst -sha256 ${{ steps.archive_project.outputs.file_name }}.tar.gz)" + echo "::set-output name=tgz_512::$(openssl dgst -sha512 ${{ steps.archive_project.outputs.file_name }}.tar.gz)" + echo "::set-output name=zip_256::$(openssl dgst -sha256 ${{ steps.archive_project.outputs.file_name }}.zip)" + echo "::set-output name=zip_512::$(openssl dgst -sha512 ${{ steps.archive_project.outputs.file_name }}.zip)" + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ github.ref }} + release_name: ${{ github.ref }} + body: | + Automated release from TagIt +
+ File Hashes + +
+ draft: false + prerelease: false + - name: Upload zip + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ./${{ steps.archive_project.outputs.file_name }}.zip + asset_name: ${{ steps.archive_project.outputs.file_name }}.zip + asset_content_type: application/zip + - name: Upload tar.gz + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.create_release.outputs.upload_url }} + asset_path: ./${{ steps.archive_project.outputs.file_name }}.tar.gz + asset_name: ${{ steps.archive_project.outputs.file_name }}.tar.gz + asset_content_type: application/gzip diff --git a/.gitignore b/.gitignore index bf568c4c9..e00c3821f 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,6 @@ install_manifest.txt ### Project ### -/build /reactivesocket-cpp/CTestTestfile.cmake /reactivesocket-cpp/ReactiveSocketTest /reactivesocket-cpp/compile_commands.json diff --git a/.travis.yml b/.travis.yml index b69e6c25a..1c88d192b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,9 +14,6 @@ addons: packages: &common_deps - lcov # Folly dependencies - - autoconf - - autoconf-archive - - automake - binutils-dev - g++ - libboost-all-dev @@ -25,42 +22,27 @@ addons: - libgflags-dev - libgoogle-glog-dev - libiberty-dev - - libjemalloc-dev - liblz4-dev - liblzma-dev - libsnappy-dev - libssl-dev - - libtool - make - - pkg-config - zlib1g-dev matrix: include: - # Set COMPILER environment variable instead of CC or CXX because the latter - # are overriden by Travis. Setting the compiler in Travis doesn't work - # either because it strips version. - - env: COMPILER=clang-4.0 + - env: COMPILER_EVAL="CC=clang-6.0 CXX=clang++-6.0" addons: apt: sources: - *common_srcs - - llvm-toolchain-trusty-4.0 + - llvm-toolchain-trusty-6.0 packages: - *common_deps - - clang-4.0 + - clang-6.0 - libstdc++-4.9-dev - - env: COMPILER=gcc-4.9 - addons: - apt: - sources: - - *common_srcs - packages: - - *common_deps - - g++-4.9 - - - env: COMPILER=gcc-5 + - env: COMPILER_EVAL="CC=gcc-5 CXX=g++-5" addons: apt: sources: @@ -68,8 +50,9 @@ matrix: packages: - *common_deps - g++-5 + - libjemalloc-dev - - env: COMPILER=gcc-6 + - env: COMPILER_EVAL="CC=gcc-6 CXX=g++-6" addons: apt: sources: @@ -77,6 +60,7 @@ matrix: packages: - *common_deps - g++-6 + - libjemalloc-dev env: global: @@ -93,29 +77,38 @@ env: eHz/lHAoLXWg/BhtgQbPmMYYKRrQaH7EKzBbqEHv6PhOk7vLMtdx5X7KmhVuFjpAMbaYoj zwxxH0u+VAnVB5iazzyjhySjvzkvx6pGzZtTnjLJHxKcp9633z4OU= -cache: - directories: - - $HOME/folly - before_script: + - eval "$COMPILER_EVAL" + - export DEP_INSTALL_DIR=$PWD/build/dep-install + # Ubuntu trusty only comes with OpenSSL 1.0.1f, but we require + # at least OpenSSL 1.0.2 for ALPN support. + - curl -L https://github.com/openssl/openssl/archive/OpenSSL_1_1_1.tar.gz -o OpenSSL_1_1_1.tar.gz + - tar -xzf OpenSSL_1_1_1.tar.gz + - cd openssl-OpenSSL_1_1_1 + - ./config --prefix=$DEP_INSTALL_DIR no-shared + - make -j4 + - make install_sw install_ssldirs + - cd .. # Install lcov to coveralls conversion + upload tool. - gem install coveralls-lcov - - lcov --directory build --zerocounters + - lcov --version + # Build folly + - ./scripts/build_folly.sh build/folly-src $DEP_INSTALL_DIR script: - - mkdir build && - cd build && - cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DRSOCKET_CC=$COMPILER - -DRSOCKET_ASAN=$ASAN -DRSOCKET_INSTALL_DEPS=True - -DRSOCKET_BUILD_WITH_COVERAGE=ON .. && - make -j8 && - make test - - cd .. - - ./scripts/tck_test.sh -c cpp -s cpp - - ./scripts/tck_test.sh -c java -s java - - ./scripts/tck_test.sh -c java -s cpp - - ./scripts/tck_test.sh -c cpp -s java - - cd build && make coverage + - cd build + - cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DRSOCKET_ASAN=$ASAN + -DCMAKE_PREFIX_PATH=$DEP_INSTALL_DIR + -DRSOCKET_BUILD_WITH_COVERAGE=ON .. + - make -j4 + - lcov --directory . --zerocounters + # - make test + # - make coverage + # - cd .. + # - ./scripts/tck_test.sh -c cpp -s cpp + # - ./scripts/tck_test.sh -c java -s java + # - ./scripts/tck_test.sh -c java -s cpp + # - ./scripts/tck_test.sh -c cpp -s java after_success: # Upload to coveralls. diff --git a/.ycm_extra_conf.py b/.ycm_extra_conf.py deleted file mode 100644 index 4f893b15e..000000000 --- a/.ycm_extra_conf.py +++ /dev/null @@ -1,158 +0,0 @@ -import os -import os.path -import logging -import ycm_core - -BASE_FLAGS = [ - '-xc++', - '-Wall', - '-Wextra', - '-Werror', - '-std=c++11', - '-I.', - '-isystem/usr/lib/', - '-isystem/usr/include/', -] - -SOURCE_EXTENSIONS = [ - '.cpp', - '.cxx', - '.cc', - '.c', - '.m', - '.mm' -] - -HEADER_EXTENSIONS = [ - '.h', - '.hxx', - '.hpp', - '.hh', - '.icc', - '.tcc', -] - - -def IsHeaderFile(filename): - extension = os.path.splitext(filename)[1] - return extension in HEADER_EXTENSIONS - - -def GetCompilationInfoForFile(database, filename): - if IsHeaderFile(filename): - basename = os.path.splitext(filename)[0] - for extension in SOURCE_EXTENSIONS: - replacement_file = basename + extension - if os.path.exists(replacement_file): - compilation_info = database.GetCompilationInfoForFile( - replacement_file) - if compilation_info.compiler_flags_: - return compilation_info - return None - return database.GetCompilationInfoForFile(filename) - - -def FindNearest(path, target): - candidate = os.path.join(path, target) - if(os.path.isfile(candidate) or os.path.isdir(candidate)): - logging.info("Found nearest " + target + " at " + candidate) - return candidate - else: - parent = os.path.dirname(os.path.abspath(path)) - if(parent == path): - raise RuntimeError("Could not find " + target) - return FindNearest(parent, target) - - -def MakeRelativePathsInFlagsAbsolute(flags, working_directory): - if not working_directory: - return list(flags) - new_flags = [] - make_next_absolute = False - path_flags = ['-isystem', '-I', '-iquote', '--sysroot='] - for flag in flags: - new_flag = flag - - if make_next_absolute: - make_next_absolute = False - if not flag.startswith('/'): - new_flag = os.path.join(working_directory, flag) - - for path_flag in path_flags: - if flag == path_flag: - make_next_absolute = True - break - - if flag.startswith(path_flag): - path = flag[len(path_flag):] - new_flag = path_flag + os.path.join(working_directory, path) - break - - if new_flag: - new_flags.append(new_flag) - return new_flags - - -def FlagsForClangComplete(root): - try: - clang_complete_path = FindNearest(root, '.clang_complete') - clang_complete_flags = open( - clang_complete_path, 'r').read().splitlines() - return clang_complete_flags - except: - return None - - -def FlagsForInclude(root): - try: - include_path = FindNearest(root, 'include') - flags = [] - for dirroot, dirnames, filenames in os.walk(include_path): - for dir_path in dirnames: - real_path = os.path.join(dirroot, dir_path) - flags = flags + ["-I" + real_path] - return flags - except: - return None - - -def FlagsForCompilationDatabase(root, filename): - try: - compilation_db_path = FindNearest( - os.path.join(root, 'build'), 'compile_commands.json') - compilation_db_dir = os.path.dirname(compilation_db_path) - logging.info( - "Set compilation database directory to " + compilation_db_dir) - compilation_db = ycm_core.CompilationDatabase(compilation_db_dir) - if not compilation_db: - logging.info("Compilation database file found but unable to load") - return None - compilation_info = GetCompilationInfoForFile(compilation_db, filename) - if not compilation_info: - logging.info( - "No compilation info for " + filename + " in compilation database") - return None - return MakeRelativePathsInFlagsAbsolute( - compilation_info.compiler_flags_, - compilation_info.compiler_working_dir_) - except: - return None - - -def FlagsForFile(filename): - root = os.path.realpath(filename) - compilation_db_flags = FlagsForCompilationDatabase(root, filename) - if compilation_db_flags: - final_flags = compilation_db_flags - else: - final_flags = BASE_FLAGS - clang_flags = FlagsForClangComplete(root) - if clang_flags: - final_flags = final_flags + clang_flags - include_flags = FlagsForInclude(root) - if include_flags: - final_flags = final_flags + include_flags - return { - 'flags': final_flags, - 'do_cache': True - } diff --git a/CMakeLists.txt b/CMakeLists.txt index c29187c80..c736ccbf0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,23 +1,21 @@ cmake_minimum_required(VERSION 3.2) -# The RSOCKET_CC CMake variable specifies the C compiler, e.g. gcc-4.9. -# The C++ compiler name is obtained by replacing "gcc" with "g++" and "clang" -# with "clang++"". If RSOCKET_CC is not given, the compiler is detected -# automatically. -if (RSOCKET_CC) - set(ENV{CC} ${RSOCKET_CC}) - if (${RSOCKET_CC} MATCHES clang) - string(REPLACE clang clang++ CXX ${RSOCKET_CC}) - else () - string(REPLACE gcc g++ CXX ${RSOCKET_CC}) - endif () - set(ENV{CXX} ${CXX}) -endif () - project(ReactiveSocket) +if (NOT DEFINED CPACK_GENERATOR) + set(CPACK_GENERATOR "RPM") +endif() +include(CPack) + # CMake modules. -set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake/") +set(CMAKE_MODULE_PATH + "${CMAKE_SOURCE_DIR}/cmake/" + # For in-fbsource builds + "${CMAKE_CURRENT_SOURCE_DIR}/../opensource/fbcode_builder/CMake" + # For shipit-transformed builds + "${CMAKE_CURRENT_SOURCE_DIR}/build/fbcode_builder/CMake" + ${CMAKE_MODULE_PATH} +) # Joins arguments and stores the result in ${var}. function(join var) @@ -48,11 +46,21 @@ endif(NOT CMAKE_BUILD_TYPE) string(TOLOWER "${CMAKE_BUILD_TYPE}" BUILD_TYPE_LOWER) +if (BUILD_TYPE_LOWER MATCHES "debug") + add_definitions(-DDEBUG) +endif () + # Enable ASAN by default on debug macOS builds. if (APPLE) set(OPENSSL_ROOT_DIR "/usr/local/opt/openssl") if ("${BUILD_TYPE_LOWER}" MATCHES "debug") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address,integer -fno-sanitize=unsigned-integer-overflow") + if (${CMAKE_CXX_COMPILER_ID} MATCHES Clang) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address,integer -fno-sanitize=unsigned-integer-overflow") + elseif (${CMAKE_CXX_COMPILER_ID} MATCHES GNU) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") + else () + message(FATAL_ERROR "Unsupported compiler on macOS") + endif() endif() endif() @@ -66,27 +74,50 @@ if (CMAKE_COMPILER_IS_GNUCXX) endif () set(EXTRA_LINK_FLAGS ${EXTRA_LINK_FLAGS} -fuse-ld=gold) - if (RSOCKET_BUILD_WITH_COVERAGE) - # Enable code coverage. - add_compile_options(--coverage) - set(EXTRA_LINK_FLAGS ${EXTRA_LINK_FLAGS} --coverage) - set(COVERAGE_INFO coverage.info) +elseif (${CMAKE_CXX_COMPILER_ID} MATCHES Clang) + if (RSOCKET_ASAN) + set(ASAN_FLAGS + -fsanitize=address,undefined,integer + -fno-sanitize=unsigned-integer-overflow) + endif () +endif () + +# Enable code coverage, if the compiler is supported +if (RSOCKET_BUILD_WITH_COVERAGE) + set(COVERAGE_INFO coverage.info) + + if (${CMAKE_SYSTEM_NAME} MATCHES Linux AND ${CMAKE_CXX_COMPILER_ID} MATCHES Clang) + # clang and linux's lcov don't play nice together; don't attempt with linux/clang + add_custom_command( + OUTPUT ${COVERAGE_INFO} + COMMAND echo "Coverage info omitted for clang/linux builds") + + else () add_custom_command( OUTPUT ${COVERAGE_INFO} # Capture coverage info. COMMAND lcov --directory . --capture --output-file ${COVERAGE_INFO} # Filter out system and test code. - COMMAND lcov --remove ${COVERAGE_INFO} 'tests/*' 'test/*' 'tck-test/*' '/usr/*' 'gmock/*' 'folly/*' --output-file + COMMAND lcov --remove ${COVERAGE_INFO} '*/tests/*' '*/test/*' '*/tck-test/*' '*/usr/include/*' '/usr/*' '*/gmock/*' '*/folly/*' --output-file ${COVERAGE_INFO} # Debug before upload. COMMAND lcov --list ${COVERAGE_INFO}) - endif() -elseif (${CMAKE_CXX_COMPILER_ID} MATCHES Clang) - if (RSOCKET_ASAN) - set(ASAN_FLAGS - -fsanitize=address,undefined,integer - -fno-sanitize=unsigned-integer-overflow) endif () + + if (CMAKE_COMPILER_IS_GNUCXX) + add_compile_options(-g -O0 --coverage) + set(EXTRA_LINK_FLAGS ${EXTRA_LINK_FLAGS} --coverage) + + elseif (${CMAKE_CXX_COMPILER_ID} MATCHES Clang) + add_compile_options(-g -O0 -fprofile-arcs -ftest-coverage) + set(EXTRA_LINK_FLAGS ${EXTRA_LINK_FLAGS} -fprofile-arcs -ftest-coverage) + + else () + message(FATAL_ERROR "Code coverage not supported with this compiler/host combination") + endif () + + message(STATUS "Building with coverage") + add_custom_target(coverage DEPENDS ${COVERAGE_INFO}) endif () if (DEFINED ASAN_FLAGS) @@ -94,99 +125,55 @@ if (DEFINED ASAN_FLAGS) set(EXTRA_LINK_FLAGS ${EXTRA_LINK_FLAGS} ${ASAN_FLAGS}) endif () -add_custom_target(coverage DEPENDS ${COVERAGE_INFO}) - option(BUILD_BENCHMARKS "Build benchmarks" ON) +option(BUILD_EXAMPLES "Build examples" ON) +option(BUILD_TESTS "Build tests" ON) enable_testing() include(ExternalProject) include(CTest) -if (NOT FOLLY_INSTALL_DIR) - set(FOLLY_INSTALL_DIR $ENV{HOME}/folly) -endif () +include(${CMAKE_SOURCE_DIR}/cmake/InstallFolly.cmake) -# Check if the correct version of folly is already installed. -set(FOLLY_VERSION v2017.07.10.00) -set(FOLLY_VERSION_FILE ${FOLLY_INSTALL_DIR}/${FOLLY_VERSION}) -if (RSOCKET_INSTALL_DEPS) - if (NOT EXISTS ${FOLLY_VERSION_FILE}) - # Remove the old version of folly. - file(REMOVE_RECURSE ${FOLLY_INSTALL_DIR}) - set(INSTALL_FOLLY True) - endif () -endif () - -if (INSTALL_FOLLY) - # Build and install folly. +if(BUILD_TESTS) + # gmock ExternalProject_Add( - folly-ext - GIT_REPOSITORY https://github.com/facebook/folly - GIT_TAG ${FOLLY_VERSION} - BINARY_DIR folly-ext-prefix/src/folly-ext/folly - CONFIGURE_COMMAND autoreconf -ivf - COMMAND ./configure CXX=${CMAKE_CXX_COMPILER} - --prefix=${FOLLY_INSTALL_DIR} - BUILD_COMMAND make -j4 - INSTALL_COMMAND make install - COMMAND cmake -E touch ${FOLLY_VERSION_FILE}) - - set(FOLLY_INCLUDE_DIR ${FOLLY_INSTALL_DIR}/include) - set(lib ${CMAKE_SHARED_LIBRARY_PREFIX}folly${CMAKE_SHARED_LIBRARY_SUFFIX}) - set(FOLLY_LIBRARY ${FOLLY_INSTALL_DIR}/lib/${lib}) - - # CMake requires directories listed in INTERFACE_INCLUDE_DIRECTORIES to exist. - file(MAKE_DIRECTORY ${FOLLY_INCLUDE_DIR}) -else () - # Use installed folly. - find_package(Folly REQUIRED) -endif () - -find_package(Threads) -find_library(EVENT_LIBRARY event) + gmock + URL ${CMAKE_CURRENT_SOURCE_DIR}/googletest-release-1.8.0.zip + INSTALL_COMMAND "" + ) -add_library(folly SHARED IMPORTED) -set_property(TARGET folly PROPERTY IMPORTED_LOCATION ${FOLLY_LIBRARY}) -set_property(TARGET folly - APPEND PROPERTY INTERFACE_LINK_LIBRARIES - ${EXTRA_LINK_FLAGS} ${EVENT_LIBRARY} ${CMAKE_THREAD_LIBS_INIT}) -if (TARGET folly-ext) - add_dependencies(folly folly-ext) -endif () + ExternalProject_Get_Property(gmock source_dir) + set(GMOCK_SOURCE_DIR ${source_dir}) + ExternalProject_Get_Property(gmock binary_dir) + set(GMOCK_BINARY_DIR ${binary_dir}) -# Folly includes are marked as system to prevent errors on non-standard -# extensions when compiling with -pedantic and -Werror. -set_property(TARGET folly - APPEND PROPERTY INTERFACE_SYSTEM_INCLUDE_DIRECTORIES ${FOLLY_INCLUDE_DIR}) -set_property(TARGET folly - APPEND PROPERTY INTERFACE_INCLUDE_DIRECTORIES ${FOLLY_INCLUDE_DIR}) - -# gmock -ExternalProject_Add( - gmock - URL ${CMAKE_CURRENT_SOURCE_DIR}/googletest-release-1.8.0.zip - INSTALL_COMMAND "" -) + set(GMOCK_LIBS + ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX} + ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX} + ) -ExternalProject_Get_Property(gmock source_dir) -set(GMOCK_SOURCE_DIR ${source_dir}) -ExternalProject_Get_Property(gmock binary_dir) -set(GMOCK_BINARY_DIR ${binary_dir}) + include_directories(${GMOCK_SOURCE_DIR}/googlemock/include) + include_directories(${GMOCK_SOURCE_DIR}/googletest/include) -set(GMOCK_LIBS - ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock${CMAKE_STATIC_LIBRARY_SUFFIX} - ${GMOCK_BINARY_DIR}/${CMAKE_CFG_INTDIR}/googlemock/${CMAKE_STATIC_LIBRARY_PREFIX}gmock_main${CMAKE_STATIC_LIBRARY_SUFFIX} - ) +endif() set(CMAKE_CXX_STANDARD 14) +include(CheckCXXCompilerFlag) + # Common configuration for all build modes. -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Woverloaded-virtual") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") +if (NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Woverloaded-virtual") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g") +endif() -set(EXTRA_CXX_FLAGS ${EXTRA_CXX_FLAGS} -Werror) +CHECK_CXX_COMPILER_FLAG(-Wnoexcept-type COMPILER_HAS_W_NOEXCEPT_TYPE) +if (COMPILER_HAS_W_NOEXCEPT_TYPE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-noexcept-type") +endif() if("${BUILD_TYPE_LOWER}" MATCHES "debug") message("debug mode was set") @@ -203,141 +190,143 @@ find_library(DOUBLE-CONVERSION double-conversion) find_package(OpenSSL REQUIRED) -# Find glog and gflags libraries specifically -find_path(GLOG_INCLUDE_DIR glog/logging.h) -find_path(GFLAGS_INCLUDE_DIR gflags/gflags.h) +find_package(Gflags REQUIRED) -find_library(GLOG_LIBRARY glog) -find_library(GFLAGS_LIBRARY gflags) +# find glog::glog to satisfy the folly dep. +find_package(Glog REQUIRED) -message("gflags include_dir <${GFLAGS_INCLUDE_DIR}> lib <${GFLAGS_LIBRARY}>") -message("glog include_dir <${GLOG_INCLUDE_DIR}> lib <${GLOG_LIBRARY}>") +find_package(fmt CONFIG REQUIRED) include_directories(SYSTEM ${OPENSSL_INCLUDE_DIR}) include_directories(SYSTEM ${GFLAGS_INCLUDE_DIR}) -include_directories(SYSTEM ${GLOG_INCLUDE_DIR}) - -include_directories(${CMAKE_SOURCE_DIR}) - -include_directories(${CMAKE_CURRENT_BINARY_DIR}/reactivestreams/include) -include_directories(${GMOCK_SOURCE_DIR}/googlemock/include) -include_directories(${GMOCK_SOURCE_DIR}/googletest/include) add_subdirectory(yarpl) add_library( ReactiveSocket + rsocket/ColdResumeHandler.cpp + rsocket/ColdResumeHandler.h + rsocket/ConnectionAcceptor.h + rsocket/ConnectionFactory.h + rsocket/DuplexConnection.h + rsocket/Payload.cpp + rsocket/Payload.h rsocket/RSocket.cpp rsocket/RSocket.h - rsocket/RSocketServer.h - rsocket/RSocketServer.cpp - rsocket/RSocketClient.h rsocket/RSocketClient.cpp - rsocket/RSocketRequester.h - rsocket/RSocketRequester.cpp + rsocket/RSocketClient.h rsocket/RSocketErrors.h + rsocket/RSocketException.h rsocket/RSocketParameters.cpp rsocket/RSocketParameters.h - rsocket/ConnectionAcceptor.h - rsocket/ConnectionFactory.h - rsocket/transports/tcp/TcpConnectionAcceptor.h - rsocket/transports/tcp/TcpConnectionAcceptor.cpp - rsocket/transports/tcp/TcpConnectionFactory.h - rsocket/transports/tcp/TcpConnectionFactory.cpp + rsocket/RSocketRequester.cpp + rsocket/RSocketRequester.h rsocket/RSocketResponder.cpp rsocket/RSocketResponder.h - rsocket/statemachine/ChannelRequester.cpp - rsocket/statemachine/ChannelRequester.h - rsocket/statemachine/ChannelResponder.cpp - rsocket/statemachine/ChannelResponder.h - rsocket/statemachine/ConsumerBase.cpp - rsocket/statemachine/ConsumerBase.h - rsocket/statemachine/PublisherBase.cpp - rsocket/statemachine/PublisherBase.h - rsocket/statemachine/RequestResponseRequester.cpp - rsocket/statemachine/RequestResponseRequester.h - rsocket/statemachine/RequestResponseResponder.cpp - rsocket/statemachine/RequestResponseResponder.h - rsocket/statemachine/StreamStateMachineBase.cpp - rsocket/statemachine/StreamStateMachineBase.h - rsocket/statemachine/StreamRequester.cpp - rsocket/statemachine/StreamRequester.h - rsocket/statemachine/StreamResponder.cpp - rsocket/statemachine/StreamResponder.h - rsocket/statemachine/StreamsFactory.cpp - rsocket/statemachine/StreamsFactory.h - rsocket/statemachine/StreamsWriter.h - rsocket/statemachine/StreamState.cpp - rsocket/statemachine/StreamState.h - rsocket/internal/ClientResumeStatusCallback.h - rsocket/internal/Common.cpp - rsocket/internal/Common.h - rsocket/statemachine/RSocketStateMachine.cpp - rsocket/statemachine/RSocketStateMachine.h - rsocket/DuplexConnection.h - rsocket/internal/FollyKeepaliveTimer.cpp - rsocket/internal/FollyKeepaliveTimer.h + rsocket/RSocketServer.cpp + rsocket/RSocketServer.h + rsocket/RSocketServerState.h + rsocket/RSocketServiceHandler.cpp + rsocket/RSocketServiceHandler.h + rsocket/RSocketStats.cpp + rsocket/RSocketStats.h + rsocket/ResumeManager.h rsocket/framing/ErrorCode.cpp rsocket/framing/ErrorCode.h rsocket/framing/Frame.cpp rsocket/framing/Frame.h rsocket/framing/FrameFlags.cpp rsocket/framing/FrameFlags.h + rsocket/framing/FrameHeader.cpp + rsocket/framing/FrameHeader.h rsocket/framing/FrameProcessor.h rsocket/framing/FrameSerializer.cpp rsocket/framing/FrameSerializer.h - rsocket/framing/FrameTransport.cpp + rsocket/framing/FrameSerializer_v1_0.cpp + rsocket/framing/FrameSerializer_v1_0.h rsocket/framing/FrameTransport.h - rsocket/framing/FrameType.cpp + rsocket/framing/FrameTransportImpl.cpp + rsocket/framing/FrameTransportImpl.h rsocket/framing/FrameType.cpp rsocket/framing/FrameType.h - rsocket/framing/FrameType.h rsocket/framing/FramedDuplexConnection.cpp rsocket/framing/FramedDuplexConnection.h rsocket/framing/FramedReader.cpp rsocket/framing/FramedReader.h - rsocket/framing/FramedWriter.cpp - rsocket/framing/FramedWriter.h - rsocket/Payload.cpp - rsocket/Payload.h - rsocket/internal/ResumeCache.cpp - rsocket/internal/ResumeCache.h - rsocket/internal/SetupResumeAcceptor.cpp - rsocket/internal/SetupResumeAcceptor.h - rsocket/RSocketStats.cpp - rsocket/RSocketStats.h - rsocket/transports/tcp/TcpDuplexConnection.cpp - rsocket/transports/tcp/TcpDuplexConnection.h - rsocket/framing/FrameSerializer_v0.cpp - rsocket/framing/FrameSerializer_v0.h - rsocket/framing/FrameSerializer_v0_1.cpp - rsocket/framing/FrameSerializer_v0_1.h - rsocket/framing/FrameSerializer_v1_0.cpp - rsocket/framing/FrameSerializer_v1_0.h - rsocket/internal/ScheduledSubscription.cpp - rsocket/internal/ScheduledSubscription.h - rsocket/internal/ScheduledSubscriber.h + rsocket/framing/ProtocolVersion.cpp + rsocket/framing/ProtocolVersion.h + rsocket/framing/ResumeIdentificationToken.cpp + rsocket/framing/ResumeIdentificationToken.h + rsocket/framing/ScheduledFrameProcessor.cpp + rsocket/framing/ScheduledFrameProcessor.h + rsocket/framing/ScheduledFrameTransport.cpp + rsocket/framing/ScheduledFrameTransport.h + rsocket/internal/ClientResumeStatusCallback.h + rsocket/internal/Common.cpp + rsocket/internal/Common.h + rsocket/internal/ConnectionSet.cpp + rsocket/internal/ConnectionSet.h + rsocket/internal/KeepaliveTimer.cpp + rsocket/internal/KeepaliveTimer.h rsocket/internal/ScheduledRSocketResponder.cpp rsocket/internal/ScheduledRSocketResponder.h rsocket/internal/ScheduledSingleObserver.h rsocket/internal/ScheduledSingleSubscription.cpp rsocket/internal/ScheduledSingleSubscription.h - rsocket/internal/RSocketConnectionManager.cpp - rsocket/internal/RSocketConnectionManager.h + rsocket/internal/ScheduledSubscriber.h + rsocket/internal/ScheduledSubscription.cpp + rsocket/internal/ScheduledSubscription.h + rsocket/internal/SetupResumeAcceptor.cpp + rsocket/internal/SetupResumeAcceptor.h rsocket/internal/SwappableEventBase.cpp rsocket/internal/SwappableEventBase.h - rsocket/ResumeManager.h - rsocket/ColdResumeHandler.h - rsocket/RSocketServiceHandler.cpp - rsocket/RSocketServiceHandler.h - rsocket/RSocketServerState.h - rsocket/RSocketException.h) + rsocket/internal/WarmResumeManager.cpp + rsocket/internal/WarmResumeManager.h + rsocket/statemachine/ChannelRequester.cpp + rsocket/statemachine/ChannelRequester.h + rsocket/statemachine/ChannelResponder.cpp + rsocket/statemachine/ChannelResponder.h + rsocket/statemachine/ConsumerBase.cpp + rsocket/statemachine/ConsumerBase.h + rsocket/statemachine/FireAndForgetResponder.cpp + rsocket/statemachine/FireAndForgetResponder.h + rsocket/statemachine/PublisherBase.cpp + rsocket/statemachine/PublisherBase.h + rsocket/statemachine/RSocketStateMachine.cpp + rsocket/statemachine/RSocketStateMachine.h + rsocket/statemachine/RequestResponseRequester.cpp + rsocket/statemachine/RequestResponseRequester.h + rsocket/statemachine/RequestResponseResponder.cpp + rsocket/statemachine/RequestResponseResponder.h + rsocket/statemachine/StreamRequester.cpp + rsocket/statemachine/StreamRequester.h + rsocket/statemachine/StreamResponder.cpp + rsocket/statemachine/StreamResponder.h + rsocket/statemachine/StreamStateMachineBase.cpp + rsocket/statemachine/StreamStateMachineBase.h + rsocket/statemachine/StreamFragmentAccumulator.cpp + rsocket/statemachine/StreamFragmentAccumulator.h + rsocket/statemachine/StreamsWriter.h + rsocket/statemachine/StreamsWriter.cpp + rsocket/transports/tcp/TcpConnectionAcceptor.cpp + rsocket/transports/tcp/TcpConnectionAcceptor.h + rsocket/transports/tcp/TcpConnectionFactory.cpp + rsocket/transports/tcp/TcpConnectionFactory.h + rsocket/transports/tcp/TcpDuplexConnection.cpp + rsocket/transports/tcp/TcpDuplexConnection.h) + +target_include_directories( + ReactiveSocket + PUBLIC + $ + $ +) -target_include_directories(ReactiveSocket PUBLIC "${PROJECT_SOURCE_DIR}/yarpl/include") -target_include_directories(ReactiveSocket PUBLIC "${PROJECT_SOURCE_DIR}/yarpl/src") -target_link_libraries(ReactiveSocket yarpl ${GFLAGS_LIBRARY} ${GLOG_LIBRARY}) +target_link_libraries(ReactiveSocket + PUBLIC yarpl glog::glog gflags + INTERFACE ${EXTRA_LINK_FLAGS}) target_compile_options( ReactiveSocket @@ -345,133 +334,185 @@ target_compile_options( enable_testing() -install(TARGETS ReactiveSocket DESTINATION lib) +install(TARGETS ReactiveSocket EXPORT rsocket-exports DESTINATION lib) install(DIRECTORY rsocket DESTINATION include FILES_MATCHING PATTERN "*.h") +install(EXPORT rsocket-exports NAMESPACE rsocket:: DESTINATION lib/cmake/rsocket) +include(CMakePackageConfigHelpers) +configure_package_config_file( + cmake/rsocket-config.cmake.in + rsocket-config.cmake + INSTALL_DESTINATION lib/cmake/rsocket +) +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/rsocket-config.cmake + DESTINATION lib/cmake/rsocket +) -# CMake doesn't seem to support "transitive" installing, and I can't access the -# "yarpl" target from this file, so just grab the library file directly. -install(FILES "${CMAKE_CURRENT_BINARY_DIR}/yarpl/libyarpl.a" DESTINATION lib) -install(DIRECTORY yarpl/include/yarpl DESTINATION include FILES_MATCHING PATTERN "*.h") - +if(BUILD_TESTS) add_executable( tests - test/MocksTest.cpp - test/PayloadTest.cpp - test/RSocketClientServerTest.cpp - test/RSocketTests.h - test/RequestChannelTest.cpp - test/RequestResponseTest.cpp - test/RequestStreamTest.cpp - test/WarmResumptionTest.cpp - test/Test.cpp - test/framing/FrameTest.cpp - test/framing/FrameTransportTest.cpp - test/handlers/HelloStreamRequestHandler.cpp - test/handlers/HelloStreamRequestHandler.h - test/internal/AllowanceSemaphoreTest.cpp - test/internal/FollyKeepaliveTimerTest.cpp - test/internal/SwappableEventBaseTest.cpp - test/internal/SetupResumeAcceptorTest.cpp - test/test_utils/MockDuplexConnection.h - test/test_utils/MockKeepaliveTimer.h - test/test_utils/MockRequestHandler.h - test/test_utils/MockStats.h - test/test_utils/Mocks.h - test/transport/DuplexConnectionTest.cpp - test/transport/DuplexConnectionTest.h - test/transport/TcpDuplexConnectionTest.cpp) - + rsocket/test/ColdResumptionTest.cpp + rsocket/test/ConnectionEventsTest.cpp + rsocket/test/PayloadTest.cpp + rsocket/test/RSocketClientServerTest.cpp + rsocket/test/RSocketClientTest.cpp + rsocket/test/RSocketTests.cpp + rsocket/test/RSocketTests.h + rsocket/test/RequestChannelTest.cpp + rsocket/test/RequestResponseTest.cpp + rsocket/test/RequestStreamTest.cpp + rsocket/test/RequestStreamTest_concurrency.cpp + rsocket/test/Test.cpp + rsocket/test/WarmResumeManagerTest.cpp + rsocket/test/WarmResumptionTest.cpp + rsocket/test/framing/FrameTest.cpp + rsocket/test/framing/FrameTransportTest.cpp + rsocket/test/framing/FramedReaderTest.cpp + rsocket/test/handlers/HelloServiceHandler.cpp + rsocket/test/handlers/HelloServiceHandler.h + rsocket/test/handlers/HelloStreamRequestHandler.cpp + rsocket/test/handlers/HelloStreamRequestHandler.h + rsocket/test/internal/AllowanceTest.cpp + rsocket/test/internal/ConnectionSetTest.cpp + rsocket/test/internal/KeepaliveTimerTest.cpp + rsocket/test/internal/ResumeIdentificationToken.cpp + rsocket/test/internal/SetupResumeAcceptorTest.cpp + rsocket/test/internal/SwappableEventBaseTest.cpp + rsocket/test/statemachine/RSocketStateMachineTest.cpp + rsocket/test/statemachine/StreamStateTest.cpp + rsocket/test/statemachine/StreamsWriterTest.cpp + rsocket/test/test_utils/ColdResumeManager.cpp + rsocket/test/test_utils/ColdResumeManager.h + rsocket/test/test_utils/GenericRequestResponseHandler.h + rsocket/test/test_utils/MockDuplexConnection.h + rsocket/test/test_utils/MockStreamsWriter.h + rsocket/test/test_utils/MockStats.h + rsocket/test/transport/DuplexConnectionTest.cpp + rsocket/test/transport/DuplexConnectionTest.h + rsocket/test/transport/TcpDuplexConnectionTest.cpp) + +add_dependencies(tests gmock) target_link_libraries( tests ReactiveSocket yarpl - ${GMOCK_LIBS} - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + yarpl-test-utils + ${GMOCK_LIBS} # This also needs the preceding `add_dependencies` + glog::glog + gflags) +target_include_directories(tests PUBLIC "${PROJECT_SOURCE_DIR}/yarpl/test/") target_compile_options( tests PRIVATE ${TEST_CXX_FLAGS}) -add_dependencies(tests gmock ReactiveSocket) - -add_test(NAME ReactiveSocketTests COMMAND tests) +add_dependencies(tests gmock yarpl-test-utils ReactiveSocket) -######################################## -# TCK Drivers -######################################## +add_test(NAME RSocketTests COMMAND tests) +### Fuzzer harnesses add_executable( - tckclient - tck-test/client.cpp - tck-test/TestFileParser.cpp - tck-test/TestFileParser.h - tck-test/FlowableSubscriber.cpp - tck-test/FlowableSubscriber.h - tck-test/SingleSubscriber.cpp - tck-test/SingleSubscriber.h - tck-test/TestSuite.cpp - tck-test/TestSuite.h - tck-test/TestInterpreter.cpp - tck-test/TestInterpreter.h - tck-test/TypedCommands.h - tck-test/BaseSubscriber.cpp - tck-test/BaseSubscriber.h) + frame_fuzzer + rsocket/test/fuzzers/frame_fuzzer.cpp) target_link_libraries( - tckclient + frame_fuzzer ReactiveSocket yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) -add_executable( - tckserver - tck-test/server.cpp - tck-test/MarbleProcessor.cpp - tck-test/MarbleProcessor.h - test/test_utils/StatsPrinter.cpp - test/test_utils/StatsPrinter.h) +add_dependencies(frame_fuzzer gmock ReactiveSocket) -target_link_libraries( - tckserver - ReactiveSocket - yarpl - ${GFLAGS_LIBRARY} - ${GMOCK_LIBS} - ${GLOG_LIBRARY} - ${DOUBLE-CONVERSION}) +add_test( + NAME FrameFuzzerTests + COMMAND ./scripts/frame_fuzzer_test.sh + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}) +endif() + +######################################## +# TCK Drivers +######################################## + +if(BUILD_TESTS) + add_executable( + tckclient + rsocket/tck-test/client.cpp + rsocket/tck-test/TestFileParser.cpp + rsocket/tck-test/TestFileParser.h + rsocket/tck-test/FlowableSubscriber.cpp + rsocket/tck-test/FlowableSubscriber.h + rsocket/tck-test/SingleSubscriber.cpp + rsocket/tck-test/SingleSubscriber.h + rsocket/tck-test/TestSuite.cpp + rsocket/tck-test/TestSuite.h + rsocket/tck-test/TestInterpreter.cpp + rsocket/tck-test/TestInterpreter.h + rsocket/tck-test/TypedCommands.h + rsocket/tck-test/BaseSubscriber.cpp + rsocket/tck-test/BaseSubscriber.h) + + target_link_libraries( + tckclient + ReactiveSocket + yarpl + glog::glog + gflags) + + add_executable( + tckserver + rsocket/tck-test/server.cpp + rsocket/tck-test/MarbleProcessor.cpp + rsocket/tck-test/MarbleProcessor.h + rsocket/test/test_utils/StatsPrinter.cpp + rsocket/test/test_utils/StatsPrinter.h) + + add_dependencies(tckserver gmock) + target_link_libraries( + tckserver + ReactiveSocket + yarpl + ${GMOCK_LIBS} # This also needs the preceding `add_dependencies` + glog::glog + gflags + ${DOUBLE-CONVERSION}) # Download the latest TCK drivers JAR. -set(TCK_DRIVERS_JAR rsocket-tck-drivers-0.9-SNAPSHOT.jar) -join(TCK_DRIVERS_URL - "https://oss.jfrog.org/libs-snapshot/io/rsocket/" - "rsocket-tck-drivers/0.9-SNAPSHOT/${TCK_DRIVERS_JAR}") -message(STATUS "Downloading ${TCK_DRIVERS_URL}") -file(DOWNLOAD ${TCK_DRIVERS_URL} ${CMAKE_SOURCE_DIR}/${TCK_DRIVERS_JAR}) + set(TCK_DRIVERS_JAR rsocket-tck-drivers-0.9.10.jar) + if (NOT EXISTS ${CMAKE_SOURCE_DIR}/${TCK_DRIVERS_JAR}) + join(TCK_DRIVERS_URL + "https://oss.jfrog.org/libs-release/io/rsocket/" + "rsocket-tck-drivers/0.9.10/${TCK_DRIVERS_JAR}") + message(STATUS "Downloading ${TCK_DRIVERS_URL}") + file(DOWNLOAD ${TCK_DRIVERS_URL} ${CMAKE_SOURCE_DIR}/${TCK_DRIVERS_JAR}) + endif () +endif() ######################################## # Examples ######################################## +if (BUILD_EXAMPLES) add_library( reactivesocket_examples_util - examples/util/ExampleSubscriber.cpp - examples/util/ExampleSubscriber.h + rsocket/examples/util/ExampleSubscriber.cpp + rsocket/examples/util/ExampleSubscriber.h + rsocket/test/test_utils/ColdResumeManager.h + rsocket/test/test_utils/ColdResumeManager.cpp ) target_link_libraries( reactivesocket_examples_util yarpl ReactiveSocket - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # request-response-hello-world add_executable( example_request-response-hello-world-server - examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp + rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp ) target_link_libraries( @@ -479,12 +520,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_request-response-hello-world-client - examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp + rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp ) target_link_libraries( @@ -492,14 +533,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # fire-and-forget-hello-world add_executable( example_fire-and-forget-hello-world-server - examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp + rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp ) target_link_libraries( @@ -507,12 +548,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_fire-and-forget-hello-world-client - examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp + rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp ) target_link_libraries( @@ -520,15 +561,15 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # stream-hello-world add_executable( example_stream-hello-world-server - examples/stream-hello-world/StreamHelloWorld_Server.cpp + rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp ) target_link_libraries( @@ -536,12 +577,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_stream-hello-world-client - examples/stream-hello-world/StreamHelloWorld_Client.cpp + rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp ) target_link_libraries( @@ -549,14 +590,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # channel-hello-world add_executable( example_channel-hello-world-server - examples/channel-hello-world/ChannelHelloWorld_Server.cpp + rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp ) target_link_libraries( @@ -564,12 +605,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) + + add_executable( example_channel-hello-world-client - examples/channel-hello-world/ChannelHelloWorld_Client.cpp + rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp ) target_link_libraries( @@ -577,14 +620,14 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # stream-observable-to-flowable add_executable( example_observable-to-flowable-server - examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp + rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp ) target_link_libraries( @@ -592,12 +635,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_observable-to-flowable-client - examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp + rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp ) target_link_libraries( @@ -605,18 +648,18 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # conditional-request-handling add_executable( example_conditional-request-handling-server - examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp - examples/conditional-request-handling/TextRequestHandler.h - examples/conditional-request-handling/TextRequestHandler.cpp - examples/conditional-request-handling/JsonRequestHandler.cpp - examples/conditional-request-handling/JsonRequestHandler.h + rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp + rsocket/examples/conditional-request-handling/TextRequestHandler.h + rsocket/examples/conditional-request-handling/TextRequestHandler.cpp + rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp + rsocket/examples/conditional-request-handling/JsonRequestHandler.h ) target_link_libraries( @@ -624,12 +667,12 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_conditional-request-handling-client - examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp + rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp ) target_link_libraries( @@ -637,27 +680,27 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) # warm-resumption add_executable( - example_warm-resumption-server - examples/warm-resumption/WarmResumption_Server.cpp + example_resumption-server + rsocket/examples/resumption/Resumption_Server.cpp ) target_link_libraries( - example_warm-resumption-server + example_resumption-server ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) add_executable( example_warm-resumption-client - examples/warm-resumption/WarmResumption_Client.cpp + rsocket/examples/resumption/WarmResumption_Client.cpp ) target_link_libraries( @@ -665,50 +708,28 @@ target_link_libraries( ReactiveSocket reactivesocket_examples_util yarpl - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) + glog::glog + gflags) +add_executable( + example_cold-resumption-client + rsocket/examples/resumption/ColdResumption_Client.cpp +) + +target_link_libraries( + example_cold-resumption-client + ReactiveSocket + reactivesocket_examples_util + yarpl + glog::glog + gflags) + +endif () # BUILD_EXAMPLES ######################################## # End Examples ######################################## -if(BUILD_BENCHMARKS) - ExternalProject_Add( - google_benchmark - URL ${CMAKE_SOURCE_DIR}/benchmark-1.1.0.zip - URL_MD5 c3c5cca410a1959efc93946f1739547f - CMAKE_ARGS "-DCMAKE_BUILD_TYPE=Release" - INSTALL_COMMAND "" - ) - - ExternalProject_Get_Property(google_benchmark source_dir) - set(GOOGLE_BENCHMARK_SOURCE_DIR ${source_dir}) - ExternalProject_Get_Property(google_benchmark binary_dir) - set(GOOGLE_BENCHMARK_BINARY_DIR ${binary_dir}) - - set( - GOOGLE_BENCHMARK_LIBS - ${GOOGLE_BENCHMARK_BINARY_DIR}/src/${CMAKE_STATIC_LIBRARY_PREFIX}benchmark${CMAKE_STATIC_LIBRARY_SUFFIX} - ) - - include_directories(${GOOGLE_BENCHMARK_SOURCE_DIR}/include) - include_directories(${CMAKE_SOURCE_DIR}/experimental) - - function(benchmark name file) - add_executable(${name} ${file}) - target_link_libraries( - ${name} - ReactiveSocket - yarpl - ${GOOGLE_BENCHMARK_LIBS} - ${GFLAGS_LIBRARY} - ${GLOG_LIBRARY}) - add_dependencies( - ${name} - google_benchmark) - endfunction() - - add_subdirectory(benchmarks) - add_subdirectory(yarpl/perf) -endif(BUILD_BENCHMARKS) +if (BUILD_BENCHMARKS) + add_subdirectory(rsocket/benchmarks) +endif () diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..d1abc700d --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,77 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq + diff --git a/LICENSE b/LICENSE index 4d4a15fb0..989e2c59e 100644 --- a/LICENSE +++ b/LICENSE @@ -1,30 +1,201 @@ -BSD License +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -For reactivesocket-cpp software + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -Copyright (c) 2016-present, Facebook, Inc. All rights reserved. + 1. Definitions. -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. - * Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. - * Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. - * Neither the name Facebook nor the names of its contributors may be used to - endorse or promote products derived from this software without specific - prior written permission. + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/PATENTS b/PATENTS deleted file mode 100644 index 3d7f19408..000000000 --- a/PATENTS +++ /dev/null @@ -1,33 +0,0 @@ -Additional Grant of Patent Rights Version 2 - -"Software" means the reactivesocket-cpp software distributed by Facebook, Inc. - -Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software -("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable -(subject to the termination provision below) license under any Necessary -Claims, to make, have made, use, sell, offer to sell, import, and otherwise -transfer the Software. For avoidance of doubt, no license is granted under -Facebook’s rights in any patent claims that are infringed by (i) modifications -to the Software made by you or any third party or (ii) the Software in -combination with any software or other technology. - -The license granted hereunder will terminate, automatically and without notice, -if you (or any of your subsidiaries, corporate affiliates or agents) initiate -directly or indirectly, or take a direct financial interest in, any Patent -Assertion: (i) against Facebook or any of its subsidiaries or corporate -affiliates, (ii) against any party if such Patent Assertion arises in whole or -in part from any software, technology, product or service of Facebook or any of -its subsidiaries or corporate affiliates, or (iii) against any party relating -to the Software. Notwithstanding the foregoing, if Facebook or any of its -subsidiaries or corporate affiliates files a lawsuit alleging patent -infringement against you in the first instance, and you respond by filing a -patent infringement counterclaim in that lawsuit against that party that is -unrelated to the Software, the license granted hereunder will not terminate -under section (i) of this paragraph due to such counterclaim. - -A "Necessary Claim" is a claim of a patent owned by Facebook that is -necessarily infringed by the Software standing alone. - -A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, -or contributory infringement or inducement to infringe any patent, including a -cross-claim or counterclaim. diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..230862230 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,28 @@ +## Motivation and Context + + + + + +## How Has This Been Tested + + + +## Types of changes + + +- [ ] Docs change / refactoring / dependency upgrade +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to change) + +## Checklist + + + +- [ ] My code follows the code style of this project. +- [ ] My change requires a change to the documentation. +- [ ] I have updated the documentation accordingly. +- [ ] I have read the **CONTRIBUTING** document. +- [ ] I have added tests to cover my changes. +- [ ] All new and existing tests passed. diff --git a/README.md b/README.md index 3b811aca9..1a5339e1c 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ C++ implementation of [RSocket](https://rsocket.io) Install `folly`: ``` -brew install folly +brew install --HEAD folly ``` # Building and running tests @@ -25,3 +25,8 @@ cmake -DCMAKE_BUILD_TYPE=DEBUG ../ make -j ./tests ``` + +# License + +By contributing to rsocket-cpp, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/benchmark-1.1.0.zip b/benchmark-1.1.0.zip deleted file mode 100644 index 3148789bb..000000000 Binary files a/benchmark-1.1.0.zip and /dev/null differ diff --git a/benchmarks/Baselines.cpp b/benchmarks/Baselines.cpp deleted file mode 100644 index 7bdebf88d..000000000 --- a/benchmarks/Baselines.cpp +++ /dev/null @@ -1,307 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define MAX_MESSAGE_LENGTH (8 * 1024) -#define PORT (35437) - -static void BM_Baseline_TCP_Throughput(benchmark::State& state) { - std::atomic accepting{false}; - std::atomic accepted{false}; - std::atomic running{true}; - std::uint64_t totalBytesReceived = 0; - std::size_t msgLength = static_cast(state.range(0)); - std::size_t recvLength = static_cast(state.range(1)); - - std::thread t([&]() { - int serverSock = socket(AF_INET, SOCK_STREAM, 0); - int sock = -1; - struct sockaddr_in addr; - socklen_t addrlen = sizeof(addr); - char message[MAX_MESSAGE_LENGTH]; - - std::memset(message, 0, sizeof(message)); - std::memset(&addr, 0, sizeof(addr)); - - if (serverSock < 0) { - state.SkipWithError("socket acceptor"); - perror("acceptor socket"); - return; - } - - int enable = 1; - if (setsockopt( - serverSock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < - 0) { - state.SkipWithError("setsockopt SO_REUSEADDR"); - perror("setsocketopt SO_REUSEADDR"); - return; - } - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = htonl(INADDR_ANY); - addr.sin_port = htons(PORT); - if (bind(serverSock, reinterpret_cast(&addr), addrlen) < - 0) { - state.SkipWithError("bind"); - perror("bind"); - return; - } - - if (listen(serverSock, 1) < 0) { - state.SkipWithError("listen"); - perror("listen"); - return; - } - - accepting.store(true); - - if ((sock = accept( - serverSock, reinterpret_cast(&addr), &addrlen)) < - 0) { - state.SkipWithError("accept"); - perror("accept"); - return; - } - - accepted.store(true); - - while (running) { - if (send(sock, message, msgLength, 0) != - static_cast(msgLength)) { - state.SkipWithError("send too short"); - perror("send"); - return; - } - } - - close(sock); - close(serverSock); - }); - - while (!accepting) { - std::this_thread::yield(); - } - - int sock = socket(AF_INET, SOCK_STREAM, 0); - struct sockaddr_in addr; - socklen_t addrlen = sizeof(addr); - char message[MAX_MESSAGE_LENGTH]; - - std::memset(message, 0, sizeof(message)); - std::memset(&addr, 0, sizeof(addr)); - - if (sock < 0) { - state.SkipWithError("socket connector"); - perror("connector socket"); - return; - } - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - addr.sin_port = htons(PORT); - if (connect(sock, reinterpret_cast(&addr), addrlen) < 0) { - state.SkipWithError("connect"); - perror("connect"); - return; - } - - while (!accepted) { - std::this_thread::yield(); - } - - while (state.KeepRunning()) { - ssize_t recved = recv(sock, message, recvLength, 0); - - if (recved < 0) { - state.SkipWithError("recv"); - perror("recv"); - return; - } - - totalBytesReceived += recved; - } - - running.store(false); - - close(sock); - - state.SetBytesProcessed(totalBytesReceived); - state.SetItemsProcessed(totalBytesReceived / msgLength); - - t.join(); -} - -BENCHMARK(BM_Baseline_TCP_Throughput) - ->Args({40, 1024}) - ->Args({40, 4096}) - ->Args({80, 4096}) - ->Args({4096, 4096}); - -static void BM_Baseline_TCP_Latency(benchmark::State& state) { - std::atomic accepting{false}; - std::atomic accepted{false}; - std::atomic running{true}; - std::uint64_t totalBytesReceived = 0; - std::uint64_t totalMsgsExchanged = 0; - std::size_t msgLength = static_cast(state.range(0)); - - std::thread t([&]() { - int serverSock = socket(AF_INET, SOCK_STREAM, 0); - int sock = -1; - struct sockaddr_in addr; - socklen_t addrlen = sizeof(addr); - char message[MAX_MESSAGE_LENGTH]; - - std::memset(message, 0, sizeof(message)); - std::memset(&addr, 0, sizeof(addr)); - - if (serverSock < 0) { - state.SkipWithError("socket acceptor"); - perror("acceptor socket"); - return; - } - - int enable = 1; -#if defined(SO_REUSEADDR) - if (setsockopt( - serverSock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < - 0) { - state.SkipWithError("setsockopt SO_REUSEADDR"); - perror("setsocketopt SO_REUSEADDR"); - return; - } -#endif -#if defined(SO_REUSEPORT) - enable = 1; - if (setsockopt( - serverSock, SOL_SOCKET, SO_REUSEPORT, &enable, sizeof(enable)) < - 0) { - state.SkipWithError("setsockopt SO_REUSEPORT"); - perror("setsocketopt SO_REUSEPORT"); - return; - } -#endif - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = htonl(INADDR_ANY); - addr.sin_port = htons(PORT); - if (bind(serverSock, reinterpret_cast(&addr), addrlen) < - 0) { - state.SkipWithError("bind"); - perror("bind"); - return; - } - - if (listen(serverSock, 1) < 0) { - state.SkipWithError("listen"); - perror("listen"); - return; - } - - accepting.store(true); - - if ((sock = accept( - serverSock, reinterpret_cast(&addr), &addrlen)) < - 0) { - state.SkipWithError("accept"); - perror("accept"); - return; - } - - accepted.store(true); - - while (running) { - if (send(sock, message, msgLength, 0) != - static_cast(msgLength)) { - state.SkipWithError("thread send too short"); - perror("thread send"); - break; - } - - ssize_t recved = recv(sock, message, sizeof(message), 0); - - if (recved < 0 && running) // may end while blocked on recv, so ignore - // error if that happens - { - state.SkipWithError("thread recv"); - perror("thread recv"); - break; - } - } - - close(sock); - close(serverSock); - }); - - while (!accepting) { - std::this_thread::yield(); - } - - int sock = socket(AF_INET, SOCK_STREAM, 0); - struct sockaddr_in addr; - socklen_t addrlen = sizeof(addr); - char message[MAX_MESSAGE_LENGTH]; - - std::memset(message, 0, sizeof(message)); - std::memset(&addr, 0, sizeof(addr)); - - if (sock < 0) { - state.SkipWithError("socket connector"); - perror("connector socket"); - return; - } - - addr.sin_family = AF_INET; - addr.sin_addr.s_addr = inet_addr("127.0.0.1"); - addr.sin_port = htons(PORT); - if (connect(sock, reinterpret_cast(&addr), addrlen) < 0) { - state.SkipWithError("connect"); - perror("connect"); - return; - } - - while (!accepted) { - std::this_thread::yield(); - } - - while (state.KeepRunning()) { - ssize_t recved = recv(sock, message, sizeof(message), 0); - - if (recved < 0) { - state.SkipWithError("main recv"); - perror("main recv"); - break; - } - - if (send(sock, message, msgLength, 0) != static_cast(msgLength)) { - state.SkipWithError("main send too short"); - perror("main send"); - break; - } - - totalMsgsExchanged++; - totalBytesReceived += recved; - } - - running.store(false); - - close(sock); - - state.SetBytesProcessed(totalBytesReceived); - state.SetItemsProcessed(totalMsgsExchanged); - - t.join(); -} - -BENCHMARK(BM_Baseline_TCP_Latency)->Arg(32)->Arg(128)->Arg(4096); - -BENCHMARK_MAIN() diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt deleted file mode 100644 index ffdcda56d..000000000 --- a/benchmarks/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ - -########################################################## -# Google benchmark - - -benchmark(baselines Baselines.cpp) -benchmark(streamthroughput StreamThroughput.cpp) -benchmark(reqrespthroughput RequestResponseThroughput.cpp) -benchmark(reqresplatency RequestResponseLatency.cpp) diff --git a/benchmarks/RequestResponseLatency.cpp b/benchmarks/RequestResponseLatency.cpp deleted file mode 100644 index 90ee758c9..000000000 --- a/benchmarks/RequestResponseLatency.cpp +++ /dev/null @@ -1,218 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include -#include -#include -#include - -#include "rsocket/RSocket.h" -#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "yarpl/Flowable.h" -#include "yarpl/utils/ExceptionString.h" - -using namespace ::folly; -using namespace ::rsocket; -using namespace yarpl; - -#define MESSAGE_LENGTH (32) - -DEFINE_string(host, "localhost", "host to connect to"); -DEFINE_int32(port, 9898, "host:port to connect to"); - -class BM_Subscription : public yarpl::flowable::Subscription { - public: - explicit BM_Subscription( - yarpl::Reference> subscriber, - size_t length) - : subscriber_(std::move(subscriber)), - data_(length, 'a'), - cancelled_(false) {} - - private: - void request(int64_t n) noexcept override { - LOG(INFO) << "requested=" << n; - - if (cancelled_) { - LOG(INFO) << "emission stopped by cancellation"; - return; - } - - subscriber_->onNext(Payload(data_)); - subscriber_->onComplete(); - } - - void cancel() noexcept override { - LOG(INFO) << "cancellation received"; - cancelled_ = true; - } - - yarpl::Reference> subscriber_; - std::string data_; - std::atomic_bool cancelled_; -}; - -class BM_RequestHandler : public RSocketResponder { - public: - // TODO(lehecka): enable when we have support for request-response - yarpl::Reference> handleRequestStream( - Payload, - StreamId) override { - CHECK(false) << "not implemented"; - } - - // void handleRequestResponse( - // Payload request, StreamId streamId, const - // yarpl::Reference> &response) - // noexcept override - // { - // LOG(INFO) << "BM_RequestHandler.handleRequestResponse " << request; - - // response->onSubscribe( - // std::make_shared(response, MESSAGE_LENGTH)); - // } - - // std::shared_ptr handleSetupPayload( - // ReactiveSocket &socket, ConnectionSetupPayload request) noexcept - // override - // { - // LOG(INFO) << "BM_RequestHandler.handleSetupPayload " << request; - // return nullptr; - // } -}; - -class BM_Subscriber : public yarpl::flowable::Subscriber { - public: - ~BM_Subscriber() { - LOG(INFO) << "BM_Subscriber destroy " << this; - } - - BM_Subscriber() - : initialRequest_(8), thresholdForRequest_(initialRequest_ * 0.75) { - LOG(INFO) << "BM_Subscriber " << this << " created with => " - << " Initial Request: " << initialRequest_ - << " Threshold for re-request: " << thresholdForRequest_; - } - - void onSubscribe(yarpl::Reference - subscription) noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onSubscribe"; - subscription_ = std::move(subscription); - requested_ = initialRequest_; - subscription_->request(initialRequest_); - } - - void onNext(Payload element) noexcept override { - LOG(INFO) << "BM_Subscriber " << this - << " onNext as string: " << element.moveDataToString(); - - if (--requested_ == thresholdForRequest_) { - int toRequest = (initialRequest_ - thresholdForRequest_); - LOG(INFO) << "BM_Subscriber " << this << " requesting " << toRequest - << " more items"; - requested_ += toRequest; - subscription_->request(toRequest); - } - } - - void onComplete() noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onComplete"; - terminated_ = true; - completed_ = true; - terminalEventCV_.notify_all(); - } - - void onError(std::exception_ptr ex) noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onError: " - << yarpl::exceptionStr(ex); - terminated_ = true; - terminalEventCV_.notify_all(); - } - - void awaitTerminalEvent() { - LOG(INFO) << "BM_Subscriber " << this << " block thread"; - // now block this thread - std::unique_lock lk(m_); - // if shutdown gets implemented this would then be released by it - terminalEventCV_.wait(lk, [this] { return terminated_; }); - LOG(INFO) << "BM_Subscriber " << this << " unblocked"; - } - - bool completed() { - return completed_; - } - - private: - int initialRequest_; - int thresholdForRequest_; - int requested_; - yarpl::Reference subscription_; - bool terminated_{false}; - std::mutex m_; - std::condition_variable terminalEventCV_; - std::atomic_bool completed_{false}; -}; - -class BM_RsFixture : public benchmark::Fixture { - public: - BM_RsFixture() - : host_(FLAGS_host), - port_(static_cast(FLAGS_port)), - serverRs_(RSocket::createServer(std::make_unique( - TcpConnectionAcceptor::Options(port_)))) { - FLAGS_v = 0; - FLAGS_minloglevel = 6; - serverRs_->start([](const SetupParameters&) { - return std::make_shared(); - }); - } - - virtual ~BM_RsFixture() {} - - void SetUp(const benchmark::State&) noexcept override {} - - void TearDown(const benchmark::State&) noexcept override {} - - std::string host_; - uint16_t port_; - std::unique_ptr serverRs_; -}; - -BENCHMARK_F(BM_RsFixture, BM_RequestResponse_Latency)(benchmark::State&) { - // TODO(lehecka): enable test - // folly::SocketAddress address; - // address.setFromHostPort(host_, port_); - // - // auto clientRs = - // RSocket::createClient(std::make_unique( - // std::move(address))); - // int reqs = 0; - // - // auto rs = clientRs->connect().get(); - // - // while (state.KeepRunning()) - // { - // auto sub = make_ref(); - // rs->requestResponse(Payload("BM_RequestResponse"))->subscribe(sub); - // - // while (!sub->completed()) - // { - // std::this_thread::yield(); - // } - // - // reqs++; - // } - // - // char label[256]; - // - // std::snprintf(label, sizeof(label), "Message Length: %d", - // MESSAGE_LENGTH); - // state.SetLabel(label); - // - // state.SetItemsProcessed(reqs); -} - -BENCHMARK_MAIN() diff --git a/benchmarks/RequestResponseThroughput.cpp b/benchmarks/RequestResponseThroughput.cpp deleted file mode 100644 index efafb226d..000000000 --- a/benchmarks/RequestResponseThroughput.cpp +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include -#include -#include -#include - -#include "rsocket/RSocket.h" -#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "yarpl/Flowable.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscription.h" -#include "yarpl/utils/ExceptionString.h" - -using namespace ::folly; -using namespace ::rsocket; -using namespace yarpl; -using namespace yarpl::flowable; - -#define MAX_REQUESTS (64) -#define MESSAGE_LENGTH (32) - -DEFINE_string(host, "localhost", "host to connect to"); -DEFINE_int32(port, 9898, "host:port to connect to"); - -class BM_Subscription : public Subscription { - public: - explicit BM_Subscription( - Reference> subscriber, - size_t length) - : subscriber_(std::move(subscriber)), - data_(length, 'a'), - cancelled_(false) {} - - private: - void request(int64_t n) noexcept override { - LOG(INFO) << "requested=" << n; - - if (cancelled_) { - LOG(INFO) << "emission stopped by cancellation"; - return; - } - - subscriber_->onNext(Payload(data_)); - subscriber_->onComplete(); - } - - void cancel() noexcept override { - LOG(INFO) << "cancellation received"; - cancelled_ = true; - } - - Reference> subscriber_; - std::string data_; - std::atomic_bool cancelled_; -}; - -class BM_RequestHandler : public RSocketResponder { - public: - // TODO(lehecka): enable when we have support for request-response - yarpl::Reference> handleRequestStream( - Payload, - StreamId) override { - CHECK(false) << "not implemented"; - } - - // void handleRequestResponse( - // Payload request, StreamId streamId, const - // std::shared_ptr> &response) noexcept override - // { - // LOG(INFO) << "BM_RequestHandler.handleRequestResponse " << request; - - // response->onSubscribe( - // std::make_shared(response, MESSAGE_LENGTH)); - // } - - // std::shared_ptr handleSetupPayload( - // ReactiveSocket &socket, ConnectionSetupPayload request) noexcept - // override - // { - // LOG(INFO) << "BM_RequestHandler.handleSetupPayload " << request; - // return nullptr; - // } -}; - -class BM_Subscriber : public yarpl::flowable::Subscriber { - public: - ~BM_Subscriber() { - LOG(INFO) << "BM_Subscriber destroy " << this; - } - - BM_Subscriber() - : initialRequest_(8), thresholdForRequest_(initialRequest_ * 0.75) { - LOG(INFO) << "BM_Subscriber " << this << " created with => " - << " Initial Request: " << initialRequest_ - << " Threshold for re-request: " << thresholdForRequest_; - } - - void onSubscribe(yarpl::Reference - subscription) noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onSubscribe"; - subscription_ = std::move(subscription); - requested_ = initialRequest_; - subscription_->request(initialRequest_); - } - - void onNext(Payload element) noexcept override { - LOG(INFO) << "BM_Subscriber " << this - << " onNext as string: " << element.moveDataToString(); - - if (--requested_ == thresholdForRequest_) { - int toRequest = (initialRequest_ - thresholdForRequest_); - LOG(INFO) << "BM_Subscriber " << this << " requesting " << toRequest - << " more items"; - requested_ += toRequest; - subscription_->request(toRequest); - } - } - - void onComplete() noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onComplete"; - terminated_ = true; - completed_ = true; - terminalEventCV_.notify_all(); - } - - void onError(std::exception_ptr ex) noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onError " - << yarpl::exceptionStr(ex); - terminated_ = true; - terminalEventCV_.notify_all(); - } - - void awaitTerminalEvent() { - LOG(INFO) << "BM_Subscriber " << this << " block thread"; - // now block this thread - std::unique_lock lk(m_); - // if shutdown gets implemented this would then be released by it - terminalEventCV_.wait(lk, [this] { return terminated_; }); - LOG(INFO) << "BM_Subscriber " << this << " unblocked"; - } - - bool completed() { - return completed_; - } - - private: - int initialRequest_; - int thresholdForRequest_; - int requested_; - yarpl::Reference subscription_; - bool terminated_{false}; - std::mutex m_; - std::condition_variable terminalEventCV_; - std::atomic_bool completed_{false}; -}; - -class BM_RsFixture : public benchmark::Fixture { - public: - BM_RsFixture() - : host_(FLAGS_host), - port_(static_cast(FLAGS_port)), - serverRs_(RSocket::createServer(std::make_unique( - TcpConnectionAcceptor::Options(port_)))) { - FLAGS_v = 0; - FLAGS_minloglevel = 6; - serverRs_->start([](const SetupParameters&) { - return std::make_shared(); - }); - } - - virtual ~BM_RsFixture() {} - - void SetUp(const benchmark::State&) override {} - - void TearDown(const benchmark::State&) override {} - - std::string host_; - uint16_t port_; - std::unique_ptr serverRs_; -}; - -BENCHMARK_DEFINE_F(BM_RsFixture, BM_RequestResponse_Throughput) -(benchmark::State&) { - // TODO(lehecka): enable test - // folly::SocketAddress address; - // address.setFromHostPort(host_, port_); - // - // auto clientRs = - // RSocket::createClient(std::make_unique( - // std::move(address))); - // int reqs = 0; - // int numSubscribers = state.range(0); - // int mask = numSubscribers - 1; - // - // yarpl::Reference subs[MAX_REQUESTS+1]; - // - // auto rs = clientRs->connect().get(); - // - // while (state.KeepRunning()) - // { - // int index = reqs & mask; - // - // if (nullptr != subs[index]) - // { - // while (!subs[index]->completed()) - // { - // std::this_thread::yield(); - // } - // - // subs[index].reset(); - // } - // - // subs[index] = make_ref(); - // rs->requestResponse(Payload("BM_RequestResponse"))->subscribe(subs[index]); - // reqs++; - // } - // - // for (int i = 0; i < numSubscribers; i++) - // { - // if (subs[i]) - // { - // subs[i]->awaitTerminalEvent(); - // } - // } - // - // char label[256]; - // - // std::snprintf(label, sizeof(label), "Max Requests: %d, Message Length: - // %d", numSubscribers, MESSAGE_LENGTH); - // state.SetLabel(label); - // - // state.SetItemsProcessed(reqs); -} - -BENCHMARK_REGISTER_F(BM_RsFixture, BM_RequestResponse_Throughput) - ->Arg(1) - ->Arg(2) - ->Arg(8) - ->Arg(16) - ->Arg(32); - -BENCHMARK_MAIN() diff --git a/benchmarks/StreamThroughput.cpp b/benchmarks/StreamThroughput.cpp deleted file mode 100644 index e3a2a4d05..000000000 --- a/benchmarks/StreamThroughput.cpp +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include -#include -#include -#include - -#include "rsocket/RSocket.h" -#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "yarpl/Flowable.h" -#include "yarpl/utils/ExceptionString.h" - -using namespace ::folly; -using namespace ::rsocket; -using namespace yarpl; - -#define MESSAGE_LENGTH (32) - -DEFINE_string(host, "localhost", "host to connect to"); -DEFINE_int32(port, 9898, "host:port to connect to"); - -class BM_Subscription : public yarpl::flowable::Subscription { - public: - explicit BM_Subscription( - yarpl::Reference> subscriber, - size_t length) - : subscriber_(std::move(subscriber)), - data_(length, 'a'), - cancelled_(false) {} - - private: - void request(int64_t n) noexcept override { - LOG(INFO) << "requested=" << n << " currentElem=" << currentElem_; - - for (int64_t i = 0; i < n; i++) { - if (cancelled_) { - LOG(INFO) << "emission stopped by cancellation"; - return; - } - subscriber_->onNext(Payload(data_)); - currentElem_++; - } - } - - void cancel() noexcept override { - LOG(INFO) << "cancellation received"; - cancelled_ = true; - } - - yarpl::Reference> subscriber_; - std::string data_; - size_t currentElem_ = 0; - std::atomic_bool cancelled_; -}; - -class BM_RequestHandler : public RSocketResponder { - public: - yarpl::Reference> handleRequestStream( - Payload, - StreamId) override { - CHECK(false) << "not implemented"; - // TODO(lehecka) need to implement new operator fromGenerator - // return yarpl::flowable::Flowables::fromGenerator< Payload>( - // []{return Payload(std::string(MESSAGE_LENGTH, 'a')); }); - } -}; - -class BM_Subscriber : public yarpl::flowable::Subscriber { - public: - ~BM_Subscriber() { - LOG(INFO) << "BM_Subscriber destroy " << this; - } - - explicit BM_Subscriber(int initialRequest) - : initialRequest_(initialRequest), - thresholdForRequest_(initialRequest * 0.75), - received_(0) { - LOG(INFO) << "BM_Subscriber " << this << " created with => " - << " Initial Request: " << initialRequest - << " Threshold for re-request: " << thresholdForRequest_; - } - - void onSubscribe(yarpl::Reference - subscription) noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onSubscribe"; - subscription_ = std::move(subscription); - requested_ = initialRequest_; - subscription_->request(initialRequest_); - } - - void onNext(Payload element) noexcept override { - LOG(INFO) << "BM_Subscriber " << this - << " onNext as string: " << element.moveDataToString(); - - received_.store(received_ + 1, std::memory_order_release); - - if (--requested_ == thresholdForRequest_) { - int toRequest = (initialRequest_ - thresholdForRequest_); - LOG(INFO) << "BM_Subscriber " << this << " requesting " << toRequest - << " more items"; - requested_ += toRequest; - subscription_->request(toRequest); - }; - - if (cancel_) { - subscription_->cancel(); - } - } - - void onComplete() noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onComplete"; - terminated_ = true; - terminalEventCV_.notify_all(); - } - - void onError(std::exception_ptr ex) noexcept override { - LOG(INFO) << "BM_Subscriber " << this << " onError " - << yarpl::exceptionStr(ex); - terminated_ = true; - terminalEventCV_.notify_all(); - } - - void awaitTerminalEvent() { - LOG(INFO) << "BM_Subscriber " << this << " block thread"; - // now block this thread - std::unique_lock lk(m_); - // if shutdown gets implemented this would then be released by it - terminalEventCV_.wait(lk, [this] { return terminated_; }); - LOG(INFO) << "BM_Subscriber " << this << " unblocked"; - } - - void cancel() { - cancel_ = true; - } - - size_t received() { - return received_.load(std::memory_order_acquire); - } - - private: - int initialRequest_; - int thresholdForRequest_; - int requested_; - yarpl::Reference subscription_; - bool terminated_{false}; - std::mutex m_; - std::condition_variable terminalEventCV_; - std::atomic_bool cancel_{false}; - std::atomic received_; -}; - -class BM_RsFixture : public benchmark::Fixture { - public: - BM_RsFixture() - : host_(FLAGS_host), - port_(static_cast(FLAGS_port)), - serverRs_(RSocket::createServer(std::make_unique( - TcpConnectionAcceptor::Options(port_)))) { - FLAGS_minloglevel = 100; - serverRs_->start([](const SetupParameters&) { - return std::make_shared(); - }); - } - - virtual ~BM_RsFixture() {} - - void SetUp(const benchmark::State&) noexcept override {} - - void TearDown(const benchmark::State&) noexcept override {} - - std::string host_; - uint16_t port_; - std::unique_ptr serverRs_; -}; - -BENCHMARK_DEFINE_F(BM_RsFixture, BM_Stream_Throughput) -(benchmark::State& state) { - folly::SocketAddress address; - address.setFromHostPort(host_, port_); - - auto s = make_ref(state.range(0)); - - std::shared_ptr client; - - try { - client = RSocket::createConnectedClient( - std::make_unique(std::move(address))) - .get(); - client->getRequester() - ->requestStream(Payload("BM_Stream")) - ->subscribe(std::move(s)); - } catch (const std::exception& ex) { - LOG(INFO) << "Exception received " << ex; - return; - } - - while (state.KeepRunning()) { - std::this_thread::yield(); - } - - size_t rcved = s->received(); - - s->cancel(); - s->awaitTerminalEvent(); - - char label[256]; - - std::snprintf(label, sizeof(label), "Message Length: %d", MESSAGE_LENGTH); - state.SetLabel(label); - - state.SetItemsProcessed(rcved); -} - -BENCHMARK_REGISTER_F(BM_RsFixture, BM_Stream_Throughput) - ->Arg(8) - ->Arg(32) - ->Arg(128); - -BENCHMARK_MAIN() diff --git a/build/README.md b/build/README.md new file mode 100644 index 000000000..fdcb9fdcb --- /dev/null +++ b/build/README.md @@ -0,0 +1,10 @@ +# Building using `fbcode_builder` + +Continuous integration builds are powered by `fbcode_builder`, a tiny tool +shared by several Facebook projects. Its files are in `./fbcode_builder` +(on Github) or in `fbcode/opensource/fbcode_builder` (inside Facebook's +repo). + +Start with the READMEs in the `fbcode_builder` directory. + +`./fbcode_builder_config.py` contains the project-specific configuration. diff --git a/build/deps/github_hashes/facebook/folly-rev.txt b/build/deps/github_hashes/facebook/folly-rev.txt new file mode 100644 index 000000000..cd836348c --- /dev/null +++ b/build/deps/github_hashes/facebook/folly-rev.txt @@ -0,0 +1 @@ +Subproject commit 2a20a79adf8480dffc165aebc02a93937e15ca94 diff --git a/build/fbcode_builder/.gitignore b/build/fbcode_builder/.gitignore new file mode 100644 index 000000000..b98f3edfa --- /dev/null +++ b/build/fbcode_builder/.gitignore @@ -0,0 +1,5 @@ +# Facebook-internal CI builds don't have write permission outside of the +# source tree, so we install all projects into this directory. +/facebook_ci +__pycache__/ +*.pyc diff --git a/build/fbcode_builder/CMake/FBBuildOptions.cmake b/build/fbcode_builder/CMake/FBBuildOptions.cmake new file mode 100644 index 000000000..dbaa29933 --- /dev/null +++ b/build/fbcode_builder/CMake/FBBuildOptions.cmake @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +function (fb_activate_static_library_option) + option(USE_STATIC_DEPS_ON_UNIX + "If enabled, use static dependencies on unix systems. This is generally discouraged." + OFF + ) + # Mark USE_STATIC_DEPS_ON_UNIX as an "advanced" option, since enabling it + # is generally discouraged. + mark_as_advanced(USE_STATIC_DEPS_ON_UNIX) + + if(UNIX AND USE_STATIC_DEPS_ON_UNIX) + SET(CMAKE_FIND_LIBRARY_SUFFIXES ".a" PARENT_SCOPE) + endif() +endfunction() diff --git a/build/fbcode_builder/CMake/FBCMakeParseArgs.cmake b/build/fbcode_builder/CMake/FBCMakeParseArgs.cmake new file mode 100644 index 000000000..933180189 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCMakeParseArgs.cmake @@ -0,0 +1,141 @@ +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Helper function for parsing arguments to a CMake function. +# +# This function is very similar to CMake's built-in cmake_parse_arguments() +# function, with some improvements: +# - This function correctly handles empty arguments. (cmake_parse_arguments() +# ignores empty arguments.) +# - If a multi-value argument is specified more than once, the subsequent +# arguments are appended to the original list rather than replacing it. e.g. +# if "SOURCES" is a multi-value argument, and the argument list contains +# "SOURCES a b c SOURCES x y z" then the resulting value for SOURCES will be +# "a;b;c;x;y;z" rather than "x;y;z" +# - This function errors out by default on unrecognized arguments. You can +# pass in an extra "ALLOW_UNPARSED_ARGS" argument to make it behave like +# cmake_parse_arguments(), and return the unparsed arguments in a +# _UNPARSED_ARGUMENTS variable instead. +# +# It does look like cmake_parse_arguments() handled empty arguments correctly +# from CMake 3.0 through 3.3, but it seems like this was probably broken when +# it was turned into a built-in function in CMake 3.4. Here is discussion and +# patches that fixed this behavior prior to CMake 3.0: +# https://cmake.org/pipermail/cmake-developers/2013-November/020607.html +# +# The one downside to this function over the built-in cmake_parse_arguments() +# is that I don't think we can achieve the PARSE_ARGV behavior in a non-builtin +# function, so we can't properly handle arguments that contain ";". CMake will +# treat the ";" characters as list element separators, and treat it as multiple +# separate arguments. +# +function(fb_cmake_parse_args PREFIX OPTIONS ONE_VALUE_ARGS MULTI_VALUE_ARGS ARGS) + foreach(option IN LISTS ARGN) + if ("${option}" STREQUAL "ALLOW_UNPARSED_ARGS") + set(ALLOW_UNPARSED_ARGS TRUE) + else() + message( + FATAL_ERROR + "unknown optional argument for fb_cmake_parse_args(): ${option}" + ) + endif() + endforeach() + + # Define all options as FALSE in the parent scope to start with + foreach(var_name IN LISTS OPTIONS) + set("${PREFIX}_${var_name}" "FALSE" PARENT_SCOPE) + endforeach() + + # TODO: We aren't extremely strict about error checking for one-value + # arguments here. e.g., we don't complain if a one-value argument is + # followed by another option/one-value/multi-value name rather than an + # argument. We also don't complain if a one-value argument is the last + # argument and isn't followed by a value. + + list(APPEND all_args ${ONE_VALUE_ARGS}) + list(APPEND all_args ${MULTI_VALUE_ARGS}) + set(current_variable) + set(unparsed_args) + foreach(arg IN LISTS ARGS) + list(FIND OPTIONS "${arg}" opt_index) + if("${opt_index}" EQUAL -1) + list(FIND all_args "${arg}" arg_index) + if("${arg_index}" EQUAL -1) + # This argument does not match an argument name, + # must be an argument value + if("${current_variable}" STREQUAL "") + list(APPEND unparsed_args "${arg}") + else() + # Ugh, CMake lists have a pretty fundamental flaw: they cannot + # distinguish between an empty list and a list with a single empty + # element. We track our own SEEN_VALUES_arg setting to help + # distinguish this and behave properly here. + if ("${SEEN_${current_variable}}" AND "${${current_variable}}" STREQUAL "") + set("${current_variable}" ";${arg}") + else() + list(APPEND "${current_variable}" "${arg}") + endif() + set("SEEN_${current_variable}" TRUE) + endif() + else() + # We found a single- or multi-value argument name + set(current_variable "VALUES_${arg}") + set("SEEN_${arg}" TRUE) + endif() + else() + # We found an option variable + set("${PREFIX}_${arg}" "TRUE" PARENT_SCOPE) + set(current_variable) + endif() + endforeach() + + foreach(arg_name IN LISTS ONE_VALUE_ARGS) + if(NOT "${SEEN_${arg_name}}") + unset("${PREFIX}_${arg_name}" PARENT_SCOPE) + elseif(NOT "${SEEN_VALUES_${arg_name}}") + # If the argument was seen but a value wasn't specified, error out. + # We require exactly one value to be specified. + message( + FATAL_ERROR "argument ${arg_name} was specified without a value" + ) + else() + list(LENGTH "VALUES_${arg_name}" num_args) + if("${num_args}" EQUAL 0) + # We know an argument was specified and that we called list(APPEND). + # If CMake thinks the list is empty that means there is really a single + # empty element in the list. + set("${PREFIX}_${arg_name}" "" PARENT_SCOPE) + elseif("${num_args}" EQUAL 1) + list(GET "VALUES_${arg_name}" 0 arg_value) + set("${PREFIX}_${arg_name}" "${arg_value}" PARENT_SCOPE) + else() + message( + FATAL_ERROR "too many arguments specified for ${arg_name}: " + "${VALUES_${arg_name}}" + ) + endif() + endif() + endforeach() + + foreach(arg_name IN LISTS MULTI_VALUE_ARGS) + # If this argument name was never seen, then unset the parent scope + if (NOT "${SEEN_${arg_name}}") + unset("${PREFIX}_${arg_name}" PARENT_SCOPE) + else() + # TODO: Our caller still won't be able to distinguish between an empty + # list and a list with a single empty element. We can tell which is + # which, but CMake lists don't make it easy to show this to our caller. + set("${PREFIX}_${arg_name}" "${VALUES_${arg_name}}" PARENT_SCOPE) + endif() + endforeach() + + # By default we fatal out on unparsed arguments, but return them to the + # caller if ALLOW_UNPARSED_ARGS was specified. + if (DEFINED unparsed_args) + if ("${ALLOW_UNPARSED_ARGS}") + set("${PREFIX}_UNPARSED_ARGUMENTS" "${unparsed_args}" PARENT_SCOPE) + else() + message(FATAL_ERROR "unrecognized arguments: ${unparsed_args}") + endif() + endif() +endfunction() diff --git a/build/fbcode_builder/CMake/FBCompilerSettings.cmake b/build/fbcode_builder/CMake/FBCompilerSettings.cmake new file mode 100644 index 000000000..585c95320 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCompilerSettings.cmake @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This file applies common compiler settings that are shared across +# a number of Facebook opensource projects. +# Please use caution and your best judgement before making changes +# to these shared compiler settings in order to avoid accidentally +# breaking a build in another project! + +if (WIN32) + include(FBCompilerSettingsMSVC) +else() + include(FBCompilerSettingsUnix) +endif() diff --git a/build/fbcode_builder/CMake/FBCompilerSettingsMSVC.cmake b/build/fbcode_builder/CMake/FBCompilerSettingsMSVC.cmake new file mode 100644 index 000000000..4efd7e966 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCompilerSettingsMSVC.cmake @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This file applies common compiler settings that are shared across +# a number of Facebook opensource projects. +# Please use caution and your best judgement before making changes +# to these shared compiler settings in order to avoid accidentally +# breaking a build in another project! + +add_compile_options( + /wd4250 # 'class1' : inherits 'class2::member' via dominance +) diff --git a/build/fbcode_builder/CMake/FBCompilerSettingsUnix.cmake b/build/fbcode_builder/CMake/FBCompilerSettingsUnix.cmake new file mode 100644 index 000000000..c26ce78b1 --- /dev/null +++ b/build/fbcode_builder/CMake/FBCompilerSettingsUnix.cmake @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This file applies common compiler settings that are shared across +# a number of Facebook opensource projects. +# Please use caution and your best judgement before making changes +# to these shared compiler settings in order to avoid accidentally +# breaking a build in another project! + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Wextra -Wno-deprecated -Wno-deprecated-declarations") diff --git a/build/fbcode_builder/CMake/FBPythonBinary.cmake b/build/fbcode_builder/CMake/FBPythonBinary.cmake new file mode 100644 index 000000000..99c33fb8c --- /dev/null +++ b/build/fbcode_builder/CMake/FBPythonBinary.cmake @@ -0,0 +1,697 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) + +# +# This file contains helper functions for building self-executing Python +# binaries. +# +# This is somewhat different than typical python installation with +# distutils/pip/virtualenv/etc. We primarily want to build a standalone +# executable, isolated from other Python packages on the system. We don't want +# to install files into the standard library python paths. This is more +# similar to PEX (https://github.com/pantsbuild/pex) and XAR +# (https://github.com/facebookincubator/xar). (In the future it would be nice +# to update this code to also support directly generating XAR files if XAR is +# available.) +# +# We also want to be able to easily define "libraries" of python files that can +# be shared and re-used between these standalone python executables, and can be +# shared across projects in different repositories. This means that we do need +# a way to "install" libraries so that they are visible to CMake builds in +# other repositories, without actually installing them in the standard python +# library paths. +# + +# If the caller has not already found Python, do so now. +# If we fail to find python now we won't fail immediately, but +# add_fb_python_executable() or add_fb_python_library() will fatal out if they +# are used. +if(NOT TARGET Python3::Interpreter) + # CMake 3.12+ ships with a FindPython3.cmake module. Try using it first. + # We find with QUIET here, since otherwise this generates some noisy warnings + # on versions of CMake before 3.12 + if (WIN32) + # On Windows we need both the Intepreter as well as the Development + # libraries. + find_package(Python3 COMPONENTS Interpreter Development QUIET) + else() + find_package(Python3 COMPONENTS Interpreter QUIET) + endif() + if(Python3_Interpreter_FOUND) + message(STATUS "Found Python 3: ${Python3_EXECUTABLE}") + else() + # Try with the FindPythonInterp.cmake module available in older CMake + # versions. Check to see if the caller has already searched for this + # themselves first. + if(NOT PYTHONINTERP_FOUND) + set(Python_ADDITIONAL_VERSIONS 3 3.6 3.5 3.4 3.3 3.2 3.1) + find_package(PythonInterp) + # TODO: On Windows we require the Python libraries as well. + # We currently do not search for them on this code path. + # For now we require building with CMake 3.12+ on Windows, so that the + # FindPython3 code path above is available. + endif() + if(PYTHONINTERP_FOUND) + if("${PYTHON_VERSION_MAJOR}" GREATER_EQUAL 3) + set(Python3_EXECUTABLE "${PYTHON_EXECUTABLE}") + add_custom_target(Python3::Interpreter) + else() + string( + CONCAT FBPY_FIND_PYTHON_ERR + "found Python ${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}, " + "but need Python 3" + ) + endif() + endif() + endif() +endif() + +# Find our helper program. +# We typically install this in the same directory as this .cmake file. +find_program( + FB_MAKE_PYTHON_ARCHIVE "make_fbpy_archive.py" + PATHS ${CMAKE_MODULE_PATH} +) +set(FB_PY_TEST_MAIN "${CMAKE_CURRENT_LIST_DIR}/fb_py_test_main.py") +set( + FB_PY_TEST_DISCOVER_SCRIPT + "${CMAKE_CURRENT_LIST_DIR}/FBPythonTestAddTests.cmake" +) +set( + FB_PY_WIN_MAIN_C + "${CMAKE_CURRENT_LIST_DIR}/fb_py_win_main.c" +) + +# An option to control the default installation location for +# install_fb_python_library(). This is relative to ${CMAKE_INSTALL_PREFIX} +set( + FBPY_LIB_INSTALL_DIR "lib/fb-py-libs" CACHE STRING + "The subdirectory where FB python libraries should be installed" +) + +# +# Build a self-executing python binary. +# +# This accepts the same arguments as add_fb_python_library(). +# +# In addition, a MAIN_MODULE argument is accepted. This argument specifies +# which module should be started as the __main__ module when the executable is +# run. If left unspecified, a __main__.py script must be present in the +# manifest. +# +function(add_fb_python_executable TARGET) + fb_py_check_available() + + # Parse the arguments + set(one_value_args BASE_DIR NAMESPACE MAIN_MODULE TYPE) + set(multi_value_args SOURCES DEPENDS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + fb_py_process_default_args(ARG_NAMESPACE ARG_BASE_DIR) + + # Use add_fb_python_library() to perform most of our source handling + add_fb_python_library( + "${TARGET}.main_lib" + BASE_DIR "${ARG_BASE_DIR}" + NAMESPACE "${ARG_NAMESPACE}" + SOURCES ${ARG_SOURCES} + DEPENDS ${ARG_DEPENDS} + ) + + set( + manifest_files + "$" + ) + set( + source_files + "$" + ) + + # The command to build the executable archive. + # + # If we are using CMake 3.8+ we can use COMMAND_EXPAND_LISTS. + # CMP0067 isn't really the policy we care about, but seems like the best way + # to check if we are running 3.8+. + if (POLICY CMP0067) + set(extra_cmd_params COMMAND_EXPAND_LISTS) + set(make_py_args "${manifest_files}") + else() + set(extra_cmd_params) + set(make_py_args --manifest-separator "::" "$") + endif() + + set(output_file "${TARGET}${CMAKE_EXECUTABLE_SUFFIX}") + if(WIN32) + set(zipapp_output "${TARGET}.py_zipapp") + else() + set(zipapp_output "${output_file}") + endif() + set(zipapp_output_file "${zipapp_output}") + + set(is_dir_output FALSE) + if(DEFINED ARG_TYPE) + list(APPEND make_py_args "--type" "${ARG_TYPE}") + if ("${ARG_TYPE}" STREQUAL "dir") + set(is_dir_output TRUE) + # CMake doesn't really seem to like having a directory specified as an + # output; specify the __main__.py file as the output instead. + set(zipapp_output_file "${zipapp_output}/__main__.py") + list(APPEND + extra_cmd_params + COMMAND "${CMAKE_COMMAND}" -E remove_directory "${zipapp_output}" + ) + endif() + endif() + + if(DEFINED ARG_MAIN_MODULE) + list(APPEND make_py_args "--main" "${ARG_MAIN_MODULE}") + endif() + + add_custom_command( + OUTPUT "${zipapp_output_file}" + ${extra_cmd_params} + COMMAND + "${Python3_EXECUTABLE}" "${FB_MAKE_PYTHON_ARCHIVE}" + -o "${zipapp_output}" + ${make_py_args} + DEPENDS + ${source_files} + "${TARGET}.main_lib.py_sources_built" + "${FB_MAKE_PYTHON_ARCHIVE}" + ) + + if(WIN32) + if(is_dir_output) + # TODO: generate a main executable that will invoke Python3 + # with the correct main module inside the output directory + else() + add_executable("${TARGET}.winmain" "${FB_PY_WIN_MAIN_C}") + target_link_libraries("${TARGET}.winmain" Python3::Python) + # The Python3::Python target doesn't seem to be set up completely + # correctly on Windows for some reason, and we have to explicitly add + # ${Python3_LIBRARY_DIRS} to the target link directories. + target_link_directories( + "${TARGET}.winmain" + PUBLIC ${Python3_LIBRARY_DIRS} + ) + add_custom_command( + OUTPUT "${output_file}" + DEPENDS "${TARGET}.winmain" "${zipapp_output_file}" + COMMAND + "cmd.exe" "/c" "copy" "/b" + "${TARGET}.winmain${CMAKE_EXECUTABLE_SUFFIX}+${zipapp_output}" + "${output_file}" + ) + endif() + endif() + + # Add an "ALL" target that depends on force ${TARGET}, + # so that ${TARGET} will be included in the default list of build targets. + add_custom_target("${TARGET}.GEN_PY_EXE" ALL DEPENDS "${output_file}") + + # Allow resolving the executable path for the target that we generate + # via a generator expression like: + # "WATCHMAN_WAIT_PATH=$" + set_property(TARGET "${TARGET}.GEN_PY_EXE" + PROPERTY EXECUTABLE "${CMAKE_CURRENT_BINARY_DIR}/${output_file}") +endfunction() + +# Define a python unittest executable. +# The executable is built using add_fb_python_executable and has the +# following differences: +# +# Each of the source files specified in SOURCES will be imported +# and have unittest discovery performed upon them. +# Those sources will be imported in the top level namespace. +# +# The ENV argument allows specifying a list of "KEY=VALUE" +# pairs that will be used by the test runner to set up the environment +# in the child process prior to running the test. This is useful for +# passing additional configuration to the test. +function(add_fb_python_unittest TARGET) + # Parse the arguments + set(multi_value_args SOURCES DEPENDS ENV PROPERTIES) + set( + one_value_args + WORKING_DIRECTORY BASE_DIR NAMESPACE TEST_LIST DISCOVERY_TIMEOUT + ) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + fb_py_process_default_args(ARG_NAMESPACE ARG_BASE_DIR) + if(NOT ARG_WORKING_DIRECTORY) + # Default the working directory to the current binary directory. + # This matches the default behavior of add_test() and other standard + # test functions like gtest_discover_tests() + set(ARG_WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") + endif() + if(NOT ARG_TEST_LIST) + set(ARG_TEST_LIST "${TARGET}_TESTS") + endif() + if(NOT ARG_DISCOVERY_TIMEOUT) + set(ARG_DISCOVERY_TIMEOUT 5) + endif() + + # Tell our test program the list of modules to scan for tests. + # We scan all modules directly listed in our SOURCES argument, and skip + # modules that came from dependencies in the DEPENDS list. + # + # This is written into a __test_modules__.py module that the test runner + # will look at. + set( + test_modules_path + "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}_test_modules.py" + ) + file(WRITE "${test_modules_path}" "TEST_MODULES = [\n") + string(REPLACE "." "/" namespace_dir "${ARG_NAMESPACE}") + if (NOT "${namespace_dir}" STREQUAL "") + set(namespace_dir "${namespace_dir}/") + endif() + set(test_modules) + foreach(src_path IN LISTS ARG_SOURCES) + fb_py_compute_dest_path( + abs_source dest_path + "${src_path}" "${namespace_dir}" "${ARG_BASE_DIR}" + ) + string(REPLACE "/" "." module_name "${dest_path}") + string(REGEX REPLACE "\\.py$" "" module_name "${module_name}") + list(APPEND test_modules "${module_name}") + file(APPEND "${test_modules_path}" " '${module_name}',\n") + endforeach() + file(APPEND "${test_modules_path}" "]\n") + + # The __main__ is provided by our runner wrapper/bootstrap + list(APPEND ARG_SOURCES "${FB_PY_TEST_MAIN}=__main__.py") + list(APPEND ARG_SOURCES "${test_modules_path}=__test_modules__.py") + + add_fb_python_executable( + "${TARGET}" + NAMESPACE "${ARG_NAMESPACE}" + BASE_DIR "${ARG_BASE_DIR}" + SOURCES ${ARG_SOURCES} + DEPENDS ${ARG_DEPENDS} + ) + + # Run test discovery after the test executable is built. + # This logic is based on the code for gtest_discover_tests() + set(ctest_file_base "${CMAKE_CURRENT_BINARY_DIR}/${TARGET}") + set(ctest_include_file "${ctest_file_base}_include.cmake") + set(ctest_tests_file "${ctest_file_base}_tests.cmake") + add_custom_command( + TARGET "${TARGET}.GEN_PY_EXE" POST_BUILD + BYPRODUCTS "${ctest_tests_file}" + COMMAND + "${CMAKE_COMMAND}" + -D "TEST_TARGET=${TARGET}" + -D "TEST_INTERPRETER=${Python3_EXECUTABLE}" + -D "TEST_ENV=${ARG_ENV}" + -D "TEST_EXECUTABLE=$" + -D "TEST_WORKING_DIR=${ARG_WORKING_DIRECTORY}" + -D "TEST_LIST=${ARG_TEST_LIST}" + -D "TEST_PREFIX=${TARGET}::" + -D "TEST_PROPERTIES=${ARG_PROPERTIES}" + -D "CTEST_FILE=${ctest_tests_file}" + -P "${FB_PY_TEST_DISCOVER_SCRIPT}" + VERBATIM + ) + + file( + WRITE "${ctest_include_file}" + "if(EXISTS \"${ctest_tests_file}\")\n" + " include(\"${ctest_tests_file}\")\n" + "else()\n" + " add_test(\"${TARGET}_NOT_BUILT\" \"${TARGET}_NOT_BUILT\")\n" + "endif()\n" + ) + set_property( + DIRECTORY APPEND PROPERTY TEST_INCLUDE_FILES + "${ctest_include_file}" + ) +endfunction() + +# +# Define a python library. +# +# If you want to install a python library generated from this rule note that +# you need to use install_fb_python_library() rather than CMake's built-in +# install() function. This will make it available for other downstream +# projects to use in their add_fb_python_executable() and +# add_fb_python_library() calls. (You do still need to use `install(EXPORT)` +# later to install the CMake exports.) +# +# Parameters: +# - BASE_DIR : +# The base directory path to strip off from each source path. All source +# files must be inside this directory. If not specified it defaults to +# ${CMAKE_CURRENT_SOURCE_DIR}. +# - NAMESPACE : +# The destination namespace where these files should be installed in python +# binaries. If not specified, this defaults to the current relative path of +# ${CMAKE_CURRENT_SOURCE_DIR} inside ${CMAKE_SOURCE_DIR}. e.g., a python +# library defined in the directory repo_root/foo/bar will use a default +# namespace of "foo.bar" +# - SOURCES <...>: +# The python source files. +# You may optionally specify as source using the form: PATH=ALIAS where +# PATH is a relative path in the source tree and ALIAS is the relative +# path into which PATH should be rewritten. This is useful for mapping +# an executable script to the main module in a python executable. +# e.g.: `python/bin/watchman-wait=__main__.py` +# - DEPENDS <...>: +# Other python libraries that this one depends on. +# - INSTALL_DIR : +# The directory where this library should be installed. +# install_fb_python_library() must still be called later to perform the +# installation. If a relative path is given it will be treated relative to +# ${CMAKE_INSTALL_PREFIX} +# +# CMake is unfortunately pretty crappy at being able to define custom build +# rules & behaviors. It doesn't support transitive property propagation +# between custom targets; only the built-in add_executable() and add_library() +# targets support transitive properties. +# +# We hack around this janky CMake behavior by (ab)using interface libraries to +# propagate some of the data we want between targets, without actually +# generating a C library. +# +# add_fb_python_library(SOMELIB) generates the following things: +# - An INTERFACE library rule named SOMELIB.py_lib which tracks some +# information about transitive dependencies: +# - the transitive set of source files in the INTERFACE_SOURCES property +# - the transitive set of manifest files that this library depends on in +# the INTERFACE_INCLUDE_DIRECTORIES property. +# - A custom command that generates a SOMELIB.manifest file. +# This file contains the mapping of source files to desired destination +# locations in executables that depend on this library. This manifest file +# will then be read at build-time in order to build executables. +# +function(add_fb_python_library LIB_NAME) + fb_py_check_available() + + # Parse the arguments + # We use fb_cmake_parse_args() rather than cmake_parse_arguments() since + # cmake_parse_arguments() does not handle empty arguments, and it is common + # for callers to want to specify an empty NAMESPACE parameter. + set(one_value_args BASE_DIR NAMESPACE INSTALL_DIR) + set(multi_value_args SOURCES DEPENDS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + fb_py_process_default_args(ARG_NAMESPACE ARG_BASE_DIR) + + string(REPLACE "." "/" namespace_dir "${ARG_NAMESPACE}") + if (NOT "${namespace_dir}" STREQUAL "") + set(namespace_dir "${namespace_dir}/") + endif() + + if(NOT DEFINED ARG_INSTALL_DIR) + set(install_dir "${FBPY_LIB_INSTALL_DIR}/") + elseif("${ARG_INSTALL_DIR}" STREQUAL "") + set(install_dir "") + else() + set(install_dir "${ARG_INSTALL_DIR}/") + endif() + + # message(STATUS "fb py library ${LIB_NAME}: " + # "NS=${namespace_dir} BASE=${ARG_BASE_DIR}") + + # TODO: In the future it would be nice to support pre-compiling the source + # files. We could emit a rule to compile each source file and emit a + # .pyc/.pyo file here, and then have the manifest reference the pyc/pyo + # files. + + # Define a library target to help pass around information about the library, + # and propagate dependency information. + # + # CMake make a lot of assumptions that libraries are C++ libraries. To help + # avoid confusion we name our target "${LIB_NAME}.py_lib" rather than just + # "${LIB_NAME}". This helps avoid confusion if callers try to use + # "${LIB_NAME}" on their own as a target name. (e.g., attempting to install + # it directly with install(TARGETS) won't work. Callers must use + # install_fb_python_library() instead.) + add_library("${LIB_NAME}.py_lib" INTERFACE) + + # Emit the manifest file. + # + # We write the manifest file to a temporary path first, then copy it with + # configure_file(COPYONLY). This is necessary to get CMake to understand + # that "${manifest_path}" is generated by the CMake configure phase, + # and allow using it as a dependency for add_custom_command(). + # (https://gitlab.kitware.com/cmake/cmake/issues/16367) + set(manifest_path "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}.manifest") + set(tmp_manifest "${manifest_path}.tmp") + file(WRITE "${tmp_manifest}" "FBPY_MANIFEST 1\n") + set(abs_sources) + foreach(src_path IN LISTS ARG_SOURCES) + fb_py_compute_dest_path( + abs_source dest_path + "${src_path}" "${namespace_dir}" "${ARG_BASE_DIR}" + ) + list(APPEND abs_sources "${abs_source}") + target_sources( + "${LIB_NAME}.py_lib" INTERFACE + "$" + "$" + ) + file( + APPEND "${tmp_manifest}" + "${abs_source} :: ${dest_path}\n" + ) + endforeach() + configure_file("${tmp_manifest}" "${manifest_path}" COPYONLY) + + target_include_directories( + "${LIB_NAME}.py_lib" INTERFACE + "$" + "$" + ) + + # Add a target that depends on all of the source files. + # This is needed in case some of the source files are generated. This will + # ensure that these source files are brought up-to-date before we build + # any python binaries that depend on this library. + add_custom_target("${LIB_NAME}.py_sources_built" DEPENDS ${abs_sources}) + add_dependencies("${LIB_NAME}.py_lib" "${LIB_NAME}.py_sources_built") + + # Hook up library dependencies, and also make the *.py_sources_built target + # depend on the sources for all of our dependencies also being up-to-date. + foreach(dep IN LISTS ARG_DEPENDS) + target_link_libraries("${LIB_NAME}.py_lib" INTERFACE "${dep}.py_lib") + + # Mark that our .py_sources_built target depends on each our our dependent + # libraries. This serves two functions: + # - This causes CMake to generate an error message if one of the + # dependencies is never defined. The target_link_libraries() call above + # won't complain if one of the dependencies doesn't exist (since it is + # intended to allow passing in file names for plain library files rather + # than just targets). + # - It ensures that sources for our depencencies are built before any + # executable that depends on us. Note that we depend on "${dep}.py_lib" + # rather than "${dep}.py_sources_built" for this purpose because the + # ".py_sources_built" target won't be available for imported targets. + add_dependencies("${LIB_NAME}.py_sources_built" "${dep}.py_lib") + endforeach() + + # Add a custom command to help with library installation, in case + # install_fb_python_library() is called later for this library. + # add_custom_command() only works with file dependencies defined in the same + # CMakeLists.txt file, so we want to make sure this is defined here, rather + # then where install_fb_python_library() is called. + # This command won't be run by default, but will only be run if it is needed + # by a subsequent install_fb_python_library() call. + # + # This command copies the library contents into the build directory. + # It would be nicer if we could skip this intermediate copy, and just run + # make_fbpy_archive.py at install time to copy them directly to the desired + # installation directory. Unfortunately this is difficult to do, and seems + # to interfere with some of the CMake code that wants to generate a manifest + # of installed files. + set(build_install_dir "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}.lib_install") + add_custom_command( + OUTPUT + "${build_install_dir}/${LIB_NAME}.manifest" + COMMAND "${CMAKE_COMMAND}" -E remove_directory "${build_install_dir}" + COMMAND + "${Python3_EXECUTABLE}" "${FB_MAKE_PYTHON_ARCHIVE}" --type lib-install + --install-dir "${LIB_NAME}" + -o "${build_install_dir}/${LIB_NAME}" "${manifest_path}" + DEPENDS + "${abs_sources}" + "${manifest_path}" + "${FB_MAKE_PYTHON_ARCHIVE}" + ) + add_custom_target( + "${LIB_NAME}.py_lib_install" + DEPENDS "${build_install_dir}/${LIB_NAME}.manifest" + ) + + # Set some properties to pass through the install paths to + # install_fb_python_library() + # + # Passing through ${build_install_dir} allows install_fb_python_library() + # to work even if used from a different CMakeLists.txt file than where + # add_fb_python_library() was called (i.e. such that + # ${CMAKE_CURRENT_BINARY_DIR} is different between the two calls). + set(abs_install_dir "${install_dir}") + if(NOT IS_ABSOLUTE "${abs_install_dir}") + set(abs_install_dir "${CMAKE_INSTALL_PREFIX}/${abs_install_dir}") + endif() + string(REGEX REPLACE "/$" "" abs_install_dir "${abs_install_dir}") + set_target_properties( + "${LIB_NAME}.py_lib_install" + PROPERTIES + INSTALL_DIR "${abs_install_dir}" + BUILD_INSTALL_DIR "${build_install_dir}" + ) +endfunction() + +# +# Install an FB-style packaged python binary. +# +# - DESTINATION : +# Associate the installed target files with the given export-name. +# +function(install_fb_python_executable TARGET) + # Parse the arguments + set(one_value_args DESTINATION) + set(multi_value_args) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_DESTINATION) + set(ARG_DESTINATION bin) + endif() + + install( + PROGRAMS "$" + DESTINATION "${ARG_DESTINATION}" + ) +endfunction() + +# +# Install a python library. +# +# - EXPORT : +# Associate the installed target files with the given export-name. +# +# Note that unlike the built-in CMake install() function we do not accept a +# DESTINATION parameter. Instead, use the INSTALL_DIR parameter to +# add_fb_python_library() to set the installation location. +# +function(install_fb_python_library LIB_NAME) + set(one_value_args EXPORT) + fb_cmake_parse_args(ARG "" "${one_value_args}" "" "${ARGN}") + + # Export our "${LIB_NAME}.py_lib" target so that it will be available to + # downstream projects in our installed CMake config files. + if(DEFINED ARG_EXPORT) + install(TARGETS "${LIB_NAME}.py_lib" EXPORT "${ARG_EXPORT}") + endif() + + # add_fb_python_library() emits a .py_lib_install target that will prepare + # the installation directory. However, it isn't part of the "ALL" target and + # therefore isn't built by default. + # + # Make sure the ALL target depends on it now. We have to do this by + # introducing yet another custom target. + # Add it as a dependency to the ALL target now. + add_custom_target("${LIB_NAME}.py_lib_install_all" ALL) + add_dependencies( + "${LIB_NAME}.py_lib_install_all" "${LIB_NAME}.py_lib_install" + ) + + # Copy the intermediate install directory generated at build time into + # the desired install location. + get_target_property(dest_dir "${LIB_NAME}.py_lib_install" "INSTALL_DIR") + get_target_property( + build_install_dir "${LIB_NAME}.py_lib_install" "BUILD_INSTALL_DIR" + ) + install( + DIRECTORY "${build_install_dir}/${LIB_NAME}" + DESTINATION "${dest_dir}" + ) + install( + FILES "${build_install_dir}/${LIB_NAME}.manifest" + DESTINATION "${dest_dir}" + ) +endfunction() + +# Helper macro to process the BASE_DIR and NAMESPACE arguments for +# add_fb_python_executable() and add_fb_python_executable() +macro(fb_py_process_default_args NAMESPACE_VAR BASE_DIR_VAR) + # If the namespace was not specified, default to the relative path to the + # current directory (starting from the repository root). + if(NOT DEFINED "${NAMESPACE_VAR}") + file( + RELATIVE_PATH "${NAMESPACE_VAR}" + "${CMAKE_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}" + ) + endif() + + if(NOT DEFINED "${BASE_DIR_VAR}") + # If the base directory was not specified, default to the current directory + set("${BASE_DIR_VAR}" "${CMAKE_CURRENT_SOURCE_DIR}") + else() + # If the base directory was specified, always convert it to an + # absolute path. + get_filename_component("${BASE_DIR_VAR}" "${${BASE_DIR_VAR}}" ABSOLUTE) + endif() +endmacro() + +function(fb_py_check_available) + # Make sure that Python 3 and our make_fbpy_archive.py helper script are + # available. + if(NOT Python3_EXECUTABLE) + if(FBPY_FIND_PYTHON_ERR) + message(FATAL_ERROR "Unable to find Python 3: ${FBPY_FIND_PYTHON_ERR}") + else() + message(FATAL_ERROR "Unable to find Python 3") + endif() + endif() + + if (NOT FB_MAKE_PYTHON_ARCHIVE) + message( + FATAL_ERROR "unable to find make_fbpy_archive.py helper program (it " + "should be located in the same directory as FBPythonBinary.cmake)" + ) + endif() +endfunction() + +function( + fb_py_compute_dest_path + src_path_output dest_path_output src_path namespace_dir base_dir +) + if("${src_path}" MATCHES "=") + # We want to split the string on the `=` sign, but cmake doesn't + # provide much in the way of helpers for this, so we rewrite the + # `=` sign to `;` so that we can treat it as a cmake list and + # then index into the components + string(REPLACE "=" ";" src_path_list "${src_path}") + list(GET src_path_list 0 src_path) + # Note that we ignore the `namespace_dir` in the alias case + # in order to allow aliasing a source to the top level `__main__.py` + # filename. + list(GET src_path_list 1 dest_path) + else() + unset(dest_path) + endif() + + get_filename_component(abs_source "${src_path}" ABSOLUTE) + if(NOT DEFINED dest_path) + file(RELATIVE_PATH rel_src "${ARG_BASE_DIR}" "${abs_source}") + if("${rel_src}" MATCHES "^../") + message( + FATAL_ERROR "${LIB_NAME}: source file \"${abs_source}\" is not inside " + "the base directory ${ARG_BASE_DIR}" + ) + endif() + set(dest_path "${namespace_dir}${rel_src}") + endif() + + set("${src_path_output}" "${abs_source}" PARENT_SCOPE) + set("${dest_path_output}" "${dest_path}" PARENT_SCOPE) +endfunction() diff --git a/build/fbcode_builder/CMake/FBPythonTestAddTests.cmake b/build/fbcode_builder/CMake/FBPythonTestAddTests.cmake new file mode 100644 index 000000000..d73c055d8 --- /dev/null +++ b/build/fbcode_builder/CMake/FBPythonTestAddTests.cmake @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# Add a command to be emitted to the CTest file +set(ctest_script) +function(add_command CMD) + set(escaped_args "") + foreach(arg ${ARGN}) + # Escape all arguments using "Bracket Argument" syntax + # We could skip this for argument that don't contain any special + # characters if we wanted to make the output slightly more human-friendly. + set(escaped_args "${escaped_args} [==[${arg}]==]") + endforeach() + set(ctest_script "${ctest_script}${CMD}(${escaped_args})\n" PARENT_SCOPE) +endfunction() + +if(NOT EXISTS "${TEST_EXECUTABLE}") + message(FATAL_ERROR "Test executable does not exist: ${TEST_EXECUTABLE}") +endif() +execute_process( + COMMAND ${CMAKE_COMMAND} -E env ${TEST_ENV} "${TEST_INTERPRETER}" "${TEST_EXECUTABLE}" --list-tests + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + OUTPUT_VARIABLE output + RESULT_VARIABLE result +) +if(NOT "${result}" EQUAL 0) + string(REPLACE "\n" "\n " output "${output}") + message( + FATAL_ERROR + "Error running test executable: ${TEST_EXECUTABLE}\n" + "Output:\n" + " ${output}\n" + ) +endif() + +# Parse output +string(REPLACE "\n" ";" tests_list "${output}") +foreach(test_name ${tests_list}) + add_command( + add_test + "${TEST_PREFIX}${test_name}" + ${CMAKE_COMMAND} -E env ${TEST_ENV} + "${TEST_INTERPRETER}" "${TEST_EXECUTABLE}" "${test_name}" + ) + add_command( + set_tests_properties + "${TEST_PREFIX}${test_name}" + PROPERTIES + WORKING_DIRECTORY "${TEST_WORKING_DIR}" + ${TEST_PROPERTIES} + ) +endforeach() + +# Set a list of discovered tests in the parent scope, in case users +# want access to this list as a CMake variable +if(TEST_LIST) + add_command(set ${TEST_LIST} ${tests_list}) +endif() + +file(WRITE "${CTEST_FILE}" "${ctest_script}") diff --git a/build/fbcode_builder/CMake/FBThriftCppLibrary.cmake b/build/fbcode_builder/CMake/FBThriftCppLibrary.cmake new file mode 100644 index 000000000..670771a46 --- /dev/null +++ b/build/fbcode_builder/CMake/FBThriftCppLibrary.cmake @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) + +# Generate a C++ library from a thrift file +# +# Parameters: +# - SERVICES [ ...] +# The names of the services defined in the thrift file. +# - DEPENDS [ ...] +# A list of other thrift C++ libraries that this library depends on. +# - OPTIONS [ ...] +# A list of options to pass to the thrift compiler. +# - INCLUDE_DIR +# The sub-directory where generated headers will be installed. +# Defaults to "include" if not specified. The caller must still call +# install() to install the thrift library if desired. +# - THRIFT_INCLUDE_DIR +# The sub-directory where generated headers will be installed. +# Defaults to "${INCLUDE_DIR}/thrift-files" if not specified. +# The caller must still call install() to install the thrift library if +# desired. +function(add_fbthrift_cpp_library LIB_NAME THRIFT_FILE) + # Parse the arguments + set(one_value_args INCLUDE_DIR THRIFT_INCLUDE_DIR) + set(multi_value_args SERVICES DEPENDS OPTIONS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + if(NOT DEFINED ARG_INCLUDE_DIR) + set(ARG_INCLUDE_DIR "include") + endif() + if(NOT DEFINED ARG_THRIFT_INCLUDE_DIR) + set(ARG_THRIFT_INCLUDE_DIR "${ARG_INCLUDE_DIR}/thrift-files") + endif() + + get_filename_component(base ${THRIFT_FILE} NAME_WE) + get_filename_component( + output_dir + ${CMAKE_CURRENT_BINARY_DIR}/${THRIFT_FILE} + DIRECTORY + ) + + # Generate relative paths in #includes + file( + RELATIVE_PATH include_prefix + "${CMAKE_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + ) + get_filename_component(include_prefix ${include_prefix} DIRECTORY) + + if (NOT "${include_prefix}" STREQUAL "") + list(APPEND ARG_OPTIONS "include_prefix=${include_prefix}") + endif() + # CMake 3.12 is finally getting a list(JOIN) function, but until then + # treating the list as a string and replacing the semicolons is good enough. + string(REPLACE ";" "," GEN_ARG_STR "${ARG_OPTIONS}") + + # Compute the list of generated files + list(APPEND generated_headers + "${output_dir}/gen-cpp2/${base}_constants.h" + "${output_dir}/gen-cpp2/${base}_types.h" + "${output_dir}/gen-cpp2/${base}_types.tcc" + "${output_dir}/gen-cpp2/${base}_types_custom_protocol.h" + "${output_dir}/gen-cpp2/${base}_metadata.h" + ) + list(APPEND generated_sources + "${output_dir}/gen-cpp2/${base}_constants.cpp" + "${output_dir}/gen-cpp2/${base}_data.h" + "${output_dir}/gen-cpp2/${base}_data.cpp" + "${output_dir}/gen-cpp2/${base}_types.cpp" + "${output_dir}/gen-cpp2/${base}_metadata.cpp" + ) + foreach(service IN LISTS ARG_SERVICES) + list(APPEND generated_headers + "${output_dir}/gen-cpp2/${service}.h" + "${output_dir}/gen-cpp2/${service}.tcc" + "${output_dir}/gen-cpp2/${service}AsyncClient.h" + "${output_dir}/gen-cpp2/${service}_custom_protocol.h" + ) + list(APPEND generated_sources + "${output_dir}/gen-cpp2/${service}.cpp" + "${output_dir}/gen-cpp2/${service}AsyncClient.cpp" + "${output_dir}/gen-cpp2/${service}_processmap_binary.cpp" + "${output_dir}/gen-cpp2/${service}_processmap_compact.cpp" + ) + endforeach() + + # This generator expression gets the list of include directories required + # for all of our dependencies. + # It requires using COMMAND_EXPAND_LISTS in the add_custom_command() call + # below. COMMAND_EXPAND_LISTS is only available in CMake 3.8+ + # If we really had to support older versions of CMake we would probably need + # to use a wrapper script around the thrift compiler that could take the + # include list as a single argument and split it up before invoking the + # thrift compiler. + if (NOT POLICY CMP0067) + message(FATAL_ERROR "add_fbthrift_cpp_library() requires CMake 3.8+") + endif() + set( + thrift_include_options + "-I;$,;-I;>" + ) + + # Emit the rule to run the thrift compiler + add_custom_command( + OUTPUT + ${generated_headers} + ${generated_sources} + COMMAND_EXPAND_LISTS + COMMAND + "${CMAKE_COMMAND}" -E make_directory "${output_dir}" + COMMAND + "${FBTHRIFT_COMPILER}" + --strict + --gen "mstch_cpp2:${GEN_ARG_STR}" + "${thrift_include_options}" + -o "${output_dir}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + WORKING_DIRECTORY + "${CMAKE_BINARY_DIR}" + MAIN_DEPENDENCY + "${THRIFT_FILE}" + DEPENDS + ${ARG_DEPENDS} + "${FBTHRIFT_COMPILER}" + ) + + # Now emit the library rule to compile the sources + if (BUILD_SHARED_LIBS) + set(LIB_TYPE SHARED) + else () + set(LIB_TYPE STATIC) + endif () + + add_library( + "${LIB_NAME}" ${LIB_TYPE} + ${generated_sources} + ) + + target_include_directories( + "${LIB_NAME}" + PUBLIC + "$" + "$" + ) + target_link_libraries( + "${LIB_NAME}" + PUBLIC + ${ARG_DEPENDS} + FBThrift::thriftcpp2 + Folly::folly + ) + + # Add ${generated_headers} to the PUBLIC_HEADER property for ${LIB_NAME} + # + # This allows callers to install it using + # "install(TARGETS ${LIB_NAME} PUBLIC_HEADER)" + # However, note that CMake's PUBLIC_HEADER behavior is rather inflexible, + # and does have any way to preserve header directory structure. Callers + # must be careful to use the correct PUBLIC_HEADER DESTINATION parameter + # when doing this, to put the files the correct directory themselves. + # We define a HEADER_INSTALL_DIR property with the include directory prefix, + # so typically callers should specify the PUBLIC_HEADER DESTINATION as + # "$" + set_property( + TARGET "${LIB_NAME}" + PROPERTY PUBLIC_HEADER ${generated_headers} + ) + + # Define a dummy interface library to help propagate the thrift include + # directories between dependencies. + add_library("${LIB_NAME}.thrift_includes" INTERFACE) + target_include_directories( + "${LIB_NAME}.thrift_includes" + INTERFACE + "$" + "$" + ) + foreach(dep IN LISTS ARG_DEPENDS) + target_link_libraries( + "${LIB_NAME}.thrift_includes" + INTERFACE "${dep}.thrift_includes" + ) + endforeach() + + set_target_properties( + "${LIB_NAME}" + PROPERTIES + EXPORT_PROPERTIES "THRIFT_INSTALL_DIR" + THRIFT_INSTALL_DIR "${ARG_THRIFT_INCLUDE_DIR}/${include_prefix}" + HEADER_INSTALL_DIR "${ARG_INCLUDE_DIR}/${include_prefix}/gen-cpp2" + ) +endfunction() diff --git a/build/fbcode_builder/CMake/FBThriftLibrary.cmake b/build/fbcode_builder/CMake/FBThriftLibrary.cmake new file mode 100644 index 000000000..e4280e2a4 --- /dev/null +++ b/build/fbcode_builder/CMake/FBThriftLibrary.cmake @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) +include(FBThriftPyLibrary) +include(FBThriftCppLibrary) + +# +# add_fbthrift_library() +# +# This is a convenience function that generates thrift libraries for multiple +# languages. +# +# For example: +# add_fbthrift_library( +# foo foo.thrift +# LANGUAGES cpp py +# SERVICES Foo +# DEPENDS bar) +# +# will be expanded into two separate calls: +# +# add_fbthrift_cpp_library(foo_cpp foo.thrift SERVICES Foo DEPENDS bar_cpp) +# add_fbthrift_py_library(foo_py foo.thrift SERVICES Foo DEPENDS bar_py) +# +function(add_fbthrift_library LIB_NAME THRIFT_FILE) + # Parse the arguments + set(one_value_args PY_NAMESPACE INCLUDE_DIR THRIFT_INCLUDE_DIR) + set(multi_value_args SERVICES DEPENDS LANGUAGES CPP_OPTIONS PY_OPTIONS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_INCLUDE_DIR) + set(ARG_INCLUDE_DIR "include") + endif() + if(NOT DEFINED ARG_THRIFT_INCLUDE_DIR) + set(ARG_THRIFT_INCLUDE_DIR "${ARG_INCLUDE_DIR}/thrift-files") + endif() + + # CMake 3.12+ adds list(TRANSFORM) which would be nice to use here, but for + # now we still want to support older versions of CMake. + set(CPP_DEPENDS) + set(PY_DEPENDS) + foreach(dep IN LISTS ARG_DEPENDS) + list(APPEND CPP_DEPENDS "${dep}_cpp") + list(APPEND PY_DEPENDS "${dep}_py") + endforeach() + + foreach(lang IN LISTS ARG_LANGUAGES) + if ("${lang}" STREQUAL "cpp") + add_fbthrift_cpp_library( + "${LIB_NAME}_cpp" "${THRIFT_FILE}" + SERVICES ${ARG_SERVICES} + DEPENDS ${CPP_DEPENDS} + OPTIONS ${ARG_CPP_OPTIONS} + INCLUDE_DIR "${ARG_INCLUDE_DIR}" + THRIFT_INCLUDE_DIR "${ARG_THRIFT_INCLUDE_DIR}" + ) + elseif ("${lang}" STREQUAL "py" OR "${lang}" STREQUAL "python") + if (DEFINED ARG_PY_NAMESPACE) + set(namespace_args NAMESPACE "${ARG_PY_NAMESPACE}") + endif() + add_fbthrift_py_library( + "${LIB_NAME}_py" "${THRIFT_FILE}" + SERVICES ${ARG_SERVICES} + ${namespace_args} + DEPENDS ${PY_DEPENDS} + OPTIONS ${ARG_PY_OPTIONS} + THRIFT_INCLUDE_DIR "${ARG_THRIFT_INCLUDE_DIR}" + ) + else() + message( + FATAL_ERROR "unknown language for thrift library ${LIB_NAME}: ${lang}" + ) + endif() + endforeach() +endfunction() diff --git a/build/fbcode_builder/CMake/FBThriftPyLibrary.cmake b/build/fbcode_builder/CMake/FBThriftPyLibrary.cmake new file mode 100644 index 000000000..7bd8879ee --- /dev/null +++ b/build/fbcode_builder/CMake/FBThriftPyLibrary.cmake @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) +include(FBPythonBinary) + +# Generate a Python library from a thrift file +function(add_fbthrift_py_library LIB_NAME THRIFT_FILE) + # Parse the arguments + set(one_value_args NAMESPACE THRIFT_INCLUDE_DIR) + set(multi_value_args SERVICES DEPENDS OPTIONS) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_THRIFT_INCLUDE_DIR) + set(ARG_THRIFT_INCLUDE_DIR "include/thrift-files") + endif() + + get_filename_component(base ${THRIFT_FILE} NAME_WE) + set(output_dir "${CMAKE_CURRENT_BINARY_DIR}/${THRIFT_FILE}-py") + + # Parse the namespace value + if (NOT DEFINED ARG_NAMESPACE) + set(ARG_NAMESPACE "${base}") + endif() + + string(REPLACE "." "/" namespace_dir "${ARG_NAMESPACE}") + set(py_output_dir "${output_dir}/gen-py/${namespace_dir}") + list(APPEND generated_sources + "${py_output_dir}/__init__.py" + "${py_output_dir}/ttypes.py" + "${py_output_dir}/constants.py" + ) + foreach(service IN LISTS ARG_SERVICES) + list(APPEND generated_sources + ${py_output_dir}/${service}.py + ) + endforeach() + + # Define a dummy interface library to help propagate the thrift include + # directories between dependencies. + add_library("${LIB_NAME}.thrift_includes" INTERFACE) + target_include_directories( + "${LIB_NAME}.thrift_includes" + INTERFACE + "$" + "$" + ) + foreach(dep IN LISTS ARG_DEPENDS) + target_link_libraries( + "${LIB_NAME}.thrift_includes" + INTERFACE "${dep}.thrift_includes" + ) + endforeach() + + # This generator expression gets the list of include directories required + # for all of our dependencies. + # It requires using COMMAND_EXPAND_LISTS in the add_custom_command() call + # below. COMMAND_EXPAND_LISTS is only available in CMake 3.8+ + # If we really had to support older versions of CMake we would probably need + # to use a wrapper script around the thrift compiler that could take the + # include list as a single argument and split it up before invoking the + # thrift compiler. + if (NOT POLICY CMP0067) + message(FATAL_ERROR "add_fbthrift_py_library() requires CMake 3.8+") + endif() + set( + thrift_include_options + "-I;$,;-I;>" + ) + + # Always force generation of "new-style" python classes for Python 2 + list(APPEND ARG_OPTIONS "new_style") + # CMake 3.12 is finally getting a list(JOIN) function, but until then + # treating the list as a string and replacing the semicolons is good enough. + string(REPLACE ";" "," GEN_ARG_STR "${ARG_OPTIONS}") + + # Emit the rule to run the thrift compiler + add_custom_command( + OUTPUT + ${generated_sources} + COMMAND_EXPAND_LISTS + COMMAND + "${CMAKE_COMMAND}" -E make_directory "${output_dir}" + COMMAND + "${FBTHRIFT_COMPILER}" + --strict + --gen "py:${GEN_ARG_STR}" + "${thrift_include_options}" + -o "${output_dir}" + "${CMAKE_CURRENT_SOURCE_DIR}/${THRIFT_FILE}" + WORKING_DIRECTORY + "${CMAKE_BINARY_DIR}" + MAIN_DEPENDENCY + "${THRIFT_FILE}" + DEPENDS + "${FBTHRIFT_COMPILER}" + ) + + # We always want to pass the namespace as "" to this call: + # thrift will already emit the files with the desired namespace prefix under + # gen-py. We don't want add_fb_python_library() to prepend the namespace a + # second time. + add_fb_python_library( + "${LIB_NAME}" + BASE_DIR "${output_dir}/gen-py" + NAMESPACE "" + SOURCES ${generated_sources} + DEPENDS ${ARG_DEPENDS} FBThrift::thrift_py + ) +endfunction() diff --git a/build/fbcode_builder/CMake/FindGMock.cmake b/build/fbcode_builder/CMake/FindGMock.cmake new file mode 100644 index 000000000..cd042dd9c --- /dev/null +++ b/build/fbcode_builder/CMake/FindGMock.cmake @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Find libgmock +# +# LIBGMOCK_DEFINES - List of defines when using libgmock. +# LIBGMOCK_INCLUDE_DIR - where to find gmock/gmock.h, etc. +# LIBGMOCK_LIBRARIES - List of libraries when using libgmock. +# LIBGMOCK_FOUND - True if libgmock found. + +IF (LIBGMOCK_INCLUDE_DIR) + # Already in cache, be silent + SET(LIBGMOCK_FIND_QUIETLY TRUE) +ENDIF () + +find_package(GTest CONFIG QUIET) +if (TARGET GTest::gmock) + get_target_property(LIBGMOCK_DEFINES GTest::gtest INTERFACE_COMPILE_DEFINITIONS) + if (NOT ${LIBGMOCK_DEFINES}) + # Explicitly set to empty string if not found to avoid it being + # set to NOTFOUND and breaking compilation + set(LIBGMOCK_DEFINES "") + endif() + get_target_property(LIBGMOCK_INCLUDE_DIR GTest::gtest INTERFACE_INCLUDE_DIRECTORIES) + set(LIBGMOCK_LIBRARIES GTest::gmock_main GTest::gmock GTest::gtest) + set(LIBGMOCK_FOUND ON) + message(STATUS "Found gmock via config, defines=${LIBGMOCK_DEFINES}, include=${LIBGMOCK_INCLUDE_DIR}, libs=${LIBGMOCK_LIBRARIES}") +else() + + FIND_PATH(LIBGMOCK_INCLUDE_DIR gmock/gmock.h) + + FIND_LIBRARY(LIBGMOCK_MAIN_LIBRARY_DEBUG NAMES gmock_maind) + FIND_LIBRARY(LIBGMOCK_MAIN_LIBRARY_RELEASE NAMES gmock_main) + FIND_LIBRARY(LIBGMOCK_LIBRARY_DEBUG NAMES gmockd) + FIND_LIBRARY(LIBGMOCK_LIBRARY_RELEASE NAMES gmock) + FIND_LIBRARY(LIBGTEST_LIBRARY_DEBUG NAMES gtestd) + FIND_LIBRARY(LIBGTEST_LIBRARY_RELEASE NAMES gtest) + + find_package(Threads REQUIRED) + INCLUDE(SelectLibraryConfigurations) + SELECT_LIBRARY_CONFIGURATIONS(LIBGMOCK_MAIN) + SELECT_LIBRARY_CONFIGURATIONS(LIBGMOCK) + SELECT_LIBRARY_CONFIGURATIONS(LIBGTEST) + + set(LIBGMOCK_LIBRARIES + ${LIBGMOCK_MAIN_LIBRARY} + ${LIBGMOCK_LIBRARY} + ${LIBGTEST_LIBRARY} + Threads::Threads + ) + + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + # The GTEST_LINKED_AS_SHARED_LIBRARY macro must be set properly on Windows. + # + # There isn't currently an easy way to determine if a library was compiled as + # a shared library on Windows, so just assume we've been built against a + # shared build of gmock for now. + SET(LIBGMOCK_DEFINES "GTEST_LINKED_AS_SHARED_LIBRARY=1" CACHE STRING "") + endif() + + # handle the QUIETLY and REQUIRED arguments and set LIBGMOCK_FOUND to TRUE if + # all listed variables are TRUE + INCLUDE(FindPackageHandleStandardArgs) + FIND_PACKAGE_HANDLE_STANDARD_ARGS( + GMock + DEFAULT_MSG + LIBGMOCK_MAIN_LIBRARY + LIBGMOCK_LIBRARY + LIBGTEST_LIBRARY + LIBGMOCK_LIBRARIES + LIBGMOCK_INCLUDE_DIR + ) + + MARK_AS_ADVANCED( + LIBGMOCK_DEFINES + LIBGMOCK_MAIN_LIBRARY + LIBGMOCK_LIBRARY + LIBGTEST_LIBRARY + LIBGMOCK_LIBRARIES + LIBGMOCK_INCLUDE_DIR + ) +endif() diff --git a/build/fbcode_builder/CMake/FindGflags.cmake b/build/fbcode_builder/CMake/FindGflags.cmake new file mode 100644 index 000000000..c00896a34 --- /dev/null +++ b/build/fbcode_builder/CMake/FindGflags.cmake @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Find libgflags. +# There's a lot of compatibility cruft going on in here, both +# to deal with changes across the FB consumers of this and also +# to deal with variances in behavior of cmake itself. +# +# Since this file is named FindGflags.cmake the cmake convention +# is for the module to export both GFLAGS_FOUND and Gflags_FOUND. +# The convention expected by consumers is that we export the +# following variables, even though these do not match the cmake +# conventions: +# +# LIBGFLAGS_INCLUDE_DIR - where to find gflags/gflags.h, etc. +# LIBGFLAGS_LIBRARY - List of libraries when using libgflags. +# LIBGFLAGS_FOUND - True if libgflags found. +# +# We need to be able to locate gflags both from an installed +# cmake config file and just from the raw headers and libs, so +# test for the former and then the latter, and then stick +# the results together and export them into the variables +# listed above. +# +# For forwards compatibility, we export the following variables: +# +# gflags_INCLUDE_DIR - where to find gflags/gflags.h, etc. +# gflags_TARGET / GFLAGS_TARGET / gflags_LIBRARIES +# - List of libraries when using libgflags. +# gflags_FOUND - True if libgflags found. +# + +IF (LIBGFLAGS_INCLUDE_DIR) + # Already in cache, be silent + SET(Gflags_FIND_QUIETLY TRUE) +ENDIF () + +find_package(gflags CONFIG QUIET) +if (gflags_FOUND) + if (NOT Gflags_FIND_QUIETLY) + message(STATUS "Found gflags from package config ${gflags_CONFIG}") + endif() + # Re-export the config-specified libs with our local names + set(LIBGFLAGS_LIBRARY ${gflags_LIBRARIES}) + set(LIBGFLAGS_INCLUDE_DIR ${gflags_INCLUDE_DIR}) + if(NOT EXISTS "${gflags_INCLUDE_DIR}") + # The gflags-devel RPM on recent RedHat-based systems is somewhat broken. + # RedHat symlinks /lib64 to /usr/lib64, and this breaks some of the + # relative path computation performed in gflags-config.cmake. The package + # config file ends up being found via /lib64, but the relative path + # computation it does only works if it was found in /usr/lib64. + # If gflags_INCLUDE_DIR does not actually exist, simply default it to + # /usr/include on these systems. + set(LIBGFLAGS_INCLUDE_DIR "/usr/include") + endif() + set(LIBGFLAGS_FOUND ${gflags_FOUND}) + # cmake module compat + set(GFLAGS_FOUND ${gflags_FOUND}) + set(Gflags_FOUND ${gflags_FOUND}) +else() + FIND_PATH(LIBGFLAGS_INCLUDE_DIR gflags/gflags.h) + + FIND_LIBRARY(LIBGFLAGS_LIBRARY_DEBUG NAMES gflagsd gflags_staticd) + FIND_LIBRARY(LIBGFLAGS_LIBRARY_RELEASE NAMES gflags gflags_static) + + INCLUDE(SelectLibraryConfigurations) + SELECT_LIBRARY_CONFIGURATIONS(LIBGFLAGS) + + # handle the QUIETLY and REQUIRED arguments and set LIBGFLAGS_FOUND to TRUE if + # all listed variables are TRUE + INCLUDE(FindPackageHandleStandardArgs) + FIND_PACKAGE_HANDLE_STANDARD_ARGS(gflags DEFAULT_MSG LIBGFLAGS_LIBRARY LIBGFLAGS_INCLUDE_DIR) + # cmake module compat + set(Gflags_FOUND ${GFLAGS_FOUND}) + # compat with some existing FindGflags consumers + set(LIBGFLAGS_FOUND ${GFLAGS_FOUND}) + + # Compat with the gflags CONFIG based detection + set(gflags_FOUND ${GFLAGS_FOUND}) + set(gflags_INCLUDE_DIR ${LIBGFLAGS_INCLUDE_DIR}) + set(gflags_LIBRARIES ${LIBGFLAGS_LIBRARY}) + set(GFLAGS_TARGET ${LIBGFLAGS_LIBRARY}) + set(gflags_TARGET ${LIBGFLAGS_LIBRARY}) + + MARK_AS_ADVANCED(LIBGFLAGS_LIBRARY LIBGFLAGS_INCLUDE_DIR) +endif() + +# Compat with the gflags CONFIG based detection +if (LIBGFLAGS_FOUND AND NOT TARGET gflags) + add_library(gflags UNKNOWN IMPORTED) + if(TARGET gflags-shared) + # If the installed gflags CMake package config defines a gflags-shared + # target but not gflags, just make the gflags target that we define + # depend on the gflags-shared target. + target_link_libraries(gflags INTERFACE gflags-shared) + # Export LIBGFLAGS_LIBRARY as the gflags-shared target in this case. + set(LIBGFLAGS_LIBRARY gflags-shared) + else() + set_target_properties( + gflags + PROPERTIES + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${LIBGFLAGS_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${LIBGFLAGS_INCLUDE_DIR}" + ) + endif() +endif() diff --git a/build/fbcode_builder/CMake/FindGlog.cmake b/build/fbcode_builder/CMake/FindGlog.cmake new file mode 100644 index 000000000..752647cb3 --- /dev/null +++ b/build/fbcode_builder/CMake/FindGlog.cmake @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# - Try to find Glog +# Once done, this will define +# +# GLOG_FOUND - system has Glog +# GLOG_INCLUDE_DIRS - the Glog include directories +# GLOG_LIBRARIES - link these to use Glog + +include(FindPackageHandleStandardArgs) +include(SelectLibraryConfigurations) + +find_library(GLOG_LIBRARY_RELEASE glog + PATHS ${GLOG_LIBRARYDIR}) +find_library(GLOG_LIBRARY_DEBUG glogd + PATHS ${GLOG_LIBRARYDIR}) + +find_path(GLOG_INCLUDE_DIR glog/logging.h + PATHS ${GLOG_INCLUDEDIR}) + +select_library_configurations(GLOG) + +find_package_handle_standard_args(glog DEFAULT_MSG + GLOG_LIBRARY + GLOG_INCLUDE_DIR) + +mark_as_advanced( + GLOG_LIBRARY + GLOG_INCLUDE_DIR) + +set(GLOG_LIBRARIES ${GLOG_LIBRARY}) +set(GLOG_INCLUDE_DIRS ${GLOG_INCLUDE_DIR}) + +if (NOT TARGET glog::glog) + add_library(glog::glog UNKNOWN IMPORTED) + set_target_properties(glog::glog PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${GLOG_INCLUDE_DIRS}") + set_target_properties(glog::glog PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES "C" IMPORTED_LOCATION "${GLOG_LIBRARIES}") +endif() diff --git a/build/fbcode_builder/CMake/FindLibEvent.cmake b/build/fbcode_builder/CMake/FindLibEvent.cmake new file mode 100644 index 000000000..dd11ebd84 --- /dev/null +++ b/build/fbcode_builder/CMake/FindLibEvent.cmake @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# - Find LibEvent (a cross event library) +# This module defines +# LIBEVENT_INCLUDE_DIR, where to find LibEvent headers +# LIBEVENT_LIB, LibEvent libraries +# LibEvent_FOUND, If false, do not try to use libevent + +set(LibEvent_EXTRA_PREFIXES /usr/local /opt/local "$ENV{HOME}") +foreach(prefix ${LibEvent_EXTRA_PREFIXES}) + list(APPEND LibEvent_INCLUDE_PATHS "${prefix}/include") + list(APPEND LibEvent_LIB_PATHS "${prefix}/lib") +endforeach() + +find_package(Libevent CONFIG QUIET) +if (TARGET event) + # Re-export the config under our own names + + # Somewhat gross, but some vcpkg installed libevents have a relative + # `include` path exported into LIBEVENT_INCLUDE_DIRS, which triggers + # a cmake error because it resolves to the `include` dir within the + # folly repo, which is not something cmake allows to be in the + # INTERFACE_INCLUDE_DIRECTORIES. Thankfully on such a system the + # actual include directory is already part of the global include + # directories, so we can just skip it. + if (NOT "${LIBEVENT_INCLUDE_DIRS}" STREQUAL "include") + set(LIBEVENT_INCLUDE_DIR ${LIBEVENT_INCLUDE_DIRS}) + else() + set(LIBEVENT_INCLUDE_DIR) + endif() + + # Unfortunately, with a bare target name `event`, downstream consumers + # of the package that depends on `Libevent` located via CONFIG end + # up exporting just a bare `event` in their libraries. This is problematic + # because this in interpreted as just `-levent` with no library path. + # When libevent is not installed in the default installation prefix + # this results in linker errors. + # To resolve this, we ask cmake to lookup the full path to the library + # and use that instead. + cmake_policy(PUSH) + if(POLICY CMP0026) + # Allow reading the LOCATION property + cmake_policy(SET CMP0026 OLD) + endif() + get_target_property(LIBEVENT_LIB event LOCATION) + cmake_policy(POP) + + set(LibEvent_FOUND ${Libevent_FOUND}) + if (NOT LibEvent_FIND_QUIETLY) + message(STATUS "Found libevent from package config include=${LIBEVENT_INCLUDE_DIRS} lib=${LIBEVENT_LIB}") + endif() +else() + find_path(LIBEVENT_INCLUDE_DIR event.h PATHS ${LibEvent_INCLUDE_PATHS}) + find_library(LIBEVENT_LIB NAMES event PATHS ${LibEvent_LIB_PATHS}) + + if (LIBEVENT_LIB AND LIBEVENT_INCLUDE_DIR) + set(LibEvent_FOUND TRUE) + set(LIBEVENT_LIB ${LIBEVENT_LIB}) + else () + set(LibEvent_FOUND FALSE) + endif () + + if (LibEvent_FOUND) + if (NOT LibEvent_FIND_QUIETLY) + message(STATUS "Found libevent: ${LIBEVENT_LIB}") + endif () + else () + if (LibEvent_FIND_REQUIRED) + message(FATAL_ERROR "Could NOT find libevent.") + endif () + message(STATUS "libevent NOT found.") + endif () + + mark_as_advanced( + LIBEVENT_LIB + LIBEVENT_INCLUDE_DIR + ) +endif() diff --git a/build/fbcode_builder/CMake/FindLibUnwind.cmake b/build/fbcode_builder/CMake/FindLibUnwind.cmake new file mode 100644 index 000000000..b01a674a5 --- /dev/null +++ b/build/fbcode_builder/CMake/FindLibUnwind.cmake @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +find_path(LIBUNWIND_INCLUDE_DIR NAMES libunwind.h) +mark_as_advanced(LIBUNWIND_INCLUDE_DIR) + +find_library(LIBUNWIND_LIBRARY NAMES unwind) +mark_as_advanced(LIBUNWIND_LIBRARY) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + LIBUNWIND + REQUIRED_VARS LIBUNWIND_LIBRARY LIBUNWIND_INCLUDE_DIR) + +if(LIBUNWIND_FOUND) + set(LIBUNWIND_LIBRARIES ${LIBUNWIND_LIBRARY}) + set(LIBUNWIND_INCLUDE_DIRS ${LIBUNWIND_INCLUDE_DIR}) +endif() diff --git a/build/fbcode_builder/CMake/FindPCRE.cmake b/build/fbcode_builder/CMake/FindPCRE.cmake new file mode 100644 index 000000000..32ccb3725 --- /dev/null +++ b/build/fbcode_builder/CMake/FindPCRE.cmake @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +include(FindPackageHandleStandardArgs) +find_path(PCRE_INCLUDE_DIR NAMES pcre.h) +find_library(PCRE_LIBRARY NAMES pcre) +find_package_handle_standard_args( + PCRE + DEFAULT_MSG + PCRE_LIBRARY + PCRE_INCLUDE_DIR +) +mark_as_advanced(PCRE_INCLUDE_DIR PCRE_LIBRARY) diff --git a/build/fbcode_builder/CMake/FindRe2.cmake b/build/fbcode_builder/CMake/FindRe2.cmake new file mode 100644 index 000000000..013ae7761 --- /dev/null +++ b/build/fbcode_builder/CMake/FindRe2.cmake @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This software may be used and distributed according to the terms of the +# GNU General Public License version 2. + +find_library(RE2_LIBRARY re2) +mark_as_advanced(RE2_LIBRARY) + +find_path(RE2_INCLUDE_DIR NAMES re2/re2.h) +mark_as_advanced(RE2_INCLUDE_DIR) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + RE2 + REQUIRED_VARS RE2_LIBRARY RE2_INCLUDE_DIR) + +if(RE2_FOUND) + set(RE2_LIBRARY ${RE2_LIBRARY}) + set(RE2_INCLUDE_DIR, ${RE2_INCLUDE_DIR}) +endif() diff --git a/build/fbcode_builder/CMake/FindSodium.cmake b/build/fbcode_builder/CMake/FindSodium.cmake new file mode 100644 index 000000000..3c3f1245c --- /dev/null +++ b/build/fbcode_builder/CMake/FindSodium.cmake @@ -0,0 +1,297 @@ +# Written in 2016 by Henrik Steffen Gaßmann +# +# To the extent possible under law, the author(s) have dedicated all +# copyright and related and neighboring rights to this software to the +# public domain worldwide. This software is distributed without any warranty. +# +# You should have received a copy of the CC0 Public Domain Dedication +# along with this software. If not, see +# +# http://creativecommons.org/publicdomain/zero/1.0/ +# +######################################################################## +# Tries to find the local libsodium installation. +# +# On Windows the sodium_DIR environment variable is used as a default +# hint which can be overridden by setting the corresponding cmake variable. +# +# Once done the following variables will be defined: +# +# sodium_FOUND +# sodium_INCLUDE_DIR +# sodium_LIBRARY_DEBUG +# sodium_LIBRARY_RELEASE +# +# +# Furthermore an imported "sodium" target is created. +# + +if (CMAKE_C_COMPILER_ID STREQUAL "GNU" + OR CMAKE_C_COMPILER_ID STREQUAL "Clang") + set(_GCC_COMPATIBLE 1) +endif() + +# static library option +if (NOT DEFINED sodium_USE_STATIC_LIBS) + option(sodium_USE_STATIC_LIBS "enable to statically link against sodium" OFF) +endif() +if(NOT (sodium_USE_STATIC_LIBS EQUAL sodium_USE_STATIC_LIBS_LAST)) + unset(sodium_LIBRARY CACHE) + unset(sodium_LIBRARY_DEBUG CACHE) + unset(sodium_LIBRARY_RELEASE CACHE) + unset(sodium_DLL_DEBUG CACHE) + unset(sodium_DLL_RELEASE CACHE) + set(sodium_USE_STATIC_LIBS_LAST ${sodium_USE_STATIC_LIBS} CACHE INTERNAL "internal change tracking variable") +endif() + + +######################################################################## +# UNIX +if (UNIX) + # import pkg-config + find_package(PkgConfig QUIET) + if (PKG_CONFIG_FOUND) + pkg_check_modules(sodium_PKG QUIET libsodium) + endif() + + if(sodium_USE_STATIC_LIBS) + foreach(_libname ${sodium_PKG_STATIC_LIBRARIES}) + if (NOT _libname MATCHES "^lib.*\\.a$") # ignore strings already ending with .a + list(INSERT sodium_PKG_STATIC_LIBRARIES 0 "lib${_libname}.a") + endif() + endforeach() + list(REMOVE_DUPLICATES sodium_PKG_STATIC_LIBRARIES) + + # if pkgconfig for libsodium doesn't provide + # static lib info, then override PKG_STATIC here.. + if (NOT sodium_PKG_STATIC_FOUND) + set(sodium_PKG_STATIC_LIBRARIES libsodium.a) + endif() + + set(XPREFIX sodium_PKG_STATIC) + else() + if (NOT sodium_PKG_FOUND) + set(sodium_PKG_LIBRARIES sodium) + endif() + + set(XPREFIX sodium_PKG) + endif() + + find_path(sodium_INCLUDE_DIR sodium.h + HINTS ${${XPREFIX}_INCLUDE_DIRS} + ) + find_library(sodium_LIBRARY_DEBUG NAMES ${${XPREFIX}_LIBRARIES} + HINTS ${${XPREFIX}_LIBRARY_DIRS} + ) + find_library(sodium_LIBRARY_RELEASE NAMES ${${XPREFIX}_LIBRARIES} + HINTS ${${XPREFIX}_LIBRARY_DIRS} + ) + + +######################################################################## +# Windows +elseif (WIN32) + set(sodium_DIR "$ENV{sodium_DIR}" CACHE FILEPATH "sodium install directory") + mark_as_advanced(sodium_DIR) + + find_path(sodium_INCLUDE_DIR sodium.h + HINTS ${sodium_DIR} + PATH_SUFFIXES include + ) + + if (MSVC) + # detect target architecture + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" [=[ + #if defined _M_IX86 + #error ARCH_VALUE x86_32 + #elif defined _M_X64 + #error ARCH_VALUE x86_64 + #endif + #error ARCH_VALUE unknown + ]=]) + try_compile(_UNUSED_VAR "${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_BINARY_DIR}/arch.cpp" + OUTPUT_VARIABLE _COMPILATION_LOG + ) + string(REGEX REPLACE ".*ARCH_VALUE ([a-zA-Z0-9_]+).*" "\\1" _TARGET_ARCH "${_COMPILATION_LOG}") + + # construct library path + if (_TARGET_ARCH STREQUAL "x86_32") + string(APPEND _PLATFORM_PATH "Win32") + elseif(_TARGET_ARCH STREQUAL "x86_64") + string(APPEND _PLATFORM_PATH "x64") + else() + message(FATAL_ERROR "the ${_TARGET_ARCH} architecture is not supported by Findsodium.cmake.") + endif() + string(APPEND _PLATFORM_PATH "/$$CONFIG$$") + + if (MSVC_VERSION LESS 1900) + math(EXPR _VS_VERSION "${MSVC_VERSION} / 10 - 60") + else() + math(EXPR _VS_VERSION "${MSVC_VERSION} / 10 - 50") + endif() + string(APPEND _PLATFORM_PATH "/v${_VS_VERSION}") + + if (sodium_USE_STATIC_LIBS) + string(APPEND _PLATFORM_PATH "/static") + else() + string(APPEND _PLATFORM_PATH "/dynamic") + endif() + + string(REPLACE "$$CONFIG$$" "Debug" _DEBUG_PATH_SUFFIX "${_PLATFORM_PATH}") + string(REPLACE "$$CONFIG$$" "Release" _RELEASE_PATH_SUFFIX "${_PLATFORM_PATH}") + + find_library(sodium_LIBRARY_DEBUG libsodium.lib + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX} + ) + find_library(sodium_LIBRARY_RELEASE libsodium.lib + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX} + ) + if (NOT sodium_USE_STATIC_LIBS) + set(CMAKE_FIND_LIBRARY_SUFFIXES_BCK ${CMAKE_FIND_LIBRARY_SUFFIXES}) + set(CMAKE_FIND_LIBRARY_SUFFIXES ".dll") + find_library(sodium_DLL_DEBUG libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_DEBUG_PATH_SUFFIX} + ) + find_library(sodium_DLL_RELEASE libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES ${_RELEASE_PATH_SUFFIX} + ) + set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES_BCK}) + endif() + + elseif(_GCC_COMPATIBLE) + if (sodium_USE_STATIC_LIBS) + find_library(sodium_LIBRARY_DEBUG libsodium.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + find_library(sodium_LIBRARY_RELEASE libsodium.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + else() + find_library(sodium_LIBRARY_DEBUG libsodium.dll.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + find_library(sodium_LIBRARY_RELEASE libsodium.dll.a + HINTS ${sodium_DIR} + PATH_SUFFIXES lib + ) + + file(GLOB _DLL + LIST_DIRECTORIES false + RELATIVE "${sodium_DIR}/bin" + "${sodium_DIR}/bin/libsodium*.dll" + ) + find_library(sodium_DLL_DEBUG ${_DLL} libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES bin + ) + find_library(sodium_DLL_RELEASE ${_DLL} libsodium + HINTS ${sodium_DIR} + PATH_SUFFIXES bin + ) + endif() + else() + message(FATAL_ERROR "this platform is not supported by FindSodium.cmake") + endif() + + +######################################################################## +# unsupported +else() + message(FATAL_ERROR "this platform is not supported by FindSodium.cmake") +endif() + + +######################################################################## +# common stuff + +# extract sodium version +if (sodium_INCLUDE_DIR) + set(_VERSION_HEADER "${_INCLUDE_DIR}/sodium/version.h") + if (EXISTS _VERSION_HEADER) + file(READ "${_VERSION_HEADER}" _VERSION_HEADER_CONTENT) + string(REGEX REPLACE ".*#[ \t]*define[ \t]*SODIUM_VERSION_STRING[ \t]*\"([^\n]*)\".*" "\\1" + sodium_VERSION "${_VERSION_HEADER_CONTENT}") + set(sodium_VERSION "${sodium_VERSION}" PARENT_SCOPE) + endif() +endif() + +# communicate results +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args( + Sodium # The name must be either uppercase or match the filename case. + REQUIRED_VARS + sodium_LIBRARY_RELEASE + sodium_LIBRARY_DEBUG + sodium_INCLUDE_DIR + VERSION_VAR + sodium_VERSION +) + +if(Sodium_FOUND) + set(sodium_LIBRARIES + optimized ${sodium_LIBRARY_RELEASE} debug ${sodium_LIBRARY_DEBUG}) +endif() + +# mark file paths as advanced +mark_as_advanced(sodium_INCLUDE_DIR) +mark_as_advanced(sodium_LIBRARY_DEBUG) +mark_as_advanced(sodium_LIBRARY_RELEASE) +if (WIN32) + mark_as_advanced(sodium_DLL_DEBUG) + mark_as_advanced(sodium_DLL_RELEASE) +endif() + +# create imported target +if(sodium_USE_STATIC_LIBS) + set(_LIB_TYPE STATIC) +else() + set(_LIB_TYPE SHARED) +endif() + +if(NOT TARGET sodium) + add_library(sodium ${_LIB_TYPE} IMPORTED) +endif() + +set_target_properties(sodium PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${sodium_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" +) + +if (sodium_USE_STATIC_LIBS) + set_target_properties(sodium PROPERTIES + INTERFACE_COMPILE_DEFINITIONS "SODIUM_STATIC" + IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" + IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}" + ) +else() + if (UNIX) + set_target_properties(sodium PROPERTIES + IMPORTED_LOCATION "${sodium_LIBRARY_RELEASE}" + IMPORTED_LOCATION_DEBUG "${sodium_LIBRARY_DEBUG}" + ) + elseif (WIN32) + set_target_properties(sodium PROPERTIES + IMPORTED_IMPLIB "${sodium_LIBRARY_RELEASE}" + IMPORTED_IMPLIB_DEBUG "${sodium_LIBRARY_DEBUG}" + ) + if (NOT (sodium_DLL_DEBUG MATCHES ".*-NOTFOUND")) + set_target_properties(sodium PROPERTIES + IMPORTED_LOCATION_DEBUG "${sodium_DLL_DEBUG}" + ) + endif() + if (NOT (sodium_DLL_RELEASE MATCHES ".*-NOTFOUND")) + set_target_properties(sodium PROPERTIES + IMPORTED_LOCATION_RELWITHDEBINFO "${sodium_DLL_RELEASE}" + IMPORTED_LOCATION_MINSIZEREL "${sodium_DLL_RELEASE}" + IMPORTED_LOCATION_RELEASE "${sodium_DLL_RELEASE}" + ) + endif() + endif() +endif() diff --git a/build/fbcode_builder/CMake/FindZstd.cmake b/build/fbcode_builder/CMake/FindZstd.cmake new file mode 100644 index 000000000..89300ddfd --- /dev/null +++ b/build/fbcode_builder/CMake/FindZstd.cmake @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# - Try to find Facebook zstd library +# This will define +# ZSTD_FOUND +# ZSTD_INCLUDE_DIR +# ZSTD_LIBRARY +# + +find_path(ZSTD_INCLUDE_DIR NAMES zstd.h) + +find_library(ZSTD_LIBRARY_DEBUG NAMES zstdd zstd_staticd) +find_library(ZSTD_LIBRARY_RELEASE NAMES zstd zstd_static) + +include(SelectLibraryConfigurations) +SELECT_LIBRARY_CONFIGURATIONS(ZSTD) + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS( + ZSTD DEFAULT_MSG + ZSTD_LIBRARY ZSTD_INCLUDE_DIR +) + +if (ZSTD_FOUND) + message(STATUS "Found Zstd: ${ZSTD_LIBRARY}") +endif() + +mark_as_advanced(ZSTD_INCLUDE_DIR ZSTD_LIBRARY) diff --git a/build/fbcode_builder/CMake/RustStaticLibrary.cmake b/build/fbcode_builder/CMake/RustStaticLibrary.cmake new file mode 100644 index 000000000..8546fe2fb --- /dev/null +++ b/build/fbcode_builder/CMake/RustStaticLibrary.cmake @@ -0,0 +1,291 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +include(FBCMakeParseArgs) + +set( + USE_CARGO_VENDOR AUTO CACHE STRING + "Download Rust Crates from an internally vendored location" +) +set_property(CACHE USE_CARGO_VENDOR PROPERTY STRINGS AUTO ON OFF) + +set(RUST_VENDORED_CRATES_DIR "$ENV{RUST_VENDORED_CRATES_DIR}") +if("${USE_CARGO_VENDOR}" STREQUAL "AUTO") + if(EXISTS "${RUST_VENDORED_CRATES_DIR}") + set(USE_CARGO_VENDOR ON) + else() + set(USE_CARGO_VENDOR OFF) + endif() +endif() + +if(USE_CARGO_VENDOR) + if(NOT EXISTS "${RUST_VENDORED_CRATES_DIR}") + message( + FATAL "vendored rust crates not present: " + "${RUST_VENDORED_CRATES_DIR}" + ) + endif() + + set(RUST_CARGO_HOME "${CMAKE_BINARY_DIR}/_cargo_home") + file(MAKE_DIRECTORY "${RUST_CARGO_HOME}") + + file( + TO_NATIVE_PATH "${RUST_VENDORED_CRATES_DIR}" + ESCAPED_RUST_VENDORED_CRATES_DIR + ) + string( + REPLACE "\\" "\\\\" + ESCAPED_RUST_VENDORED_CRATES_DIR + "${ESCAPED_RUST_VENDORED_CRATES_DIR}" + ) + file( + WRITE "${RUST_CARGO_HOME}/config" + "[source.crates-io]\n" + "replace-with = \"vendored-sources\"\n" + "\n" + "[source.vendored-sources]\n" + "directory = \"${ESCAPED_RUST_VENDORED_CRATES_DIR}\"\n" + ) +endif() + +# Cargo is a build system in itself, and thus will try to take advantage of all +# the cores on the system. Unfortunately, this conflicts with Ninja, since it +# also tries to utilize all the cores. This can lead to a system that is +# completely overloaded with compile jobs to the point where nothing else can +# be achieved on the system. +# +# Let's inform Ninja of this fact so it won't try to spawn other jobs while +# Rust being compiled. +set_property(GLOBAL APPEND PROPERTY JOB_POOLS rust_job_pool=1) + +# This function creates an interface library target based on the static library +# built by Cargo. It will call Cargo to build a staticlib and generate a CMake +# interface library with it. +# +# This function requires `find_package(Python COMPONENTS Interpreter)`. +# +# You need to set `lib:crate-type = ["staticlib"]` in your Cargo.toml to make +# Cargo build static library. +# +# ```cmake +# rust_static_library( [CRATE ]) +# ``` +# +# Parameters: +# - TARGET: +# Name of the target name. This function will create an interface library +# target with this name. +# - CRATE_NAME: +# Name of the crate. This parameter is optional. If unspecified, it will +# fallback to `${TARGET}`. +# +# This function creates two targets: +# - "${TARGET}": an interface library target contains the static library built +# from Cargo. +# - "${TARGET}.cargo": an internal custom target that invokes Cargo. +# +# If you are going to use this static library from C/C++, you will need to +# write header files for the library (or generate with cbindgen) and bind these +# headers with the interface library. +# +function(rust_static_library TARGET) + fb_cmake_parse_args(ARG "" "CRATE" "" "${ARGN}") + + if(DEFINED ARG_CRATE) + set(crate_name "${ARG_CRATE}") + else() + set(crate_name "${TARGET}") + endif() + + set(cargo_target "${TARGET}.cargo") + set(target_dir $,debug,release>) + set(staticlib_name "${CMAKE_STATIC_LIBRARY_PREFIX}${crate_name}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set(rust_staticlib "${CMAKE_CURRENT_BINARY_DIR}/${target_dir}/${staticlib_name}") + + set(cargo_cmd cargo) + if(WIN32) + set(cargo_cmd cargo.exe) + endif() + + set(cargo_flags build $,,--release> -p ${crate_name}) + if(USE_CARGO_VENDOR) + set(extra_cargo_env "CARGO_HOME=${RUST_CARGO_HOME}") + set(cargo_flags ${cargo_flags}) + endif() + + add_custom_target( + ${cargo_target} + COMMAND + "${CMAKE_COMMAND}" -E remove -f "${CMAKE_CURRENT_SOURCE_DIR}/Cargo.lock" + COMMAND + "${CMAKE_COMMAND}" -E env + "CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR}" + ${extra_cargo_env} + ${cargo_cmd} + ${cargo_flags} + COMMENT "Building Rust crate '${crate_name}'..." + JOB_POOL rust_job_pool + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + BYPRODUCTS + "${CMAKE_CURRENT_BINARY_DIR}/debug/${staticlib_name}" + "${CMAKE_CURRENT_BINARY_DIR}/release/${staticlib_name}" + ) + + add_library(${TARGET} INTERFACE) + add_dependencies(${TARGET} ${cargo_target}) + set_target_properties( + ${TARGET} + PROPERTIES + INTERFACE_STATICLIB_OUTPUT_PATH "${rust_staticlib}" + INTERFACE_INSTALL_LIBNAME + "${CMAKE_STATIC_LIBRARY_PREFIX}${crate_name}_rs${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + target_link_libraries( + ${TARGET} + INTERFACE "$" + ) +endfunction() + +# This function instructs cmake to define a target that will use `cargo build` +# to build a bin crate referenced by the Cargo.toml file in the current source +# directory. +# It accepts a single `TARGET` parameter which will be passed as the package +# name to `cargo build -p TARGET`. If binary has different name as package, +# use optional flag BINARY_NAME to override it. +# The cmake target will be registered to build by default as part of the +# ALL target. +function(rust_executable TARGET) + fb_cmake_parse_args(ARG "" "BINARY_NAME" "" "${ARGN}") + + set(crate_name "${TARGET}") + set(cargo_target "${TARGET}.cargo") + set(target_dir $,debug,release>) + + if(DEFINED ARG_BINARY_NAME) + set(executable_name "${ARG_BINARY_NAME}${CMAKE_EXECUTABLE_SUFFIX}") + else() + set(executable_name "${crate_name}${CMAKE_EXECUTABLE_SUFFIX}") + endif() + + set(cargo_cmd cargo) + if(WIN32) + set(cargo_cmd cargo.exe) + endif() + + set(cargo_flags build $,,--release> -p ${crate_name}) + if(USE_CARGO_VENDOR) + set(extra_cargo_env "CARGO_HOME=${RUST_CARGO_HOME}") + set(cargo_flags ${cargo_flags}) + endif() + + add_custom_target( + ${cargo_target} + ALL + COMMAND + "${CMAKE_COMMAND}" -E remove -f "${CMAKE_CURRENT_SOURCE_DIR}/Cargo.lock" + COMMAND + "${CMAKE_COMMAND}" -E env + "CARGO_TARGET_DIR=${CMAKE_CURRENT_BINARY_DIR}" + ${extra_cargo_env} + ${cargo_cmd} + ${cargo_flags} + COMMENT "Building Rust executable '${crate_name}'..." + JOB_POOL rust_job_pool + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + BYPRODUCTS + "${CMAKE_CURRENT_BINARY_DIR}/debug/${executable_name}" + "${CMAKE_CURRENT_BINARY_DIR}/release/${executable_name}" + ) + + set_property(TARGET "${cargo_target}" + PROPERTY EXECUTABLE "${CMAKE_CURRENT_BINARY_DIR}/${target_dir}/${executable_name}") +endfunction() + +# This function can be used to install the executable generated by a prior +# call to the `rust_executable` function. +# It requires a `TARGET` parameter to identify the target to be installed, +# and an optional `DESTINATION` parameter to specify the installation +# directory. If DESTINATION is not specified then the `bin` directory +# will be assumed. +function(install_rust_executable TARGET) + # Parse the arguments + set(one_value_args DESTINATION) + set(multi_value_args) + fb_cmake_parse_args( + ARG "" "${one_value_args}" "${multi_value_args}" "${ARGN}" + ) + + if(NOT DEFINED ARG_DESTINATION) + set(ARG_DESTINATION bin) + endif() + + get_target_property(foo "${TARGET}.cargo" EXECUTABLE) + + install( + PROGRAMS "${foo}" + DESTINATION "${ARG_DESTINATION}" + ) +endfunction() + +# This function installs the interface target generated from the function +# `rust_static_library`. Use this function if you want to export your Rust +# target to external CMake targets. +# +# ```cmake +# install_rust_static_library( +# +# INSTALL_DIR +# [EXPORT ] +# ) +# ``` +# +# Parameters: +# - TARGET: Name of the Rust static library target. +# - EXPORT_NAME: Name of the exported target. +# - INSTALL_DIR: Path to the directory where this library will be installed. +# +function(install_rust_static_library TARGET) + fb_cmake_parse_args(ARG "" "EXPORT;INSTALL_DIR" "" "${ARGN}") + + get_property( + staticlib_output_path + TARGET "${TARGET}" + PROPERTY INTERFACE_STATICLIB_OUTPUT_PATH + ) + get_property( + staticlib_output_name + TARGET "${TARGET}" + PROPERTY INTERFACE_INSTALL_LIBNAME + ) + + if(NOT DEFINED staticlib_output_path) + message(FATAL_ERROR "Not a rust_static_library target.") + endif() + + if(NOT DEFINED ARG_INSTALL_DIR) + message(FATAL_ERROR "Missing required argument.") + endif() + + if(DEFINED ARG_EXPORT) + set(install_export_args EXPORT "${ARG_EXPORT}") + endif() + + set(install_interface_dir "${ARG_INSTALL_DIR}") + if(NOT IS_ABSOLUTE "${install_interface_dir}") + set(install_interface_dir "\${_IMPORT_PREFIX}/${install_interface_dir}") + endif() + + target_link_libraries( + ${TARGET} INTERFACE + "$" + ) + install( + TARGETS ${TARGET} + ${install_export_args} + LIBRARY DESTINATION ${ARG_INSTALL_DIR} + ) + install( + FILES ${staticlib_output_path} + RENAME ${staticlib_output_name} + DESTINATION ${ARG_INSTALL_DIR} + ) +endfunction() diff --git a/build/fbcode_builder/CMake/fb_py_test_main.py b/build/fbcode_builder/CMake/fb_py_test_main.py new file mode 100644 index 000000000..1f3563aff --- /dev/null +++ b/build/fbcode_builder/CMake/fb_py_test_main.py @@ -0,0 +1,820 @@ +#!/usr/bin/env python +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +""" +This file contains the main module code for Python test programs. +""" + +from __future__ import print_function + +import contextlib +import ctypes +import fnmatch +import json +import logging +import optparse +import os +import platform +import re +import sys +import tempfile +import time +import traceback +import unittest +import warnings + +# Hide warning about importing "imp"; remove once python2 is gone. +with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + import imp + +try: + from StringIO import StringIO +except ImportError: + from io import StringIO +try: + import coverage +except ImportError: + coverage = None # type: ignore +try: + from importlib.machinery import SourceFileLoader +except ImportError: + SourceFileLoader = None # type: ignore + + +class get_cpu_instr_counter(object): + def read(self): + # TODO + return 0 + + +EXIT_CODE_SUCCESS = 0 +EXIT_CODE_TEST_FAILURE = 70 + + +class TestStatus(object): + + ABORTED = "FAILURE" + PASSED = "SUCCESS" + FAILED = "FAILURE" + EXPECTED_FAILURE = "SUCCESS" + UNEXPECTED_SUCCESS = "FAILURE" + SKIPPED = "ASSUMPTION_VIOLATION" + + +class PathMatcher(object): + def __init__(self, include_patterns, omit_patterns): + self.include_patterns = include_patterns + self.omit_patterns = omit_patterns + + def omit(self, path): + """ + Omit iff matches any of the omit_patterns or the include patterns are + not empty and none is matched + """ + path = os.path.realpath(path) + return any(fnmatch.fnmatch(path, p) for p in self.omit_patterns) or ( + self.include_patterns + and not any(fnmatch.fnmatch(path, p) for p in self.include_patterns) + ) + + def include(self, path): + return not self.omit(path) + + +class DebugWipeFinder(object): + """ + PEP 302 finder that uses a DebugWipeLoader for all files which do not need + coverage + """ + + def __init__(self, matcher): + self.matcher = matcher + + def find_module(self, fullname, path=None): + _, _, basename = fullname.rpartition(".") + try: + fd, pypath, (_, _, kind) = imp.find_module(basename, path) + except Exception: + # Finding without hooks using the imp module failed. One reason + # could be that there is a zip file on sys.path. The imp module + # does not support loading from there. Leave finding this module to + # the others finders in sys.meta_path. + return None + + if hasattr(fd, "close"): + fd.close() + if kind != imp.PY_SOURCE: + return None + if self.matcher.include(pypath): + return None + + """ + This is defined to match CPython's PyVarObject struct + """ + + class PyVarObject(ctypes.Structure): + _fields_ = [ + ("ob_refcnt", ctypes.c_long), + ("ob_type", ctypes.c_void_p), + ("ob_size", ctypes.c_ulong), + ] + + class DebugWipeLoader(SourceFileLoader): + """ + PEP302 loader that zeros out debug information before execution + """ + + def get_code(self, fullname): + code = super(DebugWipeLoader, self).get_code(fullname) + if code: + # Ideally we'd do + # code.co_lnotab = b'' + # But code objects are READONLY. Not to worry though; we'll + # directly modify CPython's object + code_impl = PyVarObject.from_address(id(code.co_lnotab)) + code_impl.ob_size = 0 + return code + + return DebugWipeLoader(fullname, pypath) + + +def optimize_for_coverage(cov, include_patterns, omit_patterns): + """ + We get better performance if we zero out debug information for files which + we're not interested in. Only available in CPython 3.3+ + """ + matcher = PathMatcher(include_patterns, omit_patterns) + if SourceFileLoader and platform.python_implementation() == "CPython": + sys.meta_path.insert(0, DebugWipeFinder(matcher)) + + +class TeeStream(object): + def __init__(self, *streams): + self._streams = streams + + def write(self, data): + for stream in self._streams: + stream.write(data) + + def flush(self): + for stream in self._streams: + stream.flush() + + def isatty(self): + return False + + +class CallbackStream(object): + def __init__(self, callback, bytes_callback=None, orig=None): + self._callback = callback + self._fileno = orig.fileno() if orig else None + + # Python 3 APIs: + # - `encoding` is a string holding the encoding name + # - `errors` is a string holding the error-handling mode for encoding + # - `buffer` should look like an io.BufferedIOBase object + + self.errors = orig.errors if orig else None + if bytes_callback: + # those members are only on the io.TextIOWrapper + self.encoding = orig.encoding if orig else "UTF-8" + self.buffer = CallbackStream(bytes_callback, orig=orig) + + def write(self, data): + self._callback(data) + + def flush(self): + pass + + def isatty(self): + return False + + def fileno(self): + return self._fileno + + +class BuckTestResult(unittest._TextTestResult): + """ + Our own TestResult class that outputs data in a format that can be easily + parsed by buck's test runner. + """ + + _instr_counter = get_cpu_instr_counter() + + def __init__( + self, stream, descriptions, verbosity, show_output, main_program, suite + ): + super(BuckTestResult, self).__init__(stream, descriptions, verbosity) + self._main_program = main_program + self._suite = suite + self._results = [] + self._current_test = None + self._saved_stdout = sys.stdout + self._saved_stderr = sys.stderr + self._show_output = show_output + + def getResults(self): + return self._results + + def startTest(self, test): + super(BuckTestResult, self).startTest(test) + + # Pass in the real stdout and stderr filenos. We can't really do much + # here to intercept callers who directly operate on these fileno + # objects. + sys.stdout = CallbackStream( + self.addStdout, self.addStdoutBytes, orig=sys.stdout + ) + sys.stderr = CallbackStream( + self.addStderr, self.addStderrBytes, orig=sys.stderr + ) + self._current_test = test + self._test_start_time = time.time() + self._current_status = TestStatus.ABORTED + self._messages = [] + self._stacktrace = None + self._stdout = "" + self._stderr = "" + self._start_instr_count = self._instr_counter.read() + + def _find_next_test(self, suite): + """ + Find the next test that has not been run. + """ + + for test in suite: + + # We identify test suites by test that are iterable (as is done in + # the builtin python test harness). If we see one, recurse on it. + if hasattr(test, "__iter__"): + test = self._find_next_test(test) + + # The builtin python test harness sets test references to `None` + # after they have run, so we know we've found the next test up + # if it's not `None`. + if test is not None: + return test + + def stopTest(self, test): + sys.stdout = self._saved_stdout + sys.stderr = self._saved_stderr + + super(BuckTestResult, self).stopTest(test) + + # If a failure occured during module/class setup, then this "test" may + # actually be a `_ErrorHolder`, which doesn't contain explicit info + # about the upcoming test. Since we really only care about the test + # name field (i.e. `_testMethodName`), we use that to detect an actual + # test cases, and fall back to looking the test up from the suite + # otherwise. + if not hasattr(test, "_testMethodName"): + test = self._find_next_test(self._suite) + + result = { + "testCaseName": "{0}.{1}".format( + test.__class__.__module__, test.__class__.__name__ + ), + "testCase": test._testMethodName, + "type": self._current_status, + "time": int((time.time() - self._test_start_time) * 1000), + "message": os.linesep.join(self._messages), + "stacktrace": self._stacktrace, + "stdOut": self._stdout, + "stdErr": self._stderr, + } + + # TestPilot supports an instruction count field. + if "TEST_PILOT" in os.environ: + result["instrCount"] = ( + int(self._instr_counter.read() - self._start_instr_count), + ) + + self._results.append(result) + self._current_test = None + + def stopTestRun(self): + cov = self._main_program.get_coverage() + if cov is not None: + self._results.append({"coverage": cov}) + + @contextlib.contextmanager + def _withTest(self, test): + self.startTest(test) + yield + self.stopTest(test) + + def _setStatus(self, test, status, message=None, stacktrace=None): + assert test == self._current_test + self._current_status = status + self._stacktrace = stacktrace + if message is not None: + if message.endswith(os.linesep): + message = message[:-1] + self._messages.append(message) + + def setStatus(self, test, status, message=None, stacktrace=None): + # addError() may be called outside of a test if one of the shared + # fixtures (setUpClass/tearDownClass/setUpModule/tearDownModule) + # throws an error. + # + # In this case, create a fake test result to record the error. + if self._current_test is None: + with self._withTest(test): + self._setStatus(test, status, message, stacktrace) + else: + self._setStatus(test, status, message, stacktrace) + + def setException(self, test, status, excinfo): + exctype, value, tb = excinfo + self.setStatus( + test, + status, + "{0}: {1}".format(exctype.__name__, value), + "".join(traceback.format_tb(tb)), + ) + + def addSuccess(self, test): + super(BuckTestResult, self).addSuccess(test) + self.setStatus(test, TestStatus.PASSED) + + def addError(self, test, err): + super(BuckTestResult, self).addError(test, err) + self.setException(test, TestStatus.ABORTED, err) + + def addFailure(self, test, err): + super(BuckTestResult, self).addFailure(test, err) + self.setException(test, TestStatus.FAILED, err) + + def addSkip(self, test, reason): + super(BuckTestResult, self).addSkip(test, reason) + self.setStatus(test, TestStatus.SKIPPED, "Skipped: %s" % (reason,)) + + def addExpectedFailure(self, test, err): + super(BuckTestResult, self).addExpectedFailure(test, err) + self.setException(test, TestStatus.EXPECTED_FAILURE, err) + + def addUnexpectedSuccess(self, test): + super(BuckTestResult, self).addUnexpectedSuccess(test) + self.setStatus(test, TestStatus.UNEXPECTED_SUCCESS, "Unexpected success") + + def addStdout(self, val): + self._stdout += val + if self._show_output: + self._saved_stdout.write(val) + self._saved_stdout.flush() + + def addStdoutBytes(self, val): + string = val.decode("utf-8", errors="backslashreplace") + self.addStdout(string) + + def addStderr(self, val): + self._stderr += val + if self._show_output: + self._saved_stderr.write(val) + self._saved_stderr.flush() + + def addStderrBytes(self, val): + string = val.decode("utf-8", errors="backslashreplace") + self.addStderr(string) + + +class BuckTestRunner(unittest.TextTestRunner): + def __init__(self, main_program, suite, show_output=True, **kwargs): + super(BuckTestRunner, self).__init__(**kwargs) + self.show_output = show_output + self._main_program = main_program + self._suite = suite + + def _makeResult(self): + return BuckTestResult( + self.stream, + self.descriptions, + self.verbosity, + self.show_output, + self._main_program, + self._suite, + ) + + +def _format_test_name(test_class, attrname): + return "{0}.{1}.{2}".format(test_class.__module__, test_class.__name__, attrname) + + +class StderrLogHandler(logging.StreamHandler): + """ + This class is very similar to logging.StreamHandler, except that it + always uses the current sys.stderr object. + + StreamHandler caches the current sys.stderr object when it is constructed. + This makes it behave poorly in unit tests, which may replace sys.stderr + with a StringIO buffer during tests. The StreamHandler will continue using + the old sys.stderr object instead of the desired StringIO buffer. + """ + + def __init__(self): + logging.Handler.__init__(self) + + @property + def stream(self): + return sys.stderr + + +class RegexTestLoader(unittest.TestLoader): + def __init__(self, regex=None): + self.regex = regex + super(RegexTestLoader, self).__init__() + + def getTestCaseNames(self, testCaseClass): + """ + Return a sorted sequence of method names found within testCaseClass + """ + + testFnNames = super(RegexTestLoader, self).getTestCaseNames(testCaseClass) + if self.regex is None: + return testFnNames + robj = re.compile(self.regex) + matched = [] + for attrname in testFnNames: + fullname = _format_test_name(testCaseClass, attrname) + if robj.search(fullname): + matched.append(attrname) + return matched + + +class Loader(object): + + suiteClass = unittest.TestSuite + + def __init__(self, modules, regex=None): + self.modules = modules + self.regex = regex + + def load_all(self): + loader = RegexTestLoader(self.regex) + test_suite = self.suiteClass() + for module_name in self.modules: + __import__(module_name, level=0) + module = sys.modules[module_name] + module_suite = loader.loadTestsFromModule(module) + test_suite.addTest(module_suite) + return test_suite + + def load_args(self, args): + loader = RegexTestLoader(self.regex) + + suites = [] + for arg in args: + suite = loader.loadTestsFromName(arg) + # loadTestsFromName() can only process names that refer to + # individual test functions or modules. It can't process package + # names. If there were no module/function matches, check to see if + # this looks like a package name. + if suite.countTestCases() != 0: + suites.append(suite) + continue + + # Load all modules whose name is . + prefix = arg + "." + for module in self.modules: + if module.startswith(prefix): + suite = loader.loadTestsFromName(module) + suites.append(suite) + + return loader.suiteClass(suites) + + +_COVERAGE_INI = """\ +[report] +exclude_lines = + pragma: no cover + pragma: nocover + pragma:.*no${PLATFORM} + pragma:.*no${PY_IMPL}${PY_MAJOR}${PY_MINOR} + pragma:.*no${PY_IMPL}${PY_MAJOR} + pragma:.*nopy${PY_MAJOR} + pragma:.*nopy${PY_MAJOR}${PY_MINOR} +""" + + +class MainProgram(object): + """ + This class implements the main program. It can be subclassed by + users who wish to customize some parts of the main program. + (Adding additional command line options, customizing test loading, etc.) + """ + + DEFAULT_VERBOSITY = 2 + + def __init__(self, argv): + self.init_option_parser() + self.parse_options(argv) + self.setup_logging() + + def init_option_parser(self): + usage = "%prog [options] [TEST] ..." + op = optparse.OptionParser(usage=usage, add_help_option=False) + self.option_parser = op + + op.add_option( + "--hide-output", + dest="show_output", + action="store_false", + default=True, + help="Suppress data that tests print to stdout/stderr, and only " + "show it if the test fails.", + ) + op.add_option( + "-o", + "--output", + help="Write results to a file in a JSON format to be read by Buck", + ) + op.add_option( + "-f", + "--failfast", + action="store_true", + default=False, + help="Stop after the first failure", + ) + op.add_option( + "-l", + "--list-tests", + action="store_true", + dest="list", + default=False, + help="List tests and exit", + ) + op.add_option( + "-r", + "--regex", + default=None, + help="Regex to apply to tests, to only run those tests", + ) + op.add_option( + "--collect-coverage", + action="store_true", + default=False, + help="Collect test coverage information", + ) + op.add_option( + "--coverage-include", + default="*", + help='File globs to include in converage (split by ",")', + ) + op.add_option( + "--coverage-omit", + default="", + help='File globs to omit from converage (split by ",")', + ) + op.add_option( + "--logger", + action="append", + metavar="=", + default=[], + help="Configure log levels for specific logger categories", + ) + op.add_option( + "-q", + "--quiet", + action="count", + default=0, + help="Decrease the verbosity (may be specified multiple times)", + ) + op.add_option( + "-v", + "--verbosity", + action="count", + default=self.DEFAULT_VERBOSITY, + help="Increase the verbosity (may be specified multiple times)", + ) + op.add_option( + "-?", "--help", action="help", help="Show this help message and exit" + ) + + def parse_options(self, argv): + self.options, self.test_args = self.option_parser.parse_args(argv[1:]) + self.options.verbosity -= self.options.quiet + + if self.options.collect_coverage and coverage is None: + self.option_parser.error("coverage module is not available") + self.options.coverage_include = self.options.coverage_include.split(",") + if self.options.coverage_omit == "": + self.options.coverage_omit = [] + else: + self.options.coverage_omit = self.options.coverage_omit.split(",") + + def setup_logging(self): + # Configure the root logger to log at INFO level. + # This is similar to logging.basicConfig(), but uses our + # StderrLogHandler instead of a StreamHandler. + fmt = logging.Formatter("%(pathname)s:%(lineno)s: %(message)s") + log_handler = StderrLogHandler() + log_handler.setFormatter(fmt) + root_logger = logging.getLogger() + root_logger.addHandler(log_handler) + root_logger.setLevel(logging.INFO) + + level_names = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warn": logging.WARNING, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, + "fatal": logging.FATAL, + } + + for value in self.options.logger: + parts = value.rsplit("=", 1) + if len(parts) != 2: + self.option_parser.error( + "--logger argument must be of the " + "form =: %s" % value + ) + name = parts[0] + level_name = parts[1].lower() + level = level_names.get(level_name) + if level is None: + self.option_parser.error( + "invalid log level %r for log " "category %s" % (parts[1], name) + ) + logging.getLogger(name).setLevel(level) + + def create_loader(self): + import __test_modules__ + + return Loader(__test_modules__.TEST_MODULES, self.options.regex) + + def load_tests(self): + loader = self.create_loader() + if self.options.collect_coverage: + self.start_coverage() + include = self.options.coverage_include + omit = self.options.coverage_omit + if include and "*" not in include: + optimize_for_coverage(self.cov, include, omit) + + if self.test_args: + suite = loader.load_args(self.test_args) + else: + suite = loader.load_all() + if self.options.collect_coverage: + self.cov.start() + return suite + + def get_tests(self, test_suite): + tests = [] + + for test in test_suite: + if isinstance(test, unittest.TestSuite): + tests.extend(self.get_tests(test)) + else: + tests.append(test) + + return tests + + def run(self): + test_suite = self.load_tests() + + if self.options.list: + for test in self.get_tests(test_suite): + method_name = getattr(test, "_testMethodName", "") + name = _format_test_name(test.__class__, method_name) + print(name) + return EXIT_CODE_SUCCESS + else: + result = self.run_tests(test_suite) + if self.options.output is not None: + with open(self.options.output, "w") as f: + json.dump(result.getResults(), f, indent=4, sort_keys=True) + if not result.wasSuccessful(): + return EXIT_CODE_TEST_FAILURE + return EXIT_CODE_SUCCESS + + def run_tests(self, test_suite): + # Install a signal handler to catch Ctrl-C and display the results + # (but only if running >2.6). + if sys.version_info[0] > 2 or sys.version_info[1] > 6: + unittest.installHandler() + + # Run the tests + runner = BuckTestRunner( + self, + test_suite, + verbosity=self.options.verbosity, + show_output=self.options.show_output, + ) + result = runner.run(test_suite) + + if self.options.collect_coverage and self.options.show_output: + self.cov.stop() + try: + self.cov.report(file=sys.stdout) + except coverage.misc.CoverageException: + print("No lines were covered, potentially restricted by file filters") + + return result + + def get_abbr_impl(self): + """Return abbreviated implementation name.""" + impl = platform.python_implementation() + if impl == "PyPy": + return "pp" + elif impl == "Jython": + return "jy" + elif impl == "IronPython": + return "ip" + elif impl == "CPython": + return "cp" + else: + raise RuntimeError("unknown python runtime") + + def start_coverage(self): + if not self.options.collect_coverage: + return + + with tempfile.NamedTemporaryFile("w", delete=False) as coverage_ini: + coverage_ini.write(_COVERAGE_INI) + self._coverage_ini_path = coverage_ini.name + + # Keep the original working dir in case tests use os.chdir + self._original_working_dir = os.getcwd() + + # for coverage config ignores by platform/python version + os.environ["PLATFORM"] = sys.platform + os.environ["PY_IMPL"] = self.get_abbr_impl() + os.environ["PY_MAJOR"] = str(sys.version_info.major) + os.environ["PY_MINOR"] = str(sys.version_info.minor) + + self.cov = coverage.Coverage( + include=self.options.coverage_include, + omit=self.options.coverage_omit, + config_file=coverage_ini.name, + ) + self.cov.erase() + self.cov.start() + + def get_coverage(self): + if not self.options.collect_coverage: + return None + + try: + os.remove(self._coverage_ini_path) + except OSError: + pass # Better to litter than to fail the test + + # Switch back to the original working directory. + os.chdir(self._original_working_dir) + + result = {} + + self.cov.stop() + + try: + f = StringIO() + self.cov.report(file=f) + lines = f.getvalue().split("\n") + except coverage.misc.CoverageException: + # Nothing was covered. That's fine by us + return result + + # N.B.: the format of the coverage library's output differs + # depending on whether one or more files are in the results + for line in lines[2:]: + if line.strip("-") == "": + break + r = line.split()[0] + analysis = self.cov.analysis2(r) + covString = self.convert_to_diff_cov_str(analysis) + if covString: + result[r] = covString + + return result + + def convert_to_diff_cov_str(self, analysis): + # Info on the format of analysis: + # http://nedbatchelder.com/code/coverage/api.html + if not analysis: + return None + numLines = max( + analysis[1][-1] if len(analysis[1]) else 0, + analysis[2][-1] if len(analysis[2]) else 0, + analysis[3][-1] if len(analysis[3]) else 0, + ) + lines = ["N"] * numLines + for l in analysis[1]: + lines[l - 1] = "C" + for l in analysis[2]: + lines[l - 1] = "X" + for l in analysis[3]: + lines[l - 1] = "U" + return "".join(lines) + + +def main(argv): + return MainProgram(sys.argv).run() + + +if __name__ == "__main__": + sys.exit(main(sys.argv)) diff --git a/build/fbcode_builder/CMake/fb_py_win_main.c b/build/fbcode_builder/CMake/fb_py_win_main.c new file mode 100644 index 000000000..8905c3602 --- /dev/null +++ b/build/fbcode_builder/CMake/fb_py_win_main.c @@ -0,0 +1,126 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#define WIN32_LEAN_AND_MEAN + +#include +#include +#include + +#define PATH_SIZE 32768 + +typedef int (*Py_Main)(int, wchar_t**); + +// Add the given path to Windows's DLL search path. +// For Windows DLL search path resolution, see: +// https://docs.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order +void add_search_path(const wchar_t* path) { + wchar_t buffer[PATH_SIZE]; + wchar_t** lppPart = NULL; + + if (!GetFullPathNameW(path, PATH_SIZE, buffer, lppPart)) { + fwprintf(stderr, L"warning: %d unable to expand path %s\n", GetLastError(), path); + return; + } + + if (!AddDllDirectory(buffer)) { + DWORD error = GetLastError(); + if (error != ERROR_FILE_NOT_FOUND) { + fwprintf(stderr, L"warning: %d unable to set DLL search path for %s\n", GetLastError(), path); + } + } +} + +int locate_py_main(int argc, wchar_t **argv) { + /* + * We have to dynamically locate Python3.dll because we may be loading a + * Python native module while running. If that module is built with a + * different Python version, we will end up a DLL import error. To resolve + * this, we can either ship an embedded version of Python with us or + * dynamically look up existing Python distribution installed on user's + * machine. This way, we should be able to get a consistent version of + * Python3.dll and .pyd modules. + */ + HINSTANCE python_dll; + Py_Main pymain; + + // last added directory has highest priority + add_search_path(L"C:\\Python36\\"); + add_search_path(L"C:\\Python37\\"); + add_search_path(L"C:\\Python38\\"); + + python_dll = LoadLibraryExW(L"python3.dll", NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS); + + int returncode = 0; + if (python_dll != NULL) { + pymain = (Py_Main) GetProcAddress(python_dll, "Py_Main"); + + if (pymain != NULL) { + returncode = (pymain)(argc, argv); + } else { + fprintf(stderr, "error: %d unable to load Py_Main\n", GetLastError()); + } + + FreeLibrary(python_dll); + } else { + fprintf(stderr, "error: %d unable to locate python3.dll\n", GetLastError()); + return 1; + } + return returncode; +} + +int wmain() { + /* + * This executable will be prepended to the start of a Python ZIP archive. + * Python will be able to directly execute the ZIP archive, so we simply + * need to tell Py_Main() to run our own file. Duplicate the argument list + * and add our file name to the beginning to tell Python what file to invoke. + */ + wchar_t** pyargv = malloc(sizeof(wchar_t*) * (__argc + 1)); + if (!pyargv) { + fprintf(stderr, "error: failed to allocate argument vector\n"); + return 1; + } + + /* Py_Main wants the wide character version of the argv so we pull those + * values from the global __wargv array that has been prepared by MSVCRT. + * + * In order for the zipapp to run we need to insert an extra argument in + * the front of the argument vector that points to ourselves. + * + * An additional complication is that, depending on who prepared the argument + * string used to start our process, the computed __wargv[0] can be a simple + * shell word like `watchman-wait` which is normally resolved together with + * the PATH by the shell. + * That unresolved path isn't sufficient to start the zipapp on windows; + * we need the fully qualified path. + * + * Given: + * __wargv == {"watchman-wait", "-h"} + * + * we want to pass the following to Py_Main: + * + * { + * "z:\build\watchman\python\watchman-wait.exe", + * "z:\build\watchman\python\watchman-wait.exe", + * "-h" + * } + */ + wchar_t full_path_to_argv0[PATH_SIZE]; + DWORD len = GetModuleFileNameW(NULL, full_path_to_argv0, PATH_SIZE); + if (len == 0 || + len == PATH_SIZE && GetLastError() == ERROR_INSUFFICIENT_BUFFER) { + fprintf( + stderr, + "error: %d while retrieving full path to this executable\n", + GetLastError()); + return 1; + } + + for (int n = 1; n < __argc; ++n) { + pyargv[n + 1] = __wargv[n]; + } + pyargv[0] = full_path_to_argv0; + pyargv[1] = full_path_to_argv0; + + return locate_py_main(__argc + 1, pyargv); +} diff --git a/build/fbcode_builder/CMake/make_fbpy_archive.py b/build/fbcode_builder/CMake/make_fbpy_archive.py new file mode 100755 index 000000000..3724feb21 --- /dev/null +++ b/build/fbcode_builder/CMake/make_fbpy_archive.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python3 +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +import argparse +import collections +import errno +import os +import shutil +import sys +import tempfile +import zipapp + +MANIFEST_SEPARATOR = " :: " +MANIFEST_HEADER_V1 = "FBPY_MANIFEST 1\n" + + +class UsageError(Exception): + def __init__(self, message): + self.message = message + + def __str__(self): + return self.message + + +class BadManifestError(UsageError): + def __init__(self, path, line_num, message): + full_msg = "%s:%s: %s" % (path, line_num, message) + super().__init__(full_msg) + self.path = path + self.line_num = line_num + self.raw_message = message + + +PathInfo = collections.namedtuple( + "PathInfo", ("src", "dest", "manifest_path", "manifest_line") +) + + +def parse_manifest(manifest, path_map): + bad_prefix = ".." + os.path.sep + manifest_dir = os.path.dirname(manifest) + with open(manifest, "r") as f: + line_num = 1 + line = f.readline() + if line != MANIFEST_HEADER_V1: + raise BadManifestError( + manifest, line_num, "Unexpected manifest file header" + ) + + for line in f: + line_num += 1 + if line.startswith("#"): + continue + line = line.rstrip("\n") + parts = line.split(MANIFEST_SEPARATOR) + if len(parts) != 2: + msg = "line must be of the form SRC %s DEST" % MANIFEST_SEPARATOR + raise BadManifestError(manifest, line_num, msg) + src, dest = parts + dest = os.path.normpath(dest) + if dest.startswith(bad_prefix): + msg = "destination path starts with %s: %s" % (bad_prefix, dest) + raise BadManifestError(manifest, line_num, msg) + + if not os.path.isabs(src): + src = os.path.normpath(os.path.join(manifest_dir, src)) + + if dest in path_map: + prev_info = path_map[dest] + msg = ( + "multiple source paths specified for destination " + "path %s. Previous source was %s from %s:%s" + % ( + dest, + prev_info.src, + prev_info.manifest_path, + prev_info.manifest_line, + ) + ) + raise BadManifestError(manifest, line_num, msg) + + info = PathInfo( + src=src, + dest=dest, + manifest_path=manifest, + manifest_line=line_num, + ) + path_map[dest] = info + + +def populate_install_tree(inst_dir, path_map): + os.mkdir(inst_dir) + dest_dirs = {"": False} + + def make_dest_dir(path): + if path in dest_dirs: + return + parent = os.path.dirname(path) + make_dest_dir(parent) + abs_path = os.path.join(inst_dir, path) + os.mkdir(abs_path) + dest_dirs[path] = False + + def install_file(info): + dir_name, base_name = os.path.split(info.dest) + make_dest_dir(dir_name) + if base_name == "__init__.py": + dest_dirs[dir_name] = True + abs_dest = os.path.join(inst_dir, info.dest) + shutil.copy2(info.src, abs_dest) + + # Copy all of the destination files + for info in path_map.values(): + install_file(info) + + # Create __init__ files in any directories that don't have them. + for dir_path, has_init in dest_dirs.items(): + if has_init: + continue + init_path = os.path.join(inst_dir, dir_path, "__init__.py") + with open(init_path, "w"): + pass + + +def build_zipapp(args, path_map): + """Create a self executing python binary using Python 3's built-in + zipapp module. + + This type of Python binary is relatively simple, as zipapp is part of the + standard library, but it does not support native language extensions + (.so/.dll files). + """ + dest_dir = os.path.dirname(args.output) + with tempfile.TemporaryDirectory(prefix="make_fbpy.", dir=dest_dir) as tmpdir: + inst_dir = os.path.join(tmpdir, "tree") + populate_install_tree(inst_dir, path_map) + + tmp_output = os.path.join(tmpdir, "output.exe") + zipapp.create_archive( + inst_dir, target=tmp_output, interpreter=args.python, main=args.main + ) + os.replace(tmp_output, args.output) + + +def create_main_module(args, inst_dir, path_map): + if not args.main: + assert "__main__.py" in path_map + return + + dest_path = os.path.join(inst_dir, "__main__.py") + main_module, main_fn = args.main.split(":") + main_contents = """\ +#!{python} + +if __name__ == "__main__": + import {main_module} + {main_module}.{main_fn}() +""".format( + python=args.python, main_module=main_module, main_fn=main_fn + ) + with open(dest_path, "w") as f: + f.write(main_contents) + os.chmod(dest_path, 0o755) + + +def build_install_dir(args, path_map): + """Create a directory that contains all of the sources, with a __main__ + module to run the program. + """ + # Populate a temporary directory first, then rename to the destination + # location. This ensures that we don't ever leave a halfway-built + # directory behind at the output path if something goes wrong. + dest_dir = os.path.dirname(args.output) + with tempfile.TemporaryDirectory(prefix="make_fbpy.", dir=dest_dir) as tmpdir: + inst_dir = os.path.join(tmpdir, "tree") + populate_install_tree(inst_dir, path_map) + create_main_module(args, inst_dir, path_map) + os.rename(inst_dir, args.output) + + +def ensure_directory(path): + try: + os.makedirs(path) + except OSError as ex: + if ex.errno != errno.EEXIST: + raise + + +def install_library(args, path_map): + """Create an installation directory a python library.""" + out_dir = args.output + out_manifest = args.output + ".manifest" + + install_dir = args.install_dir + if not install_dir: + install_dir = out_dir + + os.makedirs(out_dir) + with open(out_manifest, "w") as manifest: + manifest.write(MANIFEST_HEADER_V1) + for info in path_map.values(): + abs_dest = os.path.join(out_dir, info.dest) + ensure_directory(os.path.dirname(abs_dest)) + print("copy %r --> %r" % (info.src, abs_dest)) + shutil.copy2(info.src, abs_dest) + installed_dest = os.path.join(install_dir, info.dest) + manifest.write("%s%s%s\n" % (installed_dest, MANIFEST_SEPARATOR, info.dest)) + + +def parse_manifests(args): + # Process args.manifest_separator to help support older versions of CMake + if args.manifest_separator: + manifests = [] + for manifest_arg in args.manifests: + split_arg = manifest_arg.split(args.manifest_separator) + manifests.extend(split_arg) + args.manifests = manifests + + path_map = {} + for manifest in args.manifests: + parse_manifest(manifest, path_map) + + return path_map + + +def check_main_module(args, path_map): + # Translate an empty string in the --main argument to None, + # just to allow the CMake logic to be slightly simpler and pass in an + # empty string when it really wants the default __main__.py module to be + # used. + if args.main == "": + args.main = None + + if args.type == "lib-install": + if args.main is not None: + raise UsageError("cannot specify a --main argument with --type=lib-install") + return + + main_info = path_map.get("__main__.py") + if args.main: + if main_info is not None: + msg = ( + "specified an explicit main module with --main, " + "but the file listing already includes __main__.py" + ) + raise BadManifestError( + main_info.manifest_path, main_info.manifest_line, msg + ) + parts = args.main.split(":") + if len(parts) != 2: + raise UsageError( + "argument to --main must be of the form MODULE:CALLABLE " + "(received %s)" % (args.main,) + ) + else: + if main_info is None: + raise UsageError( + "no main module specified with --main, " + "and no __main__.py module present" + ) + + +BUILD_TYPES = { + "zipapp": build_zipapp, + "dir": build_install_dir, + "lib-install": install_library, +} + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("-o", "--output", required=True, help="The output file path") + ap.add_argument( + "--install-dir", + help="When used with --type=lib-install, this parameter specifies the " + "final location where the library where be installed. This can be " + "used to generate the library in one directory first, when you plan " + "to move or copy it to another final location later.", + ) + ap.add_argument( + "--manifest-separator", + help="Split manifest arguments around this separator. This is used " + "to support older versions of CMake that cannot supply the manifests " + "as separate arguments.", + ) + ap.add_argument( + "--main", + help="The main module to run, specified as :. " + "This must be specified if and only if the archive does not contain " + "a __main__.py file.", + ) + ap.add_argument( + "--python", + help="Explicitly specify the python interpreter to use for the " "executable.", + ) + ap.add_argument( + "--type", choices=BUILD_TYPES.keys(), help="The type of output to build." + ) + ap.add_argument( + "manifests", + nargs="+", + help="The manifest files specifying how to construct the archive", + ) + args = ap.parse_args() + + if args.python is None: + args.python = sys.executable + + if args.type is None: + # In the future we might want different default output types + # for different platforms. + args.type = "zipapp" + build_fn = BUILD_TYPES[args.type] + + try: + path_map = parse_manifests(args) + check_main_module(args, path_map) + except UsageError as ex: + print("error: %s" % (ex,), file=sys.stderr) + sys.exit(1) + + build_fn(args, path_map) + + +if __name__ == "__main__": + main() diff --git a/build/fbcode_builder/LICENSE b/build/fbcode_builder/LICENSE new file mode 100644 index 000000000..b96dcb048 --- /dev/null +++ b/build/fbcode_builder/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/build/fbcode_builder/README.docker b/build/fbcode_builder/README.docker new file mode 100644 index 000000000..4e9fa8a29 --- /dev/null +++ b/build/fbcode_builder/README.docker @@ -0,0 +1,44 @@ +## Debugging Docker builds + +To debug a a build failure, start up a shell inside the just-failed image as +follows: + +``` +docker ps -a | head # Grab the container ID +docker commit CONTAINER_ID # Grab the SHA string +docker run -it SHA_STRING /bin/bash +# Debug as usual, e.g. `./run-cmake.sh Debug`, `make`, `apt-get install gdb` +``` + +## A note on Docker security + +While the Dockerfile generated above is quite simple, you must be aware that +using Docker to run arbitrary code can present significant security risks: + + - Code signature validation is off by default (as of 2016), exposing you to + man-in-the-middle malicious code injection. + + - You implicitly trust the world -- a Dockerfile cannot annotate that + you trust the image `debian:8.6` because you trust a particular + certificate -- rather, you trust the name, and that it will never be + hijacked. + + - Sandboxing in the Linux kernel is not perfect, and the builds run code as + root. Any compromised code can likely escalate to the host system. + +Specifically, you must be very careful only to add trusted OS images to the +build flow. + +Consider setting this variable before running any Docker container -- this +will validate a signature on the base image before running code from it: + +``` +export DOCKER_CONTENT_TRUST=1 +``` + +Note that unless you go through the extra steps of notarizing the resulting +images, you will have to disable trust to enter intermediate images, e.g. + +``` +DOCKER_CONTENT_TRUST= docker run -it YOUR_IMAGE_ID /bin/bash +``` diff --git a/build/fbcode_builder/README.md b/build/fbcode_builder/README.md new file mode 100644 index 000000000..d47dd41c0 --- /dev/null +++ b/build/fbcode_builder/README.md @@ -0,0 +1,43 @@ +# Easy builds for Facebook projects + +This directory contains tools designed to simplify continuous-integration +(and other builds) of Facebook open source projects. In particular, this helps +manage builds for cross-project dependencies. + +The main entry point is the `getdeps.py` script. This script has several +subcommands, but the most notable is the `build` command. This will download +and build all dependencies for a project, and then build the project itself. + +## Deployment + +This directory is copied literally into a number of different Facebook open +source repositories. Any change made to code in this directory will be +automatically be replicated by our open source tooling into all GitHub hosted +repositories that use `fbcode_builder`. Typically this directory is copied +into the open source repositories as `build/fbcode_builder/`. + + +# Project Configuration Files + +The `manifests` subdirectory contains configuration files for many different +projects, describing how to build each project. These files also list +dependencies between projects, enabling `getdeps.py` to build all dependencies +for a project before building the project itself. + + +# Shared CMake utilities + +Since this directory is copied into many Facebook open source repositories, +it is also used to help share some CMake utility files across projects. The +`CMake/` subdirectory contains a number of `.cmake` files that are shared by +the CMake-based build systems across several different projects. + + +# Older Build Scripts + +This directory also still contains a handful of older build scripts that +pre-date the current `getdeps.py` build system. Most of the other `.py` files +in this top directory, apart from `getdeps.py` itself, are from this older +build system. This older system is only used by a few remaining projects, and +new projects should generally use the newer `getdeps.py` script, by adding a +new configuration file in the `manifests/` subdirectory. diff --git a/build/fbcode_builder/docker_build_with_ccache.sh b/build/fbcode_builder/docker_build_with_ccache.sh new file mode 100755 index 000000000..e922810d5 --- /dev/null +++ b/build/fbcode_builder/docker_build_with_ccache.sh @@ -0,0 +1,219 @@ +#!/bin/bash -uex +# Copyright (c) Facebook, Inc. and its affiliates. +set -o pipefail # Be sure to `|| :` commands that are allowed to fail. + +# +# Future: port this to Python if you are making significant changes. +# + +# Parse command-line arguments +build_timeout="" # Default to no time-out +print_usage() { + echo "Usage: $0 [--build-timeout TIMEOUT_VAL] SAVE-CCACHE-TO-DIR" + echo "SAVE-CCACHE-TO-DIR is required. An empty string discards the ccache." +} +while [[ $# -gt 0 ]]; do + case "$1" in + --build-timeout) + shift + build_timeout="$1" + if [[ "$build_timeout" != "" ]] ; then + timeout "$build_timeout" true # fail early on invalid timeouts + fi + ;; + -h|--help) + print_usage + exit + ;; + *) + break + ;; + esac + shift +done +# There is one required argument, but an empty string is allowed. +if [[ "$#" != 1 ]] ; then + print_usage + exit 1 +fi +save_ccache_to_dir="$1" +if [[ "$save_ccache_to_dir" != "" ]] ; then + mkdir -p "$save_ccache_to_dir" # fail early if there's nowhere to save +else + echo "WARNING: Will not save /ccache from inside the Docker container" +fi + +rand_guid() { + echo "$(date +%s)_${RANDOM}_${RANDOM}_${RANDOM}_${RANDOM}" +} + +id=fbcode_builder_image_id=$(rand_guid) +logfile=$(mktemp) + +echo " + + +Running build with timeout '$build_timeout', label $id, and log in $logfile + + +" + +if [[ "$build_timeout" != "" ]] ; then + # Kill the container after $build_timeout. Using `/bin/timeout` would cause + # Docker to destroy the most recent container and lose its cache. + ( + sleep "$build_timeout" + echo "Build timed out after $build_timeout" 1>&2 + while true; do + maybe_container=$( + grep -E '^( ---> Running in [0-9a-f]+|FBCODE_BUILDER_EXIT)$' "$logfile" | + tail -n 1 | awk '{print $NF}' + ) + if [[ "$maybe_container" == "FBCODE_BUILDER_EXIT" ]] ; then + echo "Time-out successfully terminated build" 1>&2 + break + fi + echo "Time-out: trying to kill $maybe_container" 1>&2 + # This kill fail if we get unlucky, try again soon. + docker kill "$maybe_container" || sleep 5 + done + ) & +fi + +build_exit_code=0 +# `docker build` is allowed to fail, and `pipefail` means we must check the +# failure explicitly. +if ! docker build --label="$id" . 2>&1 | tee "$logfile" ; then + build_exit_code="${PIPESTATUS[0]}" + # NB: We are going to deliberately forge ahead even if `tee` failed. + # If it did, we have a problem with tempfile creation, and all is sad. + echo "Build failed with code $build_exit_code, trying to save ccache" 1>&2 +fi +# Stop trying to kill the container. +echo $'\nFBCODE_BUILDER_EXIT' >> "$logfile" + +if [[ "$save_ccache_to_dir" == "" ]] ; then + echo "Not inspecting Docker build, since saving the ccache wasn't requested." + exit "$build_exit_code" +fi + +img=$(docker images --filter "label=$id" -a -q) +if [[ "$img" == "" ]] ; then + docker images -a + echo "In the above list, failed to find most recent image with $id" 1>&2 + # Usually, the above `docker kill` will leave us with an up-to-the-second + # container, from which we can extract the cache. However, if that fails + # for any reason, this loop will instead grab the latest available image. + # + # It's possible for this log search to get confused due to the output of + # the build command itself, but since our builds aren't **trying** to + # break cache, we probably won't randomly hit an ID from another build. + img=$( + grep -E '^ ---> (Running in [0-9a-f]+|[0-9a-f]+)$' "$logfile" | tac | + sed 's/Running in /container_/;s/ ---> //;' | ( + while read -r x ; do + # Both docker commands below print an image ID to stdout on + # success, so we just need to know when to stop. + if [[ "$x" =~ container_.* ]] ; then + if docker commit "${x#container_}" ; then + break + fi + elif docker inspect --type image -f '{{.Id}}' "$x" ; then + break + fi + done + ) + ) + if [[ "$img" == "" ]] ; then + echo "Failed to find valid container or image ID in log $logfile" 1>&2 + exit 1 + fi +elif [[ "$(echo "$img" | wc -l)" != 1 ]] ; then + # Shouldn't really happen, but be explicit if it does. + echo "Multiple images with label $id, taking the latest of:" + echo "$img" + img=$(echo "$img" | head -n 1) +fi + +container_name="fbcode_builder_container_$(rand_guid)" +echo "Starting $container_name from latest image of the build with $id --" +echo "$img" + +# ccache collection must be done outside of the Docker build steps because +# we need to be able to kill it on timeout. +# +# This step grows the max cache size to slightly exceed than the working set +# of a successful build. This simple design persists the max size in the +# cache directory itself (the env var CCACHE_MAXSIZE does not even work with +# older ccaches like the one on 14.04). +# +# Future: copy this script into the Docker image via Dockerfile. +( + # By default, fbcode_builder creates an unsigned image, so the `docker + # run` below would fail if DOCKER_CONTENT_TRUST were set. So we unset it + # just for this one run. + export DOCKER_CONTENT_TRUST= + # CAUTION: The inner bash runs without -uex, so code accordingly. + docker run --user root --name "$container_name" "$img" /bin/bash -c ' + build_exit_code='"$build_exit_code"' + + # Might be useful if debugging whether max cache size is too small? + grep " Cleaning up cache directory " /tmp/ccache.log + + export CCACHE_DIR=/ccache + ccache -s + + echo "Total bytes in /ccache:"; + total_bytes=$(du -sb /ccache | awk "{print \$1}") + echo "$total_bytes" + + echo "Used bytes in /ccache:"; + used_bytes=$( + du -sb $(find /ccache -type f -newermt @$( + cat /FBCODE_BUILDER_CCACHE_START_TIME + )) | awk "{t += \$1} END {print t}" + ) + echo "$used_bytes" + + # Goal: set the max cache to 750MB over 125% of the usage of a + # successful build. If this is too small, it takes too long to get a + # cache fully warmed up. Plus, ccache cleans 100-200MB before reaching + # the max cache size, so a large margin is essential to prevent misses. + desired_mb=$(( 750 + used_bytes / 800000 )) # 125% in decimal MB: 1e6/1.25 + if [[ "$build_exit_code" != "0" ]] ; then + # For a bad build, disallow shrinking the max cache size. Instead of + # the max cache size, we use on-disk size, which ccache keeps at least + # 150MB under the actual max size, hence the 400MB safety margin. + cur_max_mb=$(( 400 + total_bytes / 1000000 )) # ccache uses decimal MB + if [[ "$desired_mb" -le "$cur_max_mb" ]] ; then + desired_mb="" + fi + fi + + if [[ "$desired_mb" != "" ]] ; then + echo "Updating cache size to $desired_mb MB" + ccache -M "${desired_mb}M" + ccache -s + fi + + # Subshell because `time` the binary may not be installed. + if (time tar czf /ccache.tgz /ccache) ; then + ls -l /ccache.tgz + else + # This `else` ensures we never overwrite the current cache with + # partial data in case of error, even if somebody adds code below. + rm /ccache.tgz + exit 1 + fi + ' +) + +echo "Updating $save_ccache_to_dir/ccache.tgz" +# This will not delete the existing cache if `docker run` didn't make one +docker cp "$container_name:/ccache.tgz" "$save_ccache_to_dir/" + +# Future: it'd be nice if Travis allowed us to retry if the build timed out, +# since we'll make more progress thanks to the cache. As-is, we have to +# wait for the next commit to land. +echo "Build exited with code $build_exit_code" +exit "$build_exit_code" diff --git a/build/fbcode_builder/docker_builder.py b/build/fbcode_builder/docker_builder.py new file mode 100644 index 000000000..83df7137c --- /dev/null +++ b/build/fbcode_builder/docker_builder.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" + +Extends FBCodeBuilder to produce Docker context directories. + +In order to get the largest iteration-time savings from Docker's build +caching, you will want to: + - Use fine-grained steps as appropriate (e.g. separate make & make install), + - Start your action sequence with the lowest-risk steps, and with the steps + that change the least often, and + - Put the steps that you are debugging towards the very end. + +""" +import logging +import os +import shutil +import tempfile + +from fbcode_builder import FBCodeBuilder +from shell_quoting import raw_shell, shell_comment, shell_join, ShellQuoted, path_join +from utils import recursively_flatten_list, run_command + + +class DockerFBCodeBuilder(FBCodeBuilder): + def _user(self): + return self.option("user", "root") + + def _change_user(self): + return ShellQuoted("USER {u}").format(u=self._user()) + + def setup(self): + # Please add RPM-based OSes here as appropriate. + # + # To allow exercising non-root installs -- we change users after the + # system packages are installed. TODO: For users not defined in the + # image, we should probably `useradd`. + return self.step( + "Setup", + [ + # Docker's FROM does not understand shell quoting. + ShellQuoted("FROM {}".format(self.option("os_image"))), + # /bin/sh syntax is a pain + ShellQuoted('SHELL ["/bin/bash", "-c"]'), + ] + + self.install_debian_deps() + + [self._change_user()] + + [self.workdir(self.option("prefix"))] + + self.create_python_venv() + + self.python_venv() + + self.rust_toolchain(), + ) + + def python_venv(self): + # To both avoid calling venv activate on each RUN command AND to ensure + # it is present when the resulting container is run add to PATH + actions = [] + if self.option("PYTHON_VENV", "OFF") == "ON": + actions = ShellQuoted("ENV PATH={p}:$PATH").format( + p=path_join(self.option("prefix"), "venv", "bin") + ) + return actions + + def step(self, name, actions): + assert "\n" not in name, "Name {0} would span > 1 line".format(name) + b = ShellQuoted("") + return [ShellQuoted("### {0} ###".format(name)), b] + actions + [b] + + def run(self, shell_cmd): + return ShellQuoted("RUN {cmd}").format(cmd=shell_cmd) + + def set_env(self, key, value): + return ShellQuoted("ENV {key}={val}").format(key=key, val=value) + + def workdir(self, dir): + return [ + # As late as Docker 1.12.5, this results in `build` being owned + # by root:root -- the explicit `mkdir` works around the bug: + # USER nobody + # WORKDIR build + ShellQuoted("USER root"), + ShellQuoted("RUN mkdir -p {d} && chown {u} {d}").format( + d=dir, u=self._user() + ), + self._change_user(), + ShellQuoted("WORKDIR {dir}").format(dir=dir), + ] + + def comment(self, comment): + # This should not be a command since we don't want comment changes + # to invalidate the Docker build cache. + return shell_comment(comment) + + def copy_local_repo(self, repo_dir, dest_name): + fd, archive_path = tempfile.mkstemp( + prefix="local_repo_{0}_".format(dest_name), + suffix=".tgz", + dir=os.path.abspath(self.option("docker_context_dir")), + ) + os.close(fd) + run_command("tar", "czf", archive_path, ".", cwd=repo_dir) + return [ + ShellQuoted("ADD {archive} {dest_name}").format( + archive=os.path.basename(archive_path), dest_name=dest_name + ), + # Docker permissions make very little sense... see also workdir() + ShellQuoted("USER root"), + ShellQuoted("RUN chown -R {u} {d}").format(d=dest_name, u=self._user()), + self._change_user(), + ] + + def _render_impl(self, steps): + return raw_shell(shell_join("\n", recursively_flatten_list(steps))) + + def debian_ccache_setup_steps(self): + source_ccache_tgz = self.option("ccache_tgz", "") + if not source_ccache_tgz: + logging.info("Docker ccache not enabled") + return [] + + dest_ccache_tgz = os.path.join(self.option("docker_context_dir"), "ccache.tgz") + + try: + try: + os.link(source_ccache_tgz, dest_ccache_tgz) + except OSError: + logging.exception( + "Hard-linking {s} to {d} failed, falling back to copy".format( + s=source_ccache_tgz, d=dest_ccache_tgz + ) + ) + shutil.copyfile(source_ccache_tgz, dest_ccache_tgz) + except Exception: + logging.exception( + "Failed to copy or link {s} to {d}, aborting".format( + s=source_ccache_tgz, d=dest_ccache_tgz + ) + ) + raise + + return [ + # Separate layer so that in development we avoid re-downloads. + self.run(ShellQuoted("apt-get install -yq ccache")), + ShellQuoted("ADD ccache.tgz /"), + ShellQuoted( + # Set CCACHE_DIR before the `ccache` invocations below. + "ENV CCACHE_DIR=/ccache " + # No clang support for now, so it's easiest to hardcode gcc. + 'CC="ccache gcc" CXX="ccache g++" ' + # Always log for ease of debugging. For real FB projects, + # this log is several megabytes, so dumping it to stdout + # would likely exceed the Travis log limit of 4MB. + # + # On a local machine, `docker cp` will get you the data. To + # get the data out from Travis, I would compress and dump + # uuencoded bytes to the log -- for Bistro this was about + # 600kb or 8000 lines: + # + # apt-get install sharutils + # bzip2 -9 < /tmp/ccache.log | uuencode -m ccache.log.bz2 + "CCACHE_LOGFILE=/tmp/ccache.log" + ), + self.run( + ShellQuoted( + # Future: Skipping this part made this Docker step instant, + # saving ~1min of build time. It's unclear if it is the + # chown or the du, but probably the chown -- since a large + # part of the cost is incurred at image save time. + # + # ccache.tgz may be empty, or may have the wrong + # permissions. + "mkdir -p /ccache && time chown -R nobody /ccache && " + "time du -sh /ccache && " + # Reset stats so `docker_build_with_ccache.sh` can print + # useful values at the end of the run. + "echo === Prev run stats === && ccache -s && ccache -z && " + # Record the current time to let travis_build.sh figure out + # the number of bytes in the cache that are actually used -- + # this is crucial for tuning the maximum cache size. + "date +%s > /FBCODE_BUILDER_CCACHE_START_TIME && " + # The build running as `nobody` should be able to write here + "chown nobody /tmp/ccache.log" + ) + ), + ] diff --git a/build/fbcode_builder/docker_enable_ipv6.sh b/build/fbcode_builder/docker_enable_ipv6.sh new file mode 100755 index 000000000..3752f6f5e --- /dev/null +++ b/build/fbcode_builder/docker_enable_ipv6.sh @@ -0,0 +1,13 @@ +#!/bin/sh +# Copyright (c) Facebook, Inc. and its affiliates. + + +# `daemon.json` is normally missing, but let's log it in case that changes. +touch /etc/docker/daemon.json +service docker stop +echo '{"ipv6": true, "fixed-cidr-v6": "2001:db8:1::/64"}' > /etc/docker/daemon.json +service docker start +# Fail early if docker failed on start -- add `- sudo dockerd` to debug. +docker info +# Paranoia log: what if our config got overwritten? +cat /etc/docker/daemon.json diff --git a/build/fbcode_builder/fbcode_builder.py b/build/fbcode_builder/fbcode_builder.py new file mode 100644 index 000000000..742099321 --- /dev/null +++ b/build/fbcode_builder/fbcode_builder.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" + +This is a small DSL to describe builds of Facebook's open-source projects +that are published to Github from a single internal repo, including projects +that depend on folly, wangle, proxygen, fbthrift, etc. + +This file defines the interface of the DSL, and common utilieis, but you +will have to instantiate a specific builder, with specific options, in +order to get work done -- see e.g. make_docker_context.py. + +== Design notes == + +Goals: + + - A simple declarative language for what needs to be checked out & built, + how, in what order. + + - The same specification should work for external continuous integration + builds (e.g. Travis + Docker) and for internal VM-based continuous + integration builds. + + - One should be able to build without root, and to install to a prefix. + +Non-goals: + + - General usefulness. The only point of this is to make it easier to build + and test Facebook's open-source services. + +Ideas for the future -- these may not be very good :) + + - Especially on Ubuntu 14.04 the current initial setup is inefficient: + we add PPAs after having installed a bunch of packages -- this prompts + reinstalls of large amounts of code. We also `apt-get update` a few + times. + + - A "shell script" builder. Like DockerFBCodeBuilder, but outputs a + shell script that runs outside of a container. Or maybe even + synchronously executes the shell commands, `make`-style. + + - A "Makefile" generator. That might make iterating on builds even quicker + than what you can currently get with Docker build caching. + + - Generate a rebuild script that can be run e.g. inside the built Docker + container by tagging certain steps with list-inheriting Python objects: + * do change directories + * do NOT `git clone` -- if we want to update code this should be a + separate script that e.g. runs rebase on top of specific targets + across all the repos. + * do NOT install software (most / all setup can be skipped) + * do NOT `autoreconf` or `configure` + * do `make` and `cmake` + + - If we get non-Debian OSes, part of ccache setup should be factored out. +""" + +import os +import re + +from shell_quoting import path_join, shell_join, ShellQuoted + + +def _read_project_github_hashes(): + base_dir = "deps/github_hashes/" # trailing slash used in regex below + for dirname, _, files in os.walk(base_dir): + for filename in files: + path = os.path.join(dirname, filename) + with open(path) as f: + m_proj = re.match("^" + base_dir + "(.*)-rev\.txt$", path) + if m_proj is None: + raise RuntimeError("Not a hash file? {0}".format(path)) + m_hash = re.match("^Subproject commit ([0-9a-f]+)\n$", f.read()) + if m_hash is None: + raise RuntimeError("No hash in {0}".format(path)) + yield m_proj.group(1), m_hash.group(1) + + +class FBCodeBuilder(object): + def __init__(self, **kwargs): + self._options_do_not_access = kwargs # Use .option() instead. + # This raises upon detecting options that are specified but unused, + # because otherwise it is very easy to make a typo in option names. + self.options_used = set() + # Mark 'projects_dir' used even if the build installs no github + # projects. This is needed because driver programs like + # `shell_builder.py` unconditionally set this for all builds. + self._github_dir = self.option("projects_dir") + self._github_hashes = dict(_read_project_github_hashes()) + + def __repr__(self): + return "{0}({1})".format( + self.__class__.__name__, + ", ".join( + "{0}={1}".format(k, repr(v)) + for k, v in self._options_do_not_access.items() + ), + ) + + def option(self, name, default=None): + value = self._options_do_not_access.get(name, default) + if value is None: + raise RuntimeError("Option {0} is required".format(name)) + self.options_used.add(name) + return value + + def has_option(self, name): + return name in self._options_do_not_access + + def add_option(self, name, value): + if name in self._options_do_not_access: + raise RuntimeError("Option {0} already set".format(name)) + self._options_do_not_access[name] = value + + # + # Abstract parts common to every installation flow + # + + def render(self, steps): + """ + + Converts nested actions to your builder's expected output format. + Typically takes the output of build(). + + """ + res = self._render_impl(steps) # Implementation-dependent + # Now that the output is rendered, we expect all options to have + # been used. + unused_options = set(self._options_do_not_access) + unused_options -= self.options_used + if unused_options: + raise RuntimeError( + "Unused options: {0} -- please check if you made a typo " + "in any of them. Those that are truly not useful should " + "be not be set so that this typo detection can be useful.".format( + unused_options + ) + ) + return res + + def build(self, steps): + if not steps: + raise RuntimeError( + "Please ensure that the config you are passing " "contains steps" + ) + return [self.setup(), self.diagnostics()] + steps + + def setup(self): + "Your builder may want to install packages here." + raise NotImplementedError + + def diagnostics(self): + "Log some system diagnostics before/after setup for ease of debugging" + # The builder's repr is not used in a command to avoid pointlessly + # invalidating Docker's build cache. + return self.step( + "Diagnostics", + [ + self.comment("Builder {0}".format(repr(self))), + self.run(ShellQuoted("hostname")), + self.run(ShellQuoted("cat /etc/issue || echo no /etc/issue")), + self.run(ShellQuoted("g++ --version || echo g++ not installed")), + self.run(ShellQuoted("cmake --version || echo cmake not installed")), + ], + ) + + def step(self, name, actions): + "A labeled collection of actions or other steps" + raise NotImplementedError + + def run(self, shell_cmd): + "Run this bash command" + raise NotImplementedError + + def set_env(self, key, value): + 'Set the environment "key" to value "value"' + raise NotImplementedError + + def workdir(self, dir): + "Create this directory if it does not exist, and change into it" + raise NotImplementedError + + def copy_local_repo(self, dir, dest_name): + """ + Copy the local repo at `dir` into this step's `workdir()`, analog of: + cp -r /path/to/folly folly + """ + raise NotImplementedError + + def python_deps(self): + return [ + "wheel", + "cython==0.28.6", + ] + + def debian_deps(self): + return [ + "autoconf-archive", + "bison", + "build-essential", + "cmake", + "curl", + "flex", + "git", + "gperf", + "joe", + "libboost-all-dev", + "libcap-dev", + "libdouble-conversion-dev", + "libevent-dev", + "libgflags-dev", + "libgoogle-glog-dev", + "libkrb5-dev", + "libpcre3-dev", + "libpthread-stubs0-dev", + "libnuma-dev", + "libsasl2-dev", + "libsnappy-dev", + "libsqlite3-dev", + "libssl-dev", + "libtool", + "netcat-openbsd", + "pkg-config", + "sudo", + "unzip", + "wget", + "python3-venv", + ] + + # + # Specific build helpers + # + + def install_debian_deps(self): + actions = [ + self.run( + ShellQuoted("apt-get update && apt-get install -yq {deps}").format( + deps=shell_join( + " ", (ShellQuoted(dep) for dep in self.debian_deps()) + ) + ) + ), + ] + gcc_version = self.option("gcc_version") + + # Make the selected GCC the default before building anything + actions.extend( + [ + self.run( + ShellQuoted("apt-get install -yq {c} {cpp}").format( + c=ShellQuoted("gcc-{v}").format(v=gcc_version), + cpp=ShellQuoted("g++-{v}").format(v=gcc_version), + ) + ), + self.run( + ShellQuoted( + "update-alternatives --install /usr/bin/gcc gcc {c} 40 " + "--slave /usr/bin/g++ g++ {cpp}" + ).format( + c=ShellQuoted("/usr/bin/gcc-{v}").format(v=gcc_version), + cpp=ShellQuoted("/usr/bin/g++-{v}").format(v=gcc_version), + ) + ), + self.run(ShellQuoted("update-alternatives --config gcc")), + ] + ) + + actions.extend(self.debian_ccache_setup_steps()) + + return self.step("Install packages for Debian-based OS", actions) + + def create_python_venv(self): + actions = [] + if self.option("PYTHON_VENV", "OFF") == "ON": + actions.append( + self.run( + ShellQuoted("python3 -m venv {p}").format( + p=path_join(self.option("prefix"), "venv") + ) + ) + ) + return actions + + def python_venv(self): + actions = [] + if self.option("PYTHON_VENV", "OFF") == "ON": + actions.append( + ShellQuoted("source {p}").format( + p=path_join(self.option("prefix"), "venv", "bin", "activate") + ) + ) + + actions.append( + self.run( + ShellQuoted("python3 -m pip install {deps}").format( + deps=shell_join( + " ", (ShellQuoted(dep) for dep in self.python_deps()) + ) + ) + ) + ) + return actions + + def enable_rust_toolchain(self, toolchain="stable", is_bootstrap=True): + choices = set(["stable", "beta", "nightly"]) + + assert toolchain in choices, ( + "while enabling rust toolchain: {} is not in {}" + ).format(toolchain, choices) + + rust_toolchain_opt = (toolchain, is_bootstrap) + prev_opt = self.option("rust_toolchain", rust_toolchain_opt) + assert prev_opt == rust_toolchain_opt, ( + "while enabling rust toolchain: previous toolchain already set to" + " {}, but trying to set it to {} now" + ).format(prev_opt, rust_toolchain_opt) + + self.add_option("rust_toolchain", rust_toolchain_opt) + + def rust_toolchain(self): + actions = [] + if self.option("rust_toolchain", False): + (toolchain, is_bootstrap) = self.option("rust_toolchain") + rust_dir = path_join(self.option("prefix"), "rust") + actions = [ + self.set_env("CARGO_HOME", rust_dir), + self.set_env("RUSTUP_HOME", rust_dir), + self.set_env("RUSTC_BOOTSTRAP", "1" if is_bootstrap else "0"), + self.run( + ShellQuoted( + "curl -sSf https://build.travis-ci.com/files/rustup-init.sh" + " | sh -s --" + " --default-toolchain={r} " + " --profile=minimal" + " --no-modify-path" + " -y" + ).format(p=rust_dir, r=toolchain) + ), + self.set_env( + "PATH", + ShellQuoted("{p}:$PATH").format(p=path_join(rust_dir, "bin")), + ), + self.run(ShellQuoted("rustup update")), + self.run(ShellQuoted("rustc --version")), + self.run(ShellQuoted("rustup --version")), + self.run(ShellQuoted("cargo --version")), + ] + return actions + + def debian_ccache_setup_steps(self): + return [] # It's ok to ship a renderer without ccache support. + + def github_project_workdir(self, project, path): + # Only check out a non-default branch if requested. This especially + # makes sense when building from a local repo. + git_hash = self.option( + "{0}:git_hash".format(project), + # Any repo that has a hash in deps/github_hashes defaults to + # that, with the goal of making builds maximally consistent. + self._github_hashes.get(project, ""), + ) + maybe_change_branch = ( + [ + self.run(ShellQuoted("git checkout {hash}").format(hash=git_hash)), + ] + if git_hash + else [] + ) + + local_repo_dir = self.option("{0}:local_repo_dir".format(project), "") + return self.step( + "Check out {0}, workdir {1}".format(project, path), + [ + self.workdir(self._github_dir), + self.run( + ShellQuoted("git clone {opts} https://github.com/{p}").format( + p=project, + opts=ShellQuoted( + self.option("{}:git_clone_opts".format(project), "") + ), + ) + ) + if not local_repo_dir + else self.copy_local_repo(local_repo_dir, os.path.basename(project)), + self.workdir( + path_join(self._github_dir, os.path.basename(project), path), + ), + ] + + maybe_change_branch, + ) + + def fb_github_project_workdir(self, project_and_path, github_org="facebook"): + "This helper lets Facebook-internal CI special-cases FB projects" + project, path = project_and_path.split("/", 1) + return self.github_project_workdir(github_org + "/" + project, path) + + def _make_vars(self, make_vars): + return shell_join( + " ", + ( + ShellQuoted("{k}={v}").format(k=k, v=v) + for k, v in ({} if make_vars is None else make_vars).items() + ), + ) + + def parallel_make(self, make_vars=None): + return self.run( + ShellQuoted("make -j {n} VERBOSE=1 {vars}").format( + n=self.option("make_parallelism"), + vars=self._make_vars(make_vars), + ) + ) + + def make_and_install(self, make_vars=None): + return [ + self.parallel_make(make_vars), + self.run( + ShellQuoted("make install VERBOSE=1 {vars}").format( + vars=self._make_vars(make_vars), + ) + ), + ] + + def configure(self, name=None): + autoconf_options = {} + if name is not None: + autoconf_options.update( + self.option("{0}:autoconf_options".format(name), {}) + ) + return [ + self.run( + ShellQuoted( + 'LDFLAGS="$LDFLAGS -L"{p}"/lib -Wl,-rpath="{p}"/lib" ' + 'CFLAGS="$CFLAGS -I"{p}"/include" ' + 'CPPFLAGS="$CPPFLAGS -I"{p}"/include" ' + "PY_PREFIX={p} " + "./configure --prefix={p} {args}" + ).format( + p=self.option("prefix"), + args=shell_join( + " ", + ( + ShellQuoted("{k}={v}").format(k=k, v=v) + for k, v in autoconf_options.items() + ), + ), + ) + ), + ] + + def autoconf_install(self, name): + return self.step( + "Build and install {0}".format(name), + [ + self.run(ShellQuoted("autoreconf -ivf")), + ] + + self.configure() + + self.make_and_install(), + ) + + def cmake_configure(self, name, cmake_path=".."): + cmake_defines = { + "BUILD_SHARED_LIBS": "ON", + "CMAKE_INSTALL_PREFIX": self.option("prefix"), + } + + # Hacks to add thriftpy3 support + if "BUILD_THRIFT_PY3" in os.environ and "folly" in name: + cmake_defines["PYTHON_EXTENSIONS"] = "True" + + if "BUILD_THRIFT_PY3" in os.environ and "fbthrift" in name: + cmake_defines["thriftpy3"] = "ON" + + cmake_defines.update(self.option("{0}:cmake_defines".format(name), {})) + return [ + self.run( + ShellQuoted( + 'CXXFLAGS="$CXXFLAGS -fPIC -isystem "{p}"/include" ' + 'CFLAGS="$CFLAGS -fPIC -isystem "{p}"/include" ' + "cmake {args} {cmake_path}" + ).format( + p=self.option("prefix"), + args=shell_join( + " ", + ( + ShellQuoted("-D{k}={v}").format(k=k, v=v) + for k, v in cmake_defines.items() + ), + ), + cmake_path=cmake_path, + ) + ), + ] + + def cmake_install(self, name, cmake_path=".."): + return self.step( + "Build and install {0}".format(name), + self.cmake_configure(name, cmake_path) + self.make_and_install(), + ) + + def cargo_build(self, name): + return self.step( + "Build {0}".format(name), + [ + self.run( + ShellQuoted("cargo build -j {n}").format( + n=self.option("make_parallelism") + ) + ) + ], + ) + + def fb_github_autoconf_install(self, project_and_path, github_org="facebook"): + return [ + self.fb_github_project_workdir(project_and_path, github_org), + self.autoconf_install(project_and_path), + ] + + def fb_github_cmake_install( + self, project_and_path, cmake_path="..", github_org="facebook" + ): + return [ + self.fb_github_project_workdir(project_and_path, github_org), + self.cmake_install(project_and_path, cmake_path), + ] + + def fb_github_cargo_build(self, project_and_path, github_org="facebook"): + return [ + self.fb_github_project_workdir(project_and_path, github_org), + self.cargo_build(project_and_path), + ] diff --git a/build/fbcode_builder/fbcode_builder_config.py b/build/fbcode_builder/fbcode_builder_config.py new file mode 100644 index 000000000..5ba6e607a --- /dev/null +++ b/build/fbcode_builder/fbcode_builder_config.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +"Demo config, so that `make_docker_context.py --help` works in this directory." + +config = { + "fbcode_builder_spec": lambda _builder: { + "depends_on": [], + "steps": [], + }, + "github_project": "demo/project", +} diff --git a/build/fbcode_builder/getdeps.py b/build/fbcode_builder/getdeps.py new file mode 100755 index 000000000..1b539735f --- /dev/null +++ b/build/fbcode_builder/getdeps.py @@ -0,0 +1,1071 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import json +import os +import shutil +import subprocess +import sys +import tarfile +import tempfile + +# We don't import cache.create_cache directly as the facebook +# specific import below may monkey patch it, and we want to +# observe the patched version of this function! +import getdeps.cache as cache_module +from getdeps.buildopts import setup_build_options +from getdeps.dyndeps import create_dyn_dep_munger +from getdeps.errors import TransientFailure +from getdeps.fetcher import ( + SystemPackageFetcher, + file_name_is_cmake_file, + list_files_under_dir_newer_than_timestamp, +) +from getdeps.load import ManifestLoader +from getdeps.manifest import ManifestParser +from getdeps.platform import HostType +from getdeps.runcmd import run_cmd +from getdeps.subcmd import SubCmd, add_subcommands, cmd + + +try: + import getdeps.facebook # noqa: F401 +except ImportError: + # we don't ship the facebook specific subdir, + # so allow that to fail silently + pass + + +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "getdeps")) + + +class UsageError(Exception): + pass + + +@cmd("validate-manifest", "parse a manifest and validate that it is correct") +class ValidateManifest(SubCmd): + def run(self, args): + try: + ManifestParser(file_name=args.file_name) + print("OK", file=sys.stderr) + return 0 + except Exception as exc: + print("ERROR: %s" % str(exc), file=sys.stderr) + return 1 + + def setup_parser(self, parser): + parser.add_argument("file_name", help="path to the manifest file") + + +@cmd("show-host-type", "outputs the host type tuple for the host machine") +class ShowHostType(SubCmd): + def run(self, args): + host = HostType() + print("%s" % host.as_tuple_string()) + return 0 + + +class ProjectCmdBase(SubCmd): + def run(self, args): + opts = setup_build_options(args) + + if args.current_project is not None: + opts.repo_project = args.current_project + if args.project is None: + if opts.repo_project is None: + raise UsageError( + "no project name specified, and no .projectid file found" + ) + if opts.repo_project == "fbsource": + # The fbsource repository is a little special. There is no project + # manifest file for it. A specific project must always be explicitly + # specified when building from fbsource. + raise UsageError( + "no project name specified (required when building in fbsource)" + ) + args.project = opts.repo_project + + ctx_gen = opts.get_context_generator(facebook_internal=args.facebook_internal) + if args.test_dependencies: + ctx_gen.set_value_for_all_projects("test", "on") + if args.enable_tests: + ctx_gen.set_value_for_project(args.project, "test", "on") + else: + ctx_gen.set_value_for_project(args.project, "test", "off") + + loader = ManifestLoader(opts, ctx_gen) + self.process_project_dir_arguments(args, loader) + + manifest = loader.load_manifest(args.project) + + self.run_project_cmd(args, loader, manifest) + + def process_project_dir_arguments(self, args, loader): + def parse_project_arg(arg, arg_type): + parts = arg.split(":") + if len(parts) == 2: + project, path = parts + elif len(parts) == 1: + project = args.project + path = parts[0] + # On Windows path contains colon, e.g. C:\open + elif os.name == "nt" and len(parts) == 3: + project = parts[0] + path = parts[1] + ":" + parts[2] + else: + raise UsageError( + "invalid %s argument; too many ':' characters: %s" % (arg_type, arg) + ) + + return project, os.path.abspath(path) + + # If we are currently running from a project repository, + # use the current repository for the project sources. + build_opts = loader.build_opts + if build_opts.repo_project is not None and build_opts.repo_root is not None: + loader.set_project_src_dir(build_opts.repo_project, build_opts.repo_root) + + for arg in args.src_dir: + project, path = parse_project_arg(arg, "--src-dir") + loader.set_project_src_dir(project, path) + + for arg in args.build_dir: + project, path = parse_project_arg(arg, "--build-dir") + loader.set_project_build_dir(project, path) + + for arg in args.install_dir: + project, path = parse_project_arg(arg, "--install-dir") + loader.set_project_install_dir(project, path) + + for arg in args.project_install_prefix: + project, path = parse_project_arg(arg, "--install-prefix") + loader.set_project_install_prefix(project, path) + + def setup_parser(self, parser): + parser.add_argument( + "project", + nargs="?", + help=( + "name of the project or path to a manifest " + "file describing the project" + ), + ) + parser.add_argument( + "--no-tests", + action="store_false", + dest="enable_tests", + default=True, + help="Disable building tests for this project.", + ) + parser.add_argument( + "--test-dependencies", + action="store_true", + help="Enable building tests for dependencies as well.", + ) + parser.add_argument( + "--current-project", + help="Specify the name of the fbcode_builder manifest file for the " + "current repository. If not specified, the code will attempt to find " + "this in a .projectid file in the repository root.", + ) + parser.add_argument( + "--src-dir", + default=[], + action="append", + help="Specify a local directory to use for the project source, " + "rather than fetching it.", + ) + parser.add_argument( + "--build-dir", + default=[], + action="append", + help="Explicitly specify the build directory to use for the " + "project, instead of the default location in the scratch path. " + "This only affects the project specified, and not its dependencies.", + ) + parser.add_argument( + "--install-dir", + default=[], + action="append", + help="Explicitly specify the install directory to use for the " + "project, instead of the default location in the scratch path. " + "This only affects the project specified, and not its dependencies.", + ) + parser.add_argument( + "--project-install-prefix", + default=[], + action="append", + help="Specify the final deployment installation path for a project", + ) + + self.setup_project_cmd_parser(parser) + + def setup_project_cmd_parser(self, parser): + pass + + +class CachedProject(object): + """A helper that allows calling the cache logic for a project + from both the build and the fetch code""" + + def __init__(self, cache, loader, m): + self.m = m + self.inst_dir = loader.get_project_install_dir(m) + self.project_hash = loader.get_project_hash(m) + self.ctx = loader.ctx_gen.get_context(m.name) + self.loader = loader + self.cache = cache + + self.cache_file_name = "-".join( + ( + m.name, + self.ctx.get("os"), + self.ctx.get("distro") or "none", + self.ctx.get("distro_vers") or "none", + self.project_hash, + "buildcache.tgz", + ) + ) + + def is_cacheable(self): + """We only cache third party projects""" + return self.cache and self.m.shipit_project is None + + def was_cached(self): + cached_marker = os.path.join(self.inst_dir, ".getdeps-cached-build") + return os.path.exists(cached_marker) + + def download(self): + if self.is_cacheable() and not os.path.exists(self.inst_dir): + print("check cache for %s" % self.cache_file_name) + dl_dir = os.path.join(self.loader.build_opts.scratch_dir, "downloads") + if not os.path.exists(dl_dir): + os.makedirs(dl_dir) + try: + target_file_name = os.path.join(dl_dir, self.cache_file_name) + if self.cache.download_to_file(self.cache_file_name, target_file_name): + tf = tarfile.open(target_file_name, "r") + print( + "Extracting %s -> %s..." % (self.cache_file_name, self.inst_dir) + ) + tf.extractall(self.inst_dir) + + cached_marker = os.path.join(self.inst_dir, ".getdeps-cached-build") + with open(cached_marker, "w") as f: + f.write("\n") + + return True + except Exception as exc: + print("%s" % str(exc)) + + return False + + def upload(self): + if self.is_cacheable(): + # We can prepare an archive and stick it in LFS + tempdir = tempfile.mkdtemp() + tarfilename = os.path.join(tempdir, self.cache_file_name) + print("Archiving for cache: %s..." % tarfilename) + tf = tarfile.open(tarfilename, "w:gz") + tf.add(self.inst_dir, arcname=".") + tf.close() + try: + self.cache.upload_from_file(self.cache_file_name, tarfilename) + except Exception as exc: + print( + "Failed to upload to cache (%s), continue anyway" % str(exc), + file=sys.stderr, + ) + shutil.rmtree(tempdir) + + +@cmd("fetch", "fetch the code for a given project") +class FetchCmd(ProjectCmdBase): + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="fetch the transitive deps also", + action="store_true", + default=False, + ) + parser.add_argument( + "--host-type", + help=( + "When recursively fetching, fetch deps for " + "this host type rather than the current system" + ), + ) + + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + projects = loader.manifests_in_dependency_order() + else: + projects = [manifest] + + cache = cache_module.create_cache() + for m in projects: + cached_project = CachedProject(cache, loader, m) + if cached_project.download(): + continue + + inst_dir = loader.get_project_install_dir(m) + built_marker = os.path.join(inst_dir, ".built-by-getdeps") + if os.path.exists(built_marker): + with open(built_marker, "r") as f: + built_hash = f.read().strip() + + project_hash = loader.get_project_hash(m) + if built_hash == project_hash: + continue + + # We need to fetch the sources + fetcher = loader.create_fetcher(m) + fetcher.update() + + +@cmd("install-system-deps", "Install system packages to satisfy the deps for a project") +class InstallSysDepsCmd(ProjectCmdBase): + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="install the transitive deps also", + action="store_true", + default=False, + ) + + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + projects = loader.manifests_in_dependency_order() + else: + projects = [manifest] + + cache = cache_module.create_cache() + all_packages = {} + for m in projects: + ctx = loader.ctx_gen.get_context(m.name) + packages = m.get_required_system_packages(ctx) + for k, v in packages.items(): + merged = all_packages.get(k, []) + merged += v + all_packages[k] = merged + + manager = loader.build_opts.host_type.get_package_manager() + if manager == "rpm": + packages = sorted(list(set(all_packages["rpm"]))) + if packages: + run_cmd(["dnf", "install", "-y"] + packages) + elif manager == "deb": + packages = sorted(list(set(all_packages["deb"]))) + if packages: + run_cmd(["apt", "install", "-y"] + packages) + else: + print("I don't know how to install any packages on this system") + + +@cmd("list-deps", "lists the transitive deps for a given project") +class ListDepsCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + for m in loader.manifests_in_dependency_order(): + print(m.name) + return 0 + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--host-type", + help=( + "Produce the list for the specified host type, " + "rather than that of the current system" + ), + ) + + +def clean_dirs(opts): + for d in ["build", "installed", "extracted", "shipit"]: + d = os.path.join(opts.scratch_dir, d) + print("Cleaning %s..." % d) + if os.path.exists(d): + shutil.rmtree(d) + + +@cmd("clean", "clean up the scratch dir") +class CleanCmd(SubCmd): + def run(self, args): + opts = setup_build_options(args) + clean_dirs(opts) + + +@cmd("show-build-dir", "print the build dir for a given project") +class ShowBuildDirCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + manifests = loader.manifests_in_dependency_order() + else: + manifests = [manifest] + + for m in manifests: + inst_dir = loader.get_project_build_dir(m) + print(inst_dir) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="print the transitive deps also", + action="store_true", + default=False, + ) + + +@cmd("show-inst-dir", "print the installation dir for a given project") +class ShowInstDirCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + manifests = loader.manifests_in_dependency_order() + else: + manifests = [manifest] + + for m in manifests: + inst_dir = loader.get_project_install_dir_respecting_install_prefix(m) + print(inst_dir) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="print the transitive deps also", + action="store_true", + default=False, + ) + + +@cmd("show-source-dir", "print the source dir for a given project") +class ShowSourceDirCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.recursive: + manifests = loader.manifests_in_dependency_order() + else: + manifests = [manifest] + + for m in manifests: + fetcher = loader.create_fetcher(m) + print(fetcher.get_src_dir()) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--recursive", + help="print the transitive deps also", + action="store_true", + default=False, + ) + + +@cmd("build", "build a given project") +class BuildCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + if args.clean: + clean_dirs(loader.build_opts) + + print("Building on %s" % loader.ctx_gen.get_context(args.project)) + projects = loader.manifests_in_dependency_order() + + cache = cache_module.create_cache() if args.use_build_cache else None + + # Accumulate the install directories so that the build steps + # can find their dep installation + install_dirs = [] + + for m in projects: + fetcher = loader.create_fetcher(m) + + if isinstance(fetcher, SystemPackageFetcher): + # We are guaranteed that if the fetcher is set to + # SystemPackageFetcher then this item is completely + # satisfied by the appropriate system packages + continue + + if args.clean: + fetcher.clean() + + build_dir = loader.get_project_build_dir(m) + inst_dir = loader.get_project_install_dir(m) + + if ( + m == manifest + and not args.only_deps + or m != manifest + and not args.no_deps + ): + print("Assessing %s..." % m.name) + project_hash = loader.get_project_hash(m) + ctx = loader.ctx_gen.get_context(m.name) + built_marker = os.path.join(inst_dir, ".built-by-getdeps") + + cached_project = CachedProject(cache, loader, m) + + reconfigure, sources_changed = self.compute_source_change_status( + cached_project, fetcher, m, built_marker, project_hash + ) + + if os.path.exists(built_marker) and not cached_project.was_cached(): + # We've previously built this. We may need to reconfigure if + # our deps have changed, so let's check them. + dep_reconfigure, dep_build = self.compute_dep_change_status( + m, built_marker, loader + ) + if dep_reconfigure: + reconfigure = True + if dep_build: + sources_changed = True + + extra_cmake_defines = ( + json.loads(args.extra_cmake_defines) + if args.extra_cmake_defines + else {} + ) + + if sources_changed or reconfigure or not os.path.exists(built_marker): + if os.path.exists(built_marker): + os.unlink(built_marker) + src_dir = fetcher.get_src_dir() + builder = m.create_builder( + loader.build_opts, + src_dir, + build_dir, + inst_dir, + ctx, + loader, + final_install_prefix=loader.get_project_install_prefix(m), + extra_cmake_defines=extra_cmake_defines, + ) + builder.build(install_dirs, reconfigure=reconfigure) + + with open(built_marker, "w") as f: + f.write(project_hash) + + # Only populate the cache from continuous build runs + if args.schedule_type == "continuous": + cached_project.upload() + + install_dirs.append(inst_dir) + + def compute_dep_change_status(self, m, built_marker, loader): + reconfigure = False + sources_changed = False + st = os.lstat(built_marker) + + ctx = loader.ctx_gen.get_context(m.name) + dep_list = sorted(m.get_section_as_dict("dependencies", ctx).keys()) + for dep in dep_list: + if reconfigure and sources_changed: + break + + dep_manifest = loader.load_manifest(dep) + dep_root = loader.get_project_install_dir(dep_manifest) + for dep_file in list_files_under_dir_newer_than_timestamp( + dep_root, st.st_mtime + ): + if os.path.basename(dep_file) == ".built-by-getdeps": + continue + if file_name_is_cmake_file(dep_file): + if not reconfigure: + reconfigure = True + print( + f"Will reconfigure cmake because {dep_file} is newer than {built_marker}" + ) + else: + if not sources_changed: + sources_changed = True + print( + f"Will run build because {dep_file} is newer than {built_marker}" + ) + + if reconfigure and sources_changed: + break + + return reconfigure, sources_changed + + def compute_source_change_status( + self, cached_project, fetcher, m, built_marker, project_hash + ): + reconfigure = False + sources_changed = False + if not cached_project.download(): + check_fetcher = True + if os.path.exists(built_marker): + check_fetcher = False + with open(built_marker, "r") as f: + built_hash = f.read().strip() + if built_hash == project_hash: + if cached_project.is_cacheable(): + # We can blindly trust the build status + reconfigure = False + sources_changed = False + else: + # Otherwise, we may have changed the source, so let's + # check in with the fetcher layer + check_fetcher = True + else: + # Some kind of inconsistency with a prior build, + # let's run it again to be sure + os.unlink(built_marker) + reconfigure = True + sources_changed = True + # While we don't need to consult the fetcher for the + # status in this case, we may still need to have eg: shipit + # run in order to have a correct source tree. + fetcher.update() + + if check_fetcher: + change_status = fetcher.update() + reconfigure = change_status.build_changed() + sources_changed = change_status.sources_changed() + + return reconfigure, sources_changed + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--clean", + action="store_true", + default=False, + help=( + "Clean up the build and installation area prior to building, " + "causing the projects to be built from scratch" + ), + ) + parser.add_argument( + "--no-deps", + action="store_true", + default=False, + help=( + "Only build the named project, not its deps. " + "This is most useful after you've built all of the deps, " + "and helps to avoid waiting for relatively " + "slow up-to-date-ness checks" + ), + ) + parser.add_argument( + "--only-deps", + action="store_true", + default=False, + help=( + "Only build the named project's deps. " + "This is most useful when you want to separate out building " + "of all of the deps and your project" + ), + ) + parser.add_argument( + "--no-build-cache", + action="store_false", + default=True, + dest="use_build_cache", + help="Do not attempt to use the build cache.", + ) + parser.add_argument( + "--schedule-type", help="Indicates how the build was activated" + ) + parser.add_argument( + "--extra-cmake-defines", + help=( + "Input json map that contains extra cmake defines to be used " + "when compiling the current project and all its deps. " + 'e.g: \'{"CMAKE_CXX_FLAGS": "--bla"}\'' + ), + ) + + +@cmd("fixup-dyn-deps", "Adjusts dynamic dependencies for packaging purposes") +class FixupDeps(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + projects = loader.manifests_in_dependency_order() + + # Accumulate the install directories so that the build steps + # can find their dep installation + install_dirs = [] + + for m in projects: + inst_dir = loader.get_project_install_dir_respecting_install_prefix(m) + install_dirs.append(inst_dir) + + if m == manifest: + dep_munger = create_dyn_dep_munger( + loader.build_opts, install_dirs, args.strip + ) + dep_munger.process_deps(args.destdir, args.final_install_prefix) + + def setup_project_cmd_parser(self, parser): + parser.add_argument("destdir", help="Where to copy the fixed up executables") + parser.add_argument( + "--final-install-prefix", help="specify the final installation prefix" + ) + parser.add_argument( + "--strip", + action="store_true", + default=False, + help="Strip debug info while processing executables", + ) + + +@cmd("test", "test a given project") +class TestCmd(ProjectCmdBase): + def run_project_cmd(self, args, loader, manifest): + projects = loader.manifests_in_dependency_order() + + # Accumulate the install directories so that the test steps + # can find their dep installation + install_dirs = [] + + for m in projects: + inst_dir = loader.get_project_install_dir(m) + + if m == manifest or args.test_dependencies: + built_marker = os.path.join(inst_dir, ".built-by-getdeps") + if not os.path.exists(built_marker): + print("project %s has not been built" % m.name) + # TODO: we could just go ahead and build it here, but I + # want to tackle that as part of adding build-for-test + # support. + return 1 + fetcher = loader.create_fetcher(m) + src_dir = fetcher.get_src_dir() + ctx = loader.ctx_gen.get_context(m.name) + build_dir = loader.get_project_build_dir(m) + builder = m.create_builder( + loader.build_opts, src_dir, build_dir, inst_dir, ctx, loader + ) + + builder.run_tests( + install_dirs, + schedule_type=args.schedule_type, + owner=args.test_owner, + test_filter=args.filter, + retry=args.retry, + no_testpilot=args.no_testpilot, + ) + + install_dirs.append(inst_dir) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--schedule-type", help="Indicates how the build was activated" + ) + parser.add_argument("--test-owner", help="Owner for testpilot") + parser.add_argument("--filter", help="Only run the tests matching the regex") + parser.add_argument( + "--retry", + type=int, + default=3, + help="Number of immediate retries for failed tests " + "(noop in continuous and testwarden runs)", + ) + parser.add_argument( + "--no-testpilot", + help="Do not use Test Pilot even when available", + action="store_true", + ) + + +@cmd("generate-github-actions", "generate a GitHub actions configuration") +class GenerateGitHubActionsCmd(ProjectCmdBase): + RUN_ON_ALL = """ [push, pull_request]""" + RUN_ON_DEFAULT = """ + push: + branches: + - master + pull_request: + branches: + - master""" + + def run_project_cmd(self, args, loader, manifest): + platforms = [ + HostType("linux", "ubuntu", "18"), + HostType("darwin", None, None), + HostType("windows", None, None), + ] + + for p in platforms: + self.write_job_for_platform(p, args) + + # TODO: Break up complex function + def write_job_for_platform(self, platform, args): # noqa: C901 + build_opts = setup_build_options(args, platform) + ctx_gen = build_opts.get_context_generator(facebook_internal=False) + loader = ManifestLoader(build_opts, ctx_gen) + manifest = loader.load_manifest(args.project) + manifest_ctx = loader.ctx_gen.get_context(manifest.name) + run_on = self.RUN_ON_ALL if args.run_on_all_branches else self.RUN_ON_DEFAULT + + # Some projects don't do anything "useful" as a leaf project, only + # as a dep for a leaf project. Check for those here; we don't want + # to waste the effort scheduling them on CI. + # We do this by looking at the builder type in the manifest file + # rather than creating a builder and checking its type because we + # don't know enough to create the full builder instance here. + if manifest.get("build", "builder", ctx=manifest_ctx) == "nop": + return None + + # We want to be sure that we're running things with python 3 + # but python versioning is honestly a bit of a frustrating mess. + # `python` may be version 2 or version 3 depending on the system. + # python3 may not be a thing at all! + # Assume an optimistic default + py3 = "python3" + + if build_opts.is_linux(): + job_name = "linux" + runs_on = f"ubuntu-{args.ubuntu_version}" + elif build_opts.is_windows(): + # We're targeting the windows-2016 image because it has + # Visual Studio 2017 installed, and at the time of writing, + # the version of boost in the manifests (1.69) is not + # buildable with Visual Studio 2019 + job_name = "windows" + runs_on = "windows-2016" + # The windows runners are python 3 by default; python2.exe + # is available if needed. + py3 = "python" + else: + job_name = "mac" + runs_on = "macOS-latest" + + os.makedirs(args.output_dir, exist_ok=True) + output_file = os.path.join(args.output_dir, f"getdeps_{job_name}.yml") + with open(output_file, "w") as out: + # Deliberate line break here because the @ and the generated + # symbols are meaningful to our internal tooling when they + # appear in a single token + out.write("# This file was @") + out.write("generated by getdeps.py\n") + out.write( + f""" +name: {job_name} + +on:{run_on} + +jobs: +""" + ) + + getdeps = f"{py3} build/fbcode_builder/getdeps.py" + + out.write(" build:\n") + out.write(" runs-on: %s\n" % runs_on) + out.write(" steps:\n") + out.write(" - uses: actions/checkout@v1\n") + + if build_opts.is_windows(): + # cmake relies on BOOST_ROOT but GH deliberately don't set it in order + # to avoid versioning issues: + # https://github.com/actions/virtual-environments/issues/319 + # Instead, set the version we think we need; this is effectively + # coupled with the boost manifest + # This is the unusual syntax for setting an env var for the rest of + # the steps in a workflow: + # https://github.blog/changelog/2020-10-01-github-actions-deprecating-set-env-and-add-path-commands/ + out.write(" - name: Export boost environment\n") + out.write( + ' run: "echo BOOST_ROOT=%BOOST_ROOT_1_69_0% >> %GITHUB_ENV%"\n' + ) + out.write(" shell: cmd\n") + + # The git installation may not like long filenames, so tell it + # that we want it to use them! + out.write(" - name: Fix Git config\n") + out.write(" run: git config --system core.longpaths true\n") + + projects = loader.manifests_in_dependency_order() + + for m in projects: + if m != manifest: + out.write(" - name: Fetch %s\n" % m.name) + out.write(f" run: {getdeps} fetch --no-tests {m.name}\n") + + for m in projects: + if m != manifest: + out.write(" - name: Build %s\n" % m.name) + out.write(f" run: {getdeps} build --no-tests {m.name}\n") + + out.write(" - name: Build %s\n" % manifest.name) + + project_prefix = "" + if not build_opts.is_windows(): + project_prefix = ( + " --project-install-prefix %s:/usr/local" % manifest.name + ) + + out.write( + f" run: {getdeps} build --src-dir=. {manifest.name} {project_prefix}\n" + ) + + out.write(" - name: Copy artifacts\n") + if build_opts.is_linux(): + # Strip debug info from the binaries, but only on linux. + # While the `strip` utility is also available on macOS, + # attempting to strip there results in an error. + # The `strip` utility is not available on Windows. + strip = " --strip" + else: + strip = "" + + out.write( + f" run: {getdeps} fixup-dyn-deps{strip} " + f"--src-dir=. {manifest.name} _artifacts/{job_name} {project_prefix} " + f"--final-install-prefix /usr/local\n" + ) + + out.write(" - uses: actions/upload-artifact@master\n") + out.write(" with:\n") + out.write(" name: %s\n" % manifest.name) + out.write(" path: _artifacts\n") + + out.write(" - name: Test %s\n" % manifest.name) + out.write( + f" run: {getdeps} test --src-dir=. {manifest.name} {project_prefix}\n" + ) + + def setup_project_cmd_parser(self, parser): + parser.add_argument( + "--disallow-system-packages", + help="Disallow satisfying third party deps from installed system packages", + action="store_true", + default=False, + ) + parser.add_argument( + "--output-dir", help="The directory that will contain the yml files" + ) + parser.add_argument( + "--run-on-all-branches", + action="store_true", + help="Allow CI to fire on all branches - Handy for testing", + ) + parser.add_argument( + "--ubuntu-version", default="18.04", help="Version of Ubuntu to use" + ) + + +def get_arg_var_name(args): + for arg in args: + if arg.startswith("--"): + return arg[2:].replace("-", "_") + + raise Exception("unable to determine argument variable name from %r" % (args,)) + + +def parse_args(): + # We want to allow common arguments to be specified either before or after + # the subcommand name. In order to do this we add them to the main parser + # and to subcommand parsers. In order for this to work, we need to tell + # argparse that the default value is SUPPRESS, so that the default values + # from the subparser arguments won't override values set by the user from + # the main parser. We maintain our own list of desired defaults in the + # common_defaults dictionary, and manually set those if the argument wasn't + # present at all. + common_args = argparse.ArgumentParser(add_help=False) + common_defaults = {} + + def add_common_arg(*args, **kwargs): + var_name = get_arg_var_name(args) + default_value = kwargs.pop("default", None) + common_defaults[var_name] = default_value + kwargs["default"] = argparse.SUPPRESS + common_args.add_argument(*args, **kwargs) + + add_common_arg("--scratch-path", help="Where to maintain checkouts and build dirs") + add_common_arg( + "--vcvars-path", default=None, help="Path to the vcvarsall.bat on Windows." + ) + add_common_arg( + "--install-prefix", + help=( + "Where the final build products will be installed " + "(default is [scratch-path]/installed)" + ), + ) + add_common_arg( + "--num-jobs", + type=int, + help=( + "Number of concurrent jobs to use while building. " + "(default=number of cpu cores)" + ), + ) + add_common_arg( + "--use-shipit", + help="use the real ShipIt instead of the simple shipit transformer", + action="store_true", + default=False, + ) + add_common_arg( + "--facebook-internal", + help="Setup the build context as an FB internal build", + action="store_true", + default=None, + ) + add_common_arg( + "--no-facebook-internal", + help="Perform a non-FB internal build, even when in an fbsource repository", + action="store_false", + dest="facebook_internal", + ) + add_common_arg( + "--allow-system-packages", + help="Allow satisfying third party deps from installed system packages", + action="store_true", + default=False, + ) + add_common_arg( + "--lfs-path", + help="Provide a parent directory for lfs when fbsource is unavailable", + default=None, + ) + + ap = argparse.ArgumentParser( + description="Get and build dependencies and projects", parents=[common_args] + ) + sub = ap.add_subparsers( + # metavar suppresses the long and ugly default list of subcommands on a + # single line. We still render the nicer list below where we would + # have shown the nasty one. + metavar="", + title="Available commands", + help="", + ) + + add_subcommands(sub, common_args) + + args = ap.parse_args() + for var_name, default_value in common_defaults.items(): + if not hasattr(args, var_name): + setattr(args, var_name, default_value) + + return ap, args + + +def main(): + ap, args = parse_args() + if getattr(args, "func", None) is None: + ap.print_help() + return 0 + try: + return args.func(args) + except UsageError as exc: + ap.error(str(exc)) + return 1 + except TransientFailure as exc: + print("TransientFailure: %s" % str(exc)) + # This return code is treated as a retryable transient infrastructure + # error by Facebook's internal CI, rather than eg: a build or code + # related error that needs to be fixed before progress can be made. + return 128 + except subprocess.CalledProcessError as exc: + print("%s" % str(exc), file=sys.stderr) + print("!! Failed", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/build/fbcode_builder/getdeps/__init__.py b/build/fbcode_builder/getdeps/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/build/fbcode_builder/getdeps/builder.py b/build/fbcode_builder/getdeps/builder.py new file mode 100644 index 000000000..4e523c2dc --- /dev/null +++ b/build/fbcode_builder/getdeps/builder.py @@ -0,0 +1,1400 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import os +import shutil +import stat +import subprocess +import sys + +from .dyndeps import create_dyn_dep_munger +from .envfuncs import Env, add_path_entry, path_search +from .fetcher import copy_if_different +from .runcmd import run_cmd + + +class BuilderBase(object): + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + env=None, + final_install_prefix=None, + ): + self.env = Env() + if env: + self.env.update(env) + + subdir = manifest.get("build", "subdir", ctx=ctx) + if subdir: + src_dir = os.path.join(src_dir, subdir) + + self.ctx = ctx + self.src_dir = src_dir + self.build_dir = build_dir or src_dir + self.inst_dir = inst_dir + self.build_opts = build_opts + self.manifest = manifest + self.final_install_prefix = final_install_prefix + + def _get_cmd_prefix(self): + if self.build_opts.is_windows(): + vcvarsall = self.build_opts.get_vcvars_path() + if vcvarsall is not None: + # Since it sets rather a large number of variables we mildly abuse + # the cmd quoting rules to assemble a command that calls the script + # to prep the environment and then triggers the actual command that + # we wanted to run. + return [vcvarsall, "amd64", "&&"] + return [] + + def _run_cmd(self, cmd, cwd=None, env=None, use_cmd_prefix=True, allow_fail=False): + if env: + e = self.env.copy() + e.update(env) + env = e + else: + env = self.env + + if use_cmd_prefix: + cmd_prefix = self._get_cmd_prefix() + if cmd_prefix: + cmd = cmd_prefix + cmd + + log_file = os.path.join(self.build_dir, "getdeps_build.log") + return run_cmd( + cmd=cmd, + env=env, + cwd=cwd or self.build_dir, + log_file=log_file, + allow_fail=allow_fail, + ) + + def build(self, install_dirs, reconfigure): + print("Building %s..." % self.manifest.name) + + if self.build_dir is not None: + if not os.path.isdir(self.build_dir): + os.makedirs(self.build_dir) + reconfigure = True + + self._build(install_dirs=install_dirs, reconfigure=reconfigure) + + # On Windows, emit a wrapper script that can be used to run build artifacts + # directly from the build directory, without installing them. On Windows $PATH + # needs to be updated to include all of the directories containing the runtime + # library dependencies in order to run the binaries. + if self.build_opts.is_windows(): + script_path = self.get_dev_run_script_path() + dep_munger = create_dyn_dep_munger(self.build_opts, install_dirs) + dep_dirs = self.get_dev_run_extra_path_dirs(install_dirs, dep_munger) + dep_munger.emit_dev_run_script(script_path, dep_dirs) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + """Execute any tests that we know how to run. If they fail, + raise an exception.""" + pass + + def _build(self, install_dirs, reconfigure): + """Perform the build. + install_dirs contains the list of installation directories for + the dependencies of this project. + reconfigure will be set to true if the fetcher determined + that the sources have changed in such a way that the build + system needs to regenerate its rules.""" + pass + + def _compute_env(self, install_dirs): + # CMAKE_PREFIX_PATH is only respected when passed through the + # environment, so we construct an appropriate path to pass down + return self.build_opts.compute_env_for_install_dirs( + install_dirs, env=self.env, manifest=self.manifest + ) + + def get_dev_run_script_path(self): + assert self.build_opts.is_windows() + return os.path.join(self.build_dir, "run.ps1") + + def get_dev_run_extra_path_dirs(self, install_dirs, dep_munger=None): + assert self.build_opts.is_windows() + if dep_munger is None: + dep_munger = create_dyn_dep_munger(self.build_opts, install_dirs) + return dep_munger.compute_dependency_paths(self.build_dir) + + +class MakeBuilder(BuilderBase): + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + build_args, + install_args, + test_args, + ): + super(MakeBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.build_args = build_args or [] + self.install_args = install_args or [] + self.test_args = test_args + + def _get_prefix(self): + return ["PREFIX=" + self.inst_dir, "prefix=" + self.inst_dir] + + def _build(self, install_dirs, reconfigure): + env = self._compute_env(install_dirs) + + # Need to ensure that PREFIX is set prior to install because + # libbpf uses it when generating its pkg-config file. + # The lowercase prefix is used by some projects. + cmd = ( + ["make", "-j%s" % self.build_opts.num_jobs] + + self.build_args + + self._get_prefix() + ) + self._run_cmd(cmd, env=env) + + install_cmd = ["make"] + self.install_args + self._get_prefix() + self._run_cmd(install_cmd, env=env) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + if not self.test_args: + return + + env = self._compute_env(install_dirs) + + cmd = ["make"] + self.test_args + self._get_prefix() + self._run_cmd(cmd, env=env) + + +class CMakeBootStrapBuilder(MakeBuilder): + def _build(self, install_dirs, reconfigure): + self._run_cmd(["./bootstrap", "--prefix=" + self.inst_dir]) + super(CMakeBootStrapBuilder, self)._build(install_dirs, reconfigure) + + +class AutoconfBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir, args): + super(AutoconfBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.args = args or [] + + def _build(self, install_dirs, reconfigure): + configure_path = os.path.join(self.src_dir, "configure") + autogen_path = os.path.join(self.src_dir, "autogen.sh") + + env = self._compute_env(install_dirs) + + if not os.path.exists(configure_path): + print("%s doesn't exist, so reconfiguring" % configure_path) + # This libtoolize call is a bit gross; the issue is that + # `autoreconf` as invoked by libsodium's `autogen.sh` doesn't + # seem to realize that it should invoke libtoolize and then + # error out when the configure script references a libtool + # related symbol. + self._run_cmd(["libtoolize"], cwd=self.src_dir, env=env) + + # We generally prefer to call the `autogen.sh` script provided + # by the project on the basis that it may know more than plain + # autoreconf does. + if os.path.exists(autogen_path): + self._run_cmd(["bash", autogen_path], cwd=self.src_dir, env=env) + else: + self._run_cmd(["autoreconf", "-ivf"], cwd=self.src_dir, env=env) + configure_cmd = [configure_path, "--prefix=" + self.inst_dir] + self.args + self._run_cmd(configure_cmd, env=env) + self._run_cmd(["make", "-j%s" % self.build_opts.num_jobs], env=env) + self._run_cmd(["make", "install"], env=env) + + +class Iproute2Builder(BuilderBase): + # ./configure --prefix does not work for iproute2. + # Thus, explicitly copy sources from src_dir to build_dir, bulid, + # and then install to inst_dir using DESTDIR + # lastly, also copy include from build_dir to inst_dir + def __init__(self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir): + super(Iproute2Builder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _patch(self): + # FBOSS build currently depends on an old version of iproute2 (commit + # 7ca63aef7d1b0c808da0040c6b366ef7a61f38c1). This is missing a commit + # (ae717baf15fb4d30749ada3948d9445892bac239) needed to build iproute2 + # successfully. Apply it viz.: include stdint.h + # Reference: https://fburl.com/ilx9g5xm + with open(self.build_dir + "/tc/tc_core.c", "r") as f: + data = f.read() + + with open(self.build_dir + "/tc/tc_core.c", "w") as f: + f.write("#include \n") + f.write(data) + + def _build(self, install_dirs, reconfigure): + configure_path = os.path.join(self.src_dir, "configure") + + env = self.env.copy() + self._run_cmd([configure_path], env=env) + shutil.rmtree(self.build_dir) + shutil.copytree(self.src_dir, self.build_dir) + self._patch() + self._run_cmd(["make", "-j%s" % self.build_opts.num_jobs], env=env) + install_cmd = ["make", "install", "DESTDIR=" + self.inst_dir] + + for d in ["include", "lib"]: + if not os.path.isdir(os.path.join(self.inst_dir, d)): + shutil.copytree( + os.path.join(self.build_dir, d), os.path.join(self.inst_dir, d) + ) + + self._run_cmd(install_cmd, env=env) + + +class BistroBuilder(BuilderBase): + def _build(self, install_dirs, reconfigure): + p = os.path.join(self.src_dir, "bistro", "bistro") + env = self._compute_env(install_dirs) + env["PATH"] = env["PATH"] + ":" + os.path.join(p, "bin") + env["TEMPLATES_PATH"] = os.path.join(p, "include", "thrift", "templates") + self._run_cmd( + [ + os.path.join(".", "cmake", "run-cmake.sh"), + "Release", + "-DCMAKE_INSTALL_PREFIX=" + self.inst_dir, + ], + cwd=p, + env=env, + ) + self._run_cmd( + [ + "make", + "install", + "-j", + str(self.build_opts.num_jobs), + ], + cwd=os.path.join(p, "cmake", "Release"), + env=env, + ) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + env = self._compute_env(install_dirs) + build_dir = os.path.join(self.src_dir, "bistro", "bistro", "cmake", "Release") + NUM_RETRIES = 5 + for i in range(NUM_RETRIES): + cmd = ["ctest", "--output-on-failure"] + if i > 0: + cmd.append("--rerun-failed") + cmd.append(build_dir) + try: + self._run_cmd( + cmd, + cwd=build_dir, + env=env, + ) + except Exception: + print(f"Tests failed... retrying ({i+1}/{NUM_RETRIES})") + else: + return + raise Exception(f"Tests failed even after {NUM_RETRIES} retries") + + +class CMakeBuilder(BuilderBase): + MANUAL_BUILD_SCRIPT = """\ +#!{sys.executable} + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import subprocess +import sys + +CMAKE = {cmake!r} +CTEST = {ctest!r} +SRC_DIR = {src_dir!r} +BUILD_DIR = {build_dir!r} +INSTALL_DIR = {install_dir!r} +CMD_PREFIX = {cmd_prefix!r} +CMAKE_ENV = {env_str} +CMAKE_DEFINE_ARGS = {define_args_str} + + +def get_jobs_argument(num_jobs_arg: int) -> str: + if num_jobs_arg > 0: + return "-j" + str(num_jobs_arg) + + import multiprocessing + num_jobs = multiprocessing.cpu_count() // 2 + return "-j" + str(num_jobs) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "cmake_args", + nargs=argparse.REMAINDER, + help='Any extra arguments after an "--" argument will be passed ' + "directly to CMake." + ) + ap.add_argument( + "--mode", + choices=["configure", "build", "install", "test"], + default="configure", + help="The mode to run: configure, build, or install. " + "Defaults to configure", + ) + ap.add_argument( + "--build", + action="store_const", + const="build", + dest="mode", + help="An alias for --mode=build", + ) + ap.add_argument( + "-j", + "--num-jobs", + action="store", + type=int, + default=0, + help="Run the build or tests with the specified number of parallel jobs", + ) + ap.add_argument( + "--install", + action="store_const", + const="install", + dest="mode", + help="An alias for --mode=install", + ) + ap.add_argument( + "--test", + action="store_const", + const="test", + dest="mode", + help="An alias for --mode=test", + ) + args = ap.parse_args() + + # Strip off a leading "--" from the additional CMake arguments + if args.cmake_args and args.cmake_args[0] == "--": + args.cmake_args = args.cmake_args[1:] + + env = CMAKE_ENV + + if args.mode == "configure": + full_cmd = CMD_PREFIX + [CMAKE, SRC_DIR] + CMAKE_DEFINE_ARGS + args.cmake_args + elif args.mode in ("build", "install"): + target = "all" if args.mode == "build" else "install" + full_cmd = CMD_PREFIX + [ + CMAKE, + "--build", + BUILD_DIR, + "--target", + target, + "--config", + "Release", + get_jobs_argument(args.num_jobs), + ] + args.cmake_args + elif args.mode == "test": + full_cmd = CMD_PREFIX + [ + {dev_run_script}CTEST, + "--output-on-failure", + get_jobs_argument(args.num_jobs), + ] + args.cmake_args + else: + ap.error("unknown invocation mode: %s" % (args.mode,)) + + cmd_str = " ".join(full_cmd) + print("Running: %r" % (cmd_str,)) + proc = subprocess.run(full_cmd, env=env, cwd=BUILD_DIR) + sys.exit(proc.returncode) + + +if __name__ == "__main__": + main() +""" + + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + defines, + final_install_prefix=None, + extra_cmake_defines=None, + ): + super(CMakeBuilder, self).__init__( + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + final_install_prefix=final_install_prefix, + ) + self.defines = defines or {} + if extra_cmake_defines: + self.defines.update(extra_cmake_defines) + + def _invalidate_cache(self): + for name in [ + "CMakeCache.txt", + "CMakeFiles/CMakeError.log", + "CMakeFiles/CMakeOutput.log", + ]: + name = os.path.join(self.build_dir, name) + if os.path.isdir(name): + shutil.rmtree(name) + elif os.path.exists(name): + os.unlink(name) + + def _needs_reconfigure(self): + for name in ["CMakeCache.txt", "build.ninja"]: + name = os.path.join(self.build_dir, name) + if not os.path.exists(name): + return True + return False + + def _write_build_script(self, **kwargs): + env_lines = [" {!r}: {!r},".format(k, v) for k, v in kwargs["env"].items()] + kwargs["env_str"] = "\n".join(["{"] + env_lines + ["}"]) + + if self.build_opts.is_windows(): + kwargs["dev_run_script"] = '"powershell.exe", {!r}, '.format( + self.get_dev_run_script_path() + ) + else: + kwargs["dev_run_script"] = "" + + define_arg_lines = ["["] + for arg in kwargs["define_args"]: + # Replace the CMAKE_INSTALL_PREFIX argument to use the INSTALL_DIR + # variable that we define in the MANUAL_BUILD_SCRIPT code. + if arg.startswith("-DCMAKE_INSTALL_PREFIX="): + value = " {!r}.format(INSTALL_DIR),".format( + "-DCMAKE_INSTALL_PREFIX={}" + ) + else: + value = " {!r},".format(arg) + define_arg_lines.append(value) + define_arg_lines.append("]") + kwargs["define_args_str"] = "\n".join(define_arg_lines) + + # In order to make it easier for developers to manually run builds for + # CMake-based projects, write out some build scripts that can be used to invoke + # CMake manually. + build_script_path = os.path.join(self.build_dir, "run_cmake.py") + script_contents = self.MANUAL_BUILD_SCRIPT.format(**kwargs) + with open(build_script_path, "wb") as f: + f.write(script_contents.encode()) + os.chmod(build_script_path, 0o755) + + def _compute_cmake_define_args(self, env): + defines = { + "CMAKE_INSTALL_PREFIX": self.final_install_prefix or self.inst_dir, + "BUILD_SHARED_LIBS": "OFF", + # Some of the deps (rsocket) default to UBSAN enabled if left + # unspecified. Some of the deps fail to compile in release mode + # due to warning->error promotion. RelWithDebInfo is the happy + # medium. + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + } + if "SANDCASTLE" not in os.environ: + # We sometimes see intermittent ccache related breakages on some + # of the FB internal CI hosts, so we prefer to disable ccache + # when running in that environment. + ccache = path_search(env, "ccache") + if ccache: + defines["CMAKE_CXX_COMPILER_LAUNCHER"] = ccache + else: + # rocksdb does its own probing for ccache. + # Ensure that it is disabled on sandcastle + env["CCACHE_DISABLE"] = "1" + # Some sandcastle hosts have broken ccache related dirs, and + # even though we've asked for it to be disabled ccache is + # still invoked by rocksdb's cmake. + # Redirect its config directory to somewhere that is guaranteed + # fresh to us, and that won't have any ccache data inside. + env["CCACHE_DIR"] = f"{self.build_opts.scratch_dir}/ccache" + + if "GITHUB_ACTIONS" in os.environ and self.build_opts.is_windows(): + # GitHub actions: the host has both gcc and msvc installed, and + # the default behavior of cmake is to prefer gcc. + # Instruct cmake that we want it to use cl.exe; this is important + # because Boost prefers cl.exe and the mismatch results in cmake + # with gcc not being able to find boost built with cl.exe. + defines["CMAKE_C_COMPILER"] = "cl.exe" + defines["CMAKE_CXX_COMPILER"] = "cl.exe" + + if self.build_opts.is_darwin(): + # Try to persuade cmake to set the rpath to match the lib + # dirs of the dependencies. This isn't automatic, and to + # make things more interesting, cmake uses `;` as the path + # separator, so translate the runtime path to something + # that cmake will parse + defines["CMAKE_INSTALL_RPATH"] = ";".join( + env.get("DYLD_LIBRARY_PATH", "").split(":") + ) + # Tell cmake that we want to set the rpath in the tree + # at build time. Without this the rpath is only set + # at the moment that the binaries are installed. That + # default is problematic for example when using the + # gtest integration in cmake which runs the built test + # executables during the build to discover the set of + # tests. + defines["CMAKE_BUILD_WITH_INSTALL_RPATH"] = "ON" + + defines.update(self.defines) + define_args = ["-D%s=%s" % (k, v) for (k, v) in defines.items()] + + # if self.build_opts.is_windows(): + # define_args += ["-G", "Visual Studio 15 2017 Win64"] + define_args += ["-G", "Ninja"] + + return define_args + + def _build(self, install_dirs, reconfigure): + reconfigure = reconfigure or self._needs_reconfigure() + + env = self._compute_env(install_dirs) + if not self.build_opts.is_windows() and self.final_install_prefix: + env["DESTDIR"] = self.inst_dir + + # Resolve the cmake that we installed + cmake = path_search(env, "cmake") + if cmake is None: + raise Exception("Failed to find CMake") + + if reconfigure: + define_args = self._compute_cmake_define_args(env) + self._write_build_script( + cmd_prefix=self._get_cmd_prefix(), + cmake=cmake, + ctest=path_search(env, "ctest"), + env=env, + define_args=define_args, + src_dir=self.src_dir, + build_dir=self.build_dir, + install_dir=self.inst_dir, + sys=sys, + ) + + self._invalidate_cache() + self._run_cmd([cmake, self.src_dir] + define_args, env=env) + + self._run_cmd( + [ + cmake, + "--build", + self.build_dir, + "--target", + "install", + "--config", + "Release", + "-j", + str(self.build_opts.num_jobs), + ], + env=env, + ) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + env = self._compute_env(install_dirs) + ctest = path_search(env, "ctest") + cmake = path_search(env, "cmake") + + # On Windows, we also need to update $PATH to include the directories that + # contain runtime library dependencies. This is not needed on other platforms + # since CMake will emit RPATH properly in the binary so they can find these + # dependencies. + if self.build_opts.is_windows(): + path_entries = self.get_dev_run_extra_path_dirs(install_dirs) + path = env.get("PATH") + if path: + path_entries.insert(0, path) + env["PATH"] = ";".join(path_entries) + + # Don't use the cmd_prefix when running tests. This is vcvarsall.bat on + # Windows. vcvarsall.bat is only needed for the build, not tests. It + # unfortunately fails if invoked with a long PATH environment variable when + # running the tests. + use_cmd_prefix = False + + def get_property(test, propname, defval=None): + """extracts a named property from a cmake test info json blob. + The properties look like: + [{"name": "WORKING_DIRECTORY"}, + {"value": "something"}] + We assume that it is invalid for the same named property to be + listed more than once. + """ + props = test.get("properties", []) + for p in props: + if p.get("name", None) == propname: + return p.get("value", defval) + return defval + + def list_tests(): + output = subprocess.check_output( + [ctest, "--show-only=json-v1"], env=env, cwd=self.build_dir + ) + try: + data = json.loads(output.decode("utf-8")) + except ValueError as exc: + raise Exception( + "Failed to decode cmake test info using %s: %s. Output was: %r" + % (ctest, str(exc), output) + ) + + tests = [] + machine_suffix = self.build_opts.host_type.as_tuple_string() + for test in data["tests"]: + working_dir = get_property(test, "WORKING_DIRECTORY") + labels = [] + machine_suffix = self.build_opts.host_type.as_tuple_string() + labels.append("tpx_test_config::buildsystem=getdeps") + labels.append("tpx_test_config::platform={}".format(machine_suffix)) + + if get_property(test, "DISABLED"): + labels.append("disabled") + command = test["command"] + if working_dir: + command = [cmake, "-E", "chdir", working_dir] + command + + import os + + tests.append( + { + "type": "custom", + "target": "%s-%s-getdeps-%s" + % (self.manifest.name, test["name"], machine_suffix), + "command": command, + "labels": labels, + "env": {}, + "required_paths": [], + "contacts": [], + "cwd": os.getcwd(), + } + ) + return tests + + if schedule_type == "continuous" or schedule_type == "testwarden": + # for continuous and testwarden runs, disabling retry can give up + # better signals for flaky tests. + retry = 0 + + from sys import platform + + testpilot = path_search(env, "testpilot") + tpx = path_search(env, "tpx") + if (tpx or testpilot) and not no_testpilot: + buck_test_info = list_tests() + import os + + buck_test_info_name = os.path.join(self.build_dir, ".buck-test-info.json") + with open(buck_test_info_name, "w") as f: + json.dump(buck_test_info, f) + + env.set("http_proxy", "") + env.set("https_proxy", "") + runs = [] + from sys import platform + + if platform == "win32": + machine_suffix = self.build_opts.host_type.as_tuple_string() + testpilot_args = [ + "parexec-testinfra.exe", + "C:/tools/testpilot/sc_testpilot.par", + # Need to force the repo type otherwise testpilot on windows + # can be confused (presumably sparse profile related) + "--force-repo", + "fbcode", + "--force-repo-root", + self.build_opts.fbsource_dir, + "--buck-test-info", + buck_test_info_name, + "--retry=%d" % retry, + "-j=%s" % str(self.build_opts.num_jobs), + "--test-config", + "platform=%s" % machine_suffix, + "buildsystem=getdeps", + "--return-nonzero-on-failures", + ] + else: + testpilot_args = [ + tpx, + "--buck-test-info", + buck_test_info_name, + "--retry=%d" % retry, + "-j=%s" % str(self.build_opts.num_jobs), + "--print-long-results", + ] + + if owner: + testpilot_args += ["--contacts", owner] + + if tpx and env: + testpilot_args.append("--env") + testpilot_args.extend(f"{key}={val}" for key, val in env.items()) + + if test_filter: + testpilot_args += ["--", test_filter] + + if schedule_type == "continuous": + runs.append( + [ + "--tag-new-tests", + "--collection", + "oss-continuous", + "--purpose", + "continuous", + ] + ) + elif schedule_type == "testwarden": + # One run to assess new tests + runs.append( + [ + "--tag-new-tests", + "--collection", + "oss-new-test-stress", + "--stress-runs", + "10", + "--purpose", + "stress-run-new-test", + ] + ) + # And another for existing tests + runs.append( + [ + "--tag-new-tests", + "--collection", + "oss-existing-test-stress", + "--stress-runs", + "10", + "--purpose", + "stress-run", + ] + ) + else: + runs.append(["--collection", "oss-diff", "--purpose", "diff"]) + + for run in runs: + self._run_cmd( + testpilot_args + run, + cwd=self.build_opts.fbcode_builder_dir, + env=env, + use_cmd_prefix=use_cmd_prefix, + ) + else: + args = [ctest, "--output-on-failure", "-j", str(self.build_opts.num_jobs)] + if test_filter: + args += ["-R", test_filter] + + count = 0 + while count <= retry: + retcode = self._run_cmd( + args, env=env, use_cmd_prefix=use_cmd_prefix, allow_fail=True + ) + + if retcode == 0: + break + if count == 0: + # Only add this option in the second run. + args += ["--rerun-failed"] + count += 1 + if retcode != 0: + # Allow except clause in getdeps.main to catch and exit gracefully + # This allows non-testpilot runs to fail through the same logic as failed testpilot runs, which may become handy in case if post test processing is needed in the future + raise subprocess.CalledProcessError(retcode, args) + + +class NinjaBootstrap(BuilderBase): + def __init__(self, build_opts, ctx, manifest, build_dir, src_dir, inst_dir): + super(NinjaBootstrap, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _build(self, install_dirs, reconfigure): + self._run_cmd([sys.executable, "configure.py", "--bootstrap"], cwd=self.src_dir) + src_ninja = os.path.join(self.src_dir, "ninja") + dest_ninja = os.path.join(self.inst_dir, "bin/ninja") + bin_dir = os.path.dirname(dest_ninja) + if not os.path.exists(bin_dir): + os.makedirs(bin_dir) + shutil.copyfile(src_ninja, dest_ninja) + shutil.copymode(src_ninja, dest_ninja) + + +class OpenSSLBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, build_dir, src_dir, inst_dir): + super(OpenSSLBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _build(self, install_dirs, reconfigure): + configure = os.path.join(self.src_dir, "Configure") + + # prefer to resolve the perl that we installed from + # our manifest on windows, but fall back to the system + # path on eg: darwin + env = self.env.copy() + for d in install_dirs: + bindir = os.path.join(d, "bin") + add_path_entry(env, "PATH", bindir, append=False) + + perl = path_search(env, "perl", "perl") + + if self.build_opts.is_windows(): + make = "nmake.exe" + args = ["VC-WIN64A-masm", "-utf-8"] + elif self.build_opts.is_darwin(): + make = "make" + args = ["darwin64-x86_64-cc"] + elif self.build_opts.is_linux(): + make = "make" + args = ( + ["linux-x86_64"] if not self.build_opts.is_arm() else ["linux-aarch64"] + ) + else: + raise Exception("don't know how to build openssl for %r" % self.ctx) + + self._run_cmd( + [ + perl, + configure, + "--prefix=%s" % self.inst_dir, + "--openssldir=%s" % self.inst_dir, + ] + + args + + [ + "enable-static-engine", + "enable-capieng", + "no-makedepend", + "no-unit-test", + "no-tests", + ] + ) + self._run_cmd([make, "install_sw", "install_ssldirs"]) + + +class Boost(BuilderBase): + def __init__( + self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir, b2_args + ): + children = os.listdir(src_dir) + assert len(children) == 1, "expected a single directory entry: %r" % (children,) + boost_src = children[0] + assert boost_src.startswith("boost") + src_dir = os.path.join(src_dir, children[0]) + super(Boost, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.b2_args = b2_args + + def _build(self, install_dirs, reconfigure): + env = self._compute_env(install_dirs) + linkage = ["static"] + if self.build_opts.is_windows(): + linkage.append("shared") + + args = [] + if self.build_opts.is_darwin(): + clang = subprocess.check_output(["xcrun", "--find", "clang"]) + user_config = os.path.join(self.build_dir, "project-config.jam") + with open(user_config, "w") as jamfile: + jamfile.write("using clang : : %s ;\n" % clang.decode().strip()) + args.append("--user-config=%s" % user_config) + + for link in linkage: + if self.build_opts.is_windows(): + bootstrap = os.path.join(self.src_dir, "bootstrap.bat") + self._run_cmd([bootstrap], cwd=self.src_dir, env=env) + args += ["address-model=64"] + else: + bootstrap = os.path.join(self.src_dir, "bootstrap.sh") + self._run_cmd( + [bootstrap, "--prefix=%s" % self.inst_dir], + cwd=self.src_dir, + env=env, + ) + + b2 = os.path.join(self.src_dir, "b2") + self._run_cmd( + [ + b2, + "-j%s" % self.build_opts.num_jobs, + "--prefix=%s" % self.inst_dir, + "--builddir=%s" % self.build_dir, + ] + + args + + self.b2_args + + [ + "link=%s" % link, + "runtime-link=shared", + "variant=release", + "threading=multi", + "debug-symbols=on", + "visibility=global", + "-d2", + "install", + ], + cwd=self.src_dir, + env=env, + ) + + +class NopBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, src_dir, inst_dir): + super(NopBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, None, inst_dir + ) + + def build(self, install_dirs, reconfigure): + print("Installing %s -> %s" % (self.src_dir, self.inst_dir)) + parent = os.path.dirname(self.inst_dir) + if not os.path.exists(parent): + os.makedirs(parent) + + install_files = self.manifest.get_section_as_ordered_pairs( + "install.files", self.ctx + ) + if install_files: + for src_name, dest_name in self.manifest.get_section_as_ordered_pairs( + "install.files", self.ctx + ): + full_dest = os.path.join(self.inst_dir, dest_name) + full_src = os.path.join(self.src_dir, src_name) + + dest_parent = os.path.dirname(full_dest) + if not os.path.exists(dest_parent): + os.makedirs(dest_parent) + if os.path.isdir(full_src): + if not os.path.exists(full_dest): + shutil.copytree(full_src, full_dest) + else: + shutil.copyfile(full_src, full_dest) + shutil.copymode(full_src, full_dest) + # This is a bit gross, but the mac ninja.zip doesn't + # give ninja execute permissions, so force them on + # for things that look like they live in a bin dir + if os.path.dirname(dest_name) == "bin": + st = os.lstat(full_dest) + os.chmod(full_dest, st.st_mode | stat.S_IXUSR) + else: + if not os.path.exists(self.inst_dir): + shutil.copytree(self.src_dir, self.inst_dir) + + +class OpenNSABuilder(NopBuilder): + # OpenNSA libraries are stored with git LFS. As a result, fetcher fetches + # LFS pointers and not the contents. Use git-lfs to pull the real contents + # before copying to install dir using NoopBuilder. + # In future, if more builders require git-lfs, we would consider installing + # git-lfs as part of the sandcastle infra as against repeating similar + # logic for each builder that requires git-lfs. + def __init__(self, build_opts, ctx, manifest, src_dir, inst_dir): + super(OpenNSABuilder, self).__init__( + build_opts, ctx, manifest, src_dir, inst_dir + ) + + def build(self, install_dirs, reconfigure): + env = self._compute_env(install_dirs) + self._run_cmd(["git", "lfs", "install", "--local"], cwd=self.src_dir, env=env) + self._run_cmd(["git", "lfs", "pull"], cwd=self.src_dir, env=env) + + super(OpenNSABuilder, self).build(install_dirs, reconfigure) + + +class SqliteBuilder(BuilderBase): + def __init__(self, build_opts, ctx, manifest, src_dir, build_dir, inst_dir): + super(SqliteBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + + def _build(self, install_dirs, reconfigure): + for f in ["sqlite3.c", "sqlite3.h", "sqlite3ext.h"]: + src = os.path.join(self.src_dir, f) + dest = os.path.join(self.build_dir, f) + copy_if_different(src, dest) + + cmake_lists = """ +cmake_minimum_required(VERSION 3.1.3 FATAL_ERROR) +project(sqlite3 C) +add_library(sqlite3 STATIC sqlite3.c) +# These options are taken from the defaults in Makefile.msc in +# the sqlite distribution +target_compile_definitions(sqlite3 PRIVATE + -DSQLITE_ENABLE_COLUMN_METADATA=1 + -DSQLITE_ENABLE_FTS3=1 + -DSQLITE_ENABLE_RTREE=1 + -DSQLITE_ENABLE_GEOPOLY=1 + -DSQLITE_ENABLE_JSON1=1 + -DSQLITE_ENABLE_STMTVTAB=1 + -DSQLITE_ENABLE_DBPAGE_VTAB=1 + -DSQLITE_ENABLE_DBSTAT_VTAB=1 + -DSQLITE_INTROSPECTION_PRAGMAS=1 + -DSQLITE_ENABLE_DESERIALIZE=1 +) +install(TARGETS sqlite3) +install(FILES sqlite3.h sqlite3ext.h DESTINATION include) + """ + + with open(os.path.join(self.build_dir, "CMakeLists.txt"), "w") as f: + f.write(cmake_lists) + + defines = { + "CMAKE_INSTALL_PREFIX": self.inst_dir, + "BUILD_SHARED_LIBS": "OFF", + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + } + define_args = ["-D%s=%s" % (k, v) for (k, v) in defines.items()] + define_args += ["-G", "Ninja"] + + env = self._compute_env(install_dirs) + + # Resolve the cmake that we installed + cmake = path_search(env, "cmake") + + self._run_cmd([cmake, self.build_dir] + define_args, env=env) + self._run_cmd( + [ + cmake, + "--build", + self.build_dir, + "--target", + "install", + "--config", + "Release", + "-j", + str(self.build_opts.num_jobs), + ], + env=env, + ) + + +class CargoBuilder(BuilderBase): + def __init__( + self, + build_opts, + ctx, + manifest, + src_dir, + build_dir, + inst_dir, + build_doc, + workspace_dir, + manifests_to_build, + loader, + ): + super(CargoBuilder, self).__init__( + build_opts, ctx, manifest, src_dir, build_dir, inst_dir + ) + self.build_doc = build_doc + self.ws_dir = workspace_dir + self.manifests_to_build = manifests_to_build and manifests_to_build.split(",") + self.loader = loader + + def run_cargo(self, install_dirs, operation, args=None): + args = args or [] + env = self._compute_env(install_dirs) + # Enable using nightly features with stable compiler + env["RUSTC_BOOTSTRAP"] = "1" + env["LIBZ_SYS_STATIC"] = "1" + cmd = [ + "cargo", + operation, + "--workspace", + "-j%s" % self.build_opts.num_jobs, + ] + args + self._run_cmd(cmd, cwd=self.workspace_dir(), env=env) + + def build_source_dir(self): + return os.path.join(self.build_dir, "source") + + def workspace_dir(self): + return os.path.join(self.build_source_dir(), self.ws_dir or "") + + def manifest_dir(self, manifest): + return os.path.join(self.build_source_dir(), manifest) + + def recreate_dir(self, src, dst): + if os.path.isdir(dst): + shutil.rmtree(dst) + shutil.copytree(src, dst) + + def _build(self, install_dirs, reconfigure): + build_source_dir = self.build_source_dir() + self.recreate_dir(self.src_dir, build_source_dir) + + dot_cargo_dir = os.path.join(build_source_dir, ".cargo") + if not os.path.isdir(dot_cargo_dir): + os.mkdir(dot_cargo_dir) + + with open(os.path.join(dot_cargo_dir, "config"), "w+") as f: + f.write( + """\ +[build] +target-dir = '''{}''' + +[net] +git-fetch-with-cli = true + +[profile.dev] +debug = false +incremental = false +""".format( + self.build_dir.replace("\\", "\\\\") + ) + ) + + if self.ws_dir is not None: + self._patchup_workspace() + + try: + from getdeps.facebook.rust import vendored_crates + + vendored_crates(self.build_opts, build_source_dir) + except ImportError: + # This FB internal module isn't shippped to github, + # so just rely on cargo downloading crates on it's own + pass + + if self.manifests_to_build is None: + self.run_cargo( + install_dirs, + "build", + ["--out-dir", os.path.join(self.inst_dir, "bin"), "-Zunstable-options"], + ) + else: + for manifest in self.manifests_to_build: + self.run_cargo( + install_dirs, + "build", + [ + "--out-dir", + os.path.join(self.inst_dir, "bin"), + "-Zunstable-options", + "--manifest-path", + self.manifest_dir(manifest), + ], + ) + + self.recreate_dir(build_source_dir, os.path.join(self.inst_dir, "source")) + + def run_tests( + self, install_dirs, schedule_type, owner, test_filter, retry, no_testpilot + ): + if test_filter: + args = ["--", test_filter] + else: + args = [] + + if self.manifests_to_build is None: + self.run_cargo(install_dirs, "test", args) + if self.build_doc: + self.run_cargo(install_dirs, "doc", ["--no-deps"]) + else: + for manifest in self.manifests_to_build: + margs = ["--manifest-path", self.manifest_dir(manifest)] + self.run_cargo(install_dirs, "test", args + margs) + if self.build_doc: + self.run_cargo(install_dirs, "doc", ["--no-deps"] + margs) + + def _patchup_workspace(self): + """ + This method makes some assumptions about the state of the project and + its cargo dependendies: + 1. Crates from cargo dependencies can be extracted from Cargo.toml files + using _extract_crates function. It is using a heuristic so check its + code to understand how it is done. + 2. The extracted cargo dependencies crates can be found in the + dependency's install dir using _resolve_crate_to_path function + which again is using a heuristic. + + Notice that many things might go wrong here. E.g. if someone depends + on another getdeps crate by writing in their Cargo.toml file: + + my-rename-of-crate = { package = "crate", git = "..." } + + they can count themselves lucky because the code will raise an + Exception. There migh be more cases where the code will silently pass + producing bad results. + """ + workspace_dir = self.workspace_dir() + config = self._resolve_config() + if config: + with open(os.path.join(workspace_dir, "Cargo.toml"), "r+") as f: + manifest_content = f.read() + if "[package]" not in manifest_content: + # A fake manifest has to be crated to change the virtual + # manifest into a non-virtual. The virtual manifests are limited + # in many ways and the inability to define patches on them is + # one. Check https://github.com/rust-lang/cargo/issues/4934 to + # see if it is resolved. + f.write( + """ + [package] + name = "fake_manifest_of_{}" + version = "0.0.0" + [lib] + path = "/dev/null" + """.format( + self.manifest.name + ) + ) + else: + f.write("\n") + f.write(config) + + def _resolve_config(self): + """ + Returns a configuration to be put inside root Cargo.toml file which + patches the dependencies git code with local getdeps versions. + See https://doc.rust-lang.org/cargo/reference/manifest.html#the-patch-section + """ + dep_to_git = self._resolve_dep_to_git() + dep_to_crates = CargoBuilder._resolve_dep_to_crates( + self.build_source_dir(), dep_to_git + ) + + config = [] + for name in sorted(dep_to_git.keys()): + git_conf = dep_to_git[name] + crates = sorted(dep_to_crates.get(name, [])) + if not crates: + continue # nothing to patch, move along + crates_patches = [ + '{} = {{ path = "{}" }}'.format( + crate, + CargoBuilder._resolve_crate_to_path(crate, git_conf).replace( + "\\", "\\\\" + ), + ) + for crate in crates + ] + + config.append( + '[patch."{0}"]\n'.format(git_conf["repo_url"]) + + "\n".join(crates_patches) + ) + return "\n".join(config) + + def _resolve_dep_to_git(self): + """ + For each direct dependency of the currently build manifest check if it + is also cargo-builded and if yes then extract it's git configs and + install dir + """ + dependencies = self.manifest.get_section_as_dict("dependencies", ctx=self.ctx) + if not dependencies: + return [] + + dep_to_git = {} + for dep in dependencies.keys(): + dep_manifest = self.loader.load_manifest(dep) + dep_builder = dep_manifest.get("build", "builder", ctx=self.ctx) + if dep_builder not in ["cargo", "nop"] or dep == "rust": + # This is a direct dependency, but it is not build with cargo + # and it is not simply copying files with nop, so ignore it. + # The "rust" dependency is an exception since it contains the + # toolchain. + continue + + git_conf = dep_manifest.get_section_as_dict("git", ctx=self.ctx) + if "repo_url" not in git_conf: + raise Exception( + "A cargo dependency requires git.repo_url to be defined." + ) + source_dir = self.loader.get_project_install_dir(dep_manifest) + if dep_builder == "cargo": + source_dir = os.path.join(source_dir, "source") + git_conf["source_dir"] = source_dir + dep_to_git[dep] = git_conf + return dep_to_git + + @staticmethod + def _resolve_dep_to_crates(build_source_dir, dep_to_git): + """ + This function traverse the build_source_dir in search of Cargo.toml + files, extracts the crate names from them using _extract_crates + function and returns a merged result containing crate names per + dependency name from all Cargo.toml files in the project. + """ + if not dep_to_git: + return {} # no deps, so don't waste time traversing files + + dep_to_crates = {} + for root, _, files in os.walk(build_source_dir): + for f in files: + if f == "Cargo.toml": + more_dep_to_crates = CargoBuilder._extract_crates( + os.path.join(root, f), dep_to_git + ) + for name, crates in more_dep_to_crates.items(): + dep_to_crates.setdefault(name, set()).update(crates) + return dep_to_crates + + @staticmethod + def _extract_crates(cargo_toml_file, dep_to_git): + """ + This functions reads content of provided cargo toml file and extracts + crate names per each dependency. The extraction is done by a heuristic + so it might be incorrect. + """ + deps_to_crates = {} + with open(cargo_toml_file, "r") as f: + for line in f.readlines(): + if line.startswith("#") or "git = " not in line: + continue # filter out commented lines and ones without git deps + for name, conf in dep_to_git.items(): + if 'git = "{}"'.format(conf["repo_url"]) in line: + pkg_template = ' package = "' + if pkg_template in line: + crate_name, _, _ = line.partition(pkg_template)[ + 2 + ].partition('"') + else: + crate_name, _, _ = line.partition("=") + deps_to_crates.setdefault(name, set()).add(crate_name.strip()) + return deps_to_crates + + @staticmethod + def _resolve_crate_to_path(crate, git_conf): + """ + Tries to find in git_conf["inst_dir"] by searching a [package] + keyword followed by name = "". + """ + source_dir = git_conf["source_dir"] + search_pattern = '[package]\nname = "{}"'.format(crate) + + for root, _, files in os.walk(source_dir): + for fname in files: + if fname == "Cargo.toml": + with open(os.path.join(root, fname), "r") as f: + if search_pattern in f.read(): + return root + + raise Exception("Failed to found crate {} in path {}".format(crate, source_dir)) diff --git a/build/fbcode_builder/getdeps/buildopts.py b/build/fbcode_builder/getdeps/buildopts.py new file mode 100644 index 000000000..bc6d2da87 --- /dev/null +++ b/build/fbcode_builder/getdeps/buildopts.py @@ -0,0 +1,458 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import errno +import glob +import ntpath +import os +import subprocess +import sys +import tempfile + +from .copytree import containing_repo_type +from .envfuncs import Env, add_path_entry +from .fetcher import get_fbsource_repo_data +from .manifest import ContextGenerator +from .platform import HostType, is_windows + + +try: + import typing # noqa: F401 +except ImportError: + pass + + +def detect_project(path): + repo_type, repo_root = containing_repo_type(path) + if repo_type is None: + return None, None + + # Look for a .projectid file. If it exists, read the project name from it. + project_id_path = os.path.join(repo_root, ".projectid") + try: + with open(project_id_path, "r") as f: + project_name = f.read().strip() + return repo_root, project_name + except EnvironmentError as ex: + if ex.errno != errno.ENOENT: + raise + + return repo_root, None + + +class BuildOptions(object): + def __init__( + self, + fbcode_builder_dir, + scratch_dir, + host_type, + install_dir=None, + num_jobs=0, + use_shipit=False, + vcvars_path=None, + allow_system_packages=False, + lfs_path=None, + ): + """fbcode_builder_dir - the path to either the in-fbsource fbcode_builder dir, + or for shipit-transformed repos, the build dir that + has been mapped into that dir. + scratch_dir - a place where we can store repos and build bits. + This path should be stable across runs and ideally + should not be in the repo of the project being built, + but that is ultimately where we generally fall back + for builds outside of FB + install_dir - where the project will ultimately be installed + num_jobs - the level of concurrency to use while building + use_shipit - use real shipit instead of the simple shipit transformer + vcvars_path - Path to external VS toolchain's vsvarsall.bat + """ + if not num_jobs: + import multiprocessing + + num_jobs = multiprocessing.cpu_count() // 2 + + if not install_dir: + install_dir = os.path.join(scratch_dir, "installed") + + self.project_hashes = None + for p in ["../deps/github_hashes", "../project_hashes"]: + hashes = os.path.join(fbcode_builder_dir, p) + if os.path.exists(hashes): + self.project_hashes = hashes + break + + # Detect what repository and project we are being run from. + self.repo_root, self.repo_project = detect_project(os.getcwd()) + + # If we are running from an fbsource repository, set self.fbsource_dir + # to allow the ShipIt-based fetchers to use it. + if self.repo_project == "fbsource": + self.fbsource_dir = self.repo_root + else: + self.fbsource_dir = None + + self.num_jobs = num_jobs + self.scratch_dir = scratch_dir + self.install_dir = install_dir + self.fbcode_builder_dir = fbcode_builder_dir + self.host_type = host_type + self.use_shipit = use_shipit + self.allow_system_packages = allow_system_packages + self.lfs_path = lfs_path + if vcvars_path is None and is_windows(): + + # On Windows, the compiler is not available in the PATH by + # default so we need to run the vcvarsall script to populate the + # environment. We use a glob to find some version of this script + # as deployed with Visual Studio 2017. This logic can also + # locate Visual Studio 2019 but note that at the time of writing + # the version of boost in our manifest cannot be built with + # VS 2019, so we're effectively tied to VS 2017 until we upgrade + # the boost dependency. + vcvarsall = [] + for year in ["2017", "2019"]: + vcvarsall += glob.glob( + os.path.join( + os.environ["ProgramFiles(x86)"], + "Microsoft Visual Studio", + year, + "*", + "VC", + "Auxiliary", + "Build", + "vcvarsall.bat", + ) + ) + vcvars_path = vcvarsall[0] + + self.vcvars_path = vcvars_path + + @property + def manifests_dir(self): + return os.path.join(self.fbcode_builder_dir, "manifests") + + def is_darwin(self): + return self.host_type.is_darwin() + + def is_windows(self): + return self.host_type.is_windows() + + def is_arm(self): + return self.host_type.is_arm() + + def get_vcvars_path(self): + return self.vcvars_path + + def is_linux(self): + return self.host_type.is_linux() + + def get_context_generator(self, host_tuple=None, facebook_internal=None): + """Create a manifest ContextGenerator for the specified target platform.""" + if host_tuple is None: + host_type = self.host_type + elif isinstance(host_tuple, HostType): + host_type = host_tuple + else: + host_type = HostType.from_tuple_string(host_tuple) + + # facebook_internal is an Optional[bool] + # If it is None, default to assuming this is a Facebook-internal build if + # we are running in an fbsource repository. + if facebook_internal is None: + facebook_internal = self.fbsource_dir is not None + + return ContextGenerator( + { + "os": host_type.ostype, + "distro": host_type.distro, + "distro_vers": host_type.distrovers, + "fb": "on" if facebook_internal else "off", + "test": "off", + } + ) + + def compute_env_for_install_dirs(self, install_dirs, env=None, manifest=None): + if env is not None: + env = env.copy() + else: + env = Env() + + env["GETDEPS_BUILD_DIR"] = os.path.join(self.scratch_dir, "build") + env["GETDEPS_INSTALL_DIR"] = self.install_dir + + # On macOS we need to set `SDKROOT` when we use clang for system + # header files. + if self.is_darwin() and "SDKROOT" not in env: + sdkroot = subprocess.check_output(["xcrun", "--show-sdk-path"]) + env["SDKROOT"] = sdkroot.decode().strip() + + if self.fbsource_dir: + env["YARN_YARN_OFFLINE_MIRROR"] = os.path.join( + self.fbsource_dir, "xplat/third-party/yarn/offline-mirror" + ) + yarn_exe = "yarn.bat" if self.is_windows() else "yarn" + env["YARN_PATH"] = os.path.join( + self.fbsource_dir, "xplat/third-party/yarn/", yarn_exe + ) + node_exe = "node-win-x64.exe" if self.is_windows() else "node" + env["NODE_BIN"] = os.path.join( + self.fbsource_dir, "xplat/third-party/node/bin/", node_exe + ) + env["RUST_VENDORED_CRATES_DIR"] = os.path.join( + self.fbsource_dir, "third-party/rust/vendor" + ) + hash_data = get_fbsource_repo_data(self) + env["FBSOURCE_HASH"] = hash_data.hash + env["FBSOURCE_DATE"] = hash_data.date + + lib_path = None + if self.is_darwin(): + lib_path = "DYLD_LIBRARY_PATH" + elif self.is_linux(): + lib_path = "LD_LIBRARY_PATH" + elif self.is_windows(): + lib_path = "PATH" + else: + lib_path = None + + for d in install_dirs: + bindir = os.path.join(d, "bin") + + if not ( + manifest and manifest.get("build", "disable_env_override_pkgconfig") + ): + pkgconfig = os.path.join(d, "lib/pkgconfig") + if os.path.exists(pkgconfig): + add_path_entry(env, "PKG_CONFIG_PATH", pkgconfig) + + pkgconfig = os.path.join(d, "lib64/pkgconfig") + if os.path.exists(pkgconfig): + add_path_entry(env, "PKG_CONFIG_PATH", pkgconfig) + + if not (manifest and manifest.get("build", "disable_env_override_path")): + add_path_entry(env, "CMAKE_PREFIX_PATH", d) + + # Allow resolving shared objects built earlier (eg: zstd + # doesn't include the full path to the dylib in its linkage + # so we need to give it an assist) + if lib_path: + for lib in ["lib", "lib64"]: + libdir = os.path.join(d, lib) + if os.path.exists(libdir): + add_path_entry(env, lib_path, libdir) + + # Allow resolving binaries (eg: cmake, ninja) and dlls + # built by earlier steps + if os.path.exists(bindir): + add_path_entry(env, "PATH", bindir, append=False) + + # If rustc is present in the `bin` directory, set RUSTC to prevent + # cargo uses the rustc installed in the system. + if self.is_windows(): + cargo_path = os.path.join(bindir, "cargo.exe") + rustc_path = os.path.join(bindir, "rustc.exe") + rustdoc_path = os.path.join(bindir, "rustdoc.exe") + else: + cargo_path = os.path.join(bindir, "cargo") + rustc_path = os.path.join(bindir, "rustc") + rustdoc_path = os.path.join(bindir, "rustdoc") + + if os.path.isfile(rustc_path): + env["CARGO_BIN"] = cargo_path + env["RUSTC"] = rustc_path + env["RUSTDOC"] = rustdoc_path + + openssl_include = os.path.join(d, "include/openssl") + if os.path.isdir(openssl_include) and any( + os.path.isfile(os.path.join(d, "lib", libcrypto)) + for libcrypto in ("libcrypto.lib", "libcrypto.so", "libcrypto.a") + ): + # This must be the openssl library, let Rust know about it + env["OPENSSL_DIR"] = d + + return env + + +def list_win32_subst_letters(): + output = subprocess.check_output(["subst"]).decode("utf-8") + # The output is a set of lines like: `F:\: => C:\open\some\where` + lines = output.strip().split("\r\n") + mapping = {} + for line in lines: + fields = line.split(": => ") + if len(fields) != 2: + continue + letter = fields[0] + path = fields[1] + mapping[letter] = path + + return mapping + + +def find_existing_win32_subst_for_path( + path, # type: str + subst_mapping, # type: typing.Mapping[str, str] +): + # type: (...) -> typing.Optional[str] + path = ntpath.normcase(ntpath.normpath(path)) + for letter, target in subst_mapping.items(): + if ntpath.normcase(target) == path: + return letter + return None + + +def find_unused_drive_letter(): + import ctypes + + buffer_len = 256 + blen = ctypes.c_uint(buffer_len) + rv = ctypes.c_uint() + bufs = ctypes.create_string_buffer(buffer_len) + rv = ctypes.windll.kernel32.GetLogicalDriveStringsA(blen, bufs) + if rv > buffer_len: + raise Exception("GetLogicalDriveStringsA result too large for buffer") + nul = "\x00".encode("ascii") + + used = [drive.decode("ascii")[0] for drive in bufs.raw.strip(nul).split(nul)] + possible = [c for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"] + available = sorted(list(set(possible) - set(used))) + if len(available) == 0: + return None + # Prefer to assign later letters rather than earlier letters + return available[-1] + + +def create_subst_path(path): + for _attempt in range(0, 24): + drive = find_existing_win32_subst_for_path( + path, subst_mapping=list_win32_subst_letters() + ) + if drive: + return drive + available = find_unused_drive_letter() + if available is None: + raise Exception( + ( + "unable to make shorter subst mapping for %s; " + "no available drive letters" + ) + % path + ) + + # Try to set up a subst mapping; note that we may be racing with + # other processes on the same host, so this may not succeed. + try: + subprocess.check_call(["subst", "%s:" % available, path]) + return "%s:\\" % available + except Exception: + print("Failed to map %s -> %s" % (available, path)) + + raise Exception("failed to set up a subst path for %s" % path) + + +def _check_host_type(args, host_type): + if host_type is None: + host_tuple_string = getattr(args, "host_type", None) + if host_tuple_string: + host_type = HostType.from_tuple_string(host_tuple_string) + else: + host_type = HostType() + + assert isinstance(host_type, HostType) + return host_type + + +def setup_build_options(args, host_type=None): + """Create a BuildOptions object based on the arguments""" + + fbcode_builder_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + scratch_dir = args.scratch_path + if not scratch_dir: + # TODO: `mkscratch` doesn't currently know how best to place things on + # sandcastle, so whip up something reasonable-ish + if "SANDCASTLE" in os.environ: + if "DISK_TEMP" not in os.environ: + raise Exception( + ( + "I need DISK_TEMP to be set in the sandcastle environment " + "so that I can store build products somewhere sane" + ) + ) + scratch_dir = os.path.join( + os.environ["DISK_TEMP"], "fbcode_builder_getdeps" + ) + if not scratch_dir: + try: + scratch_dir = ( + subprocess.check_output( + ["mkscratch", "path", "--subdir", "fbcode_builder_getdeps"] + ) + .strip() + .decode("utf-8") + ) + except OSError as exc: + if exc.errno != errno.ENOENT: + # A legit failure; don't fall back, surface the error + raise + # This system doesn't have mkscratch so we fall back to + # something local. + munged = fbcode_builder_dir.replace("Z", "zZ") + for s in ["/", "\\", ":"]: + munged = munged.replace(s, "Z") + + if is_windows() and os.path.isdir("c:/open"): + temp = "c:/open/scratch" + else: + temp = tempfile.gettempdir() + + scratch_dir = os.path.join(temp, "fbcode_builder_getdeps-%s" % munged) + if not is_windows() and os.geteuid() == 0: + # Running as root; in the case where someone runs + # sudo getdeps.py install-system-deps + # and then runs as build without privs, we want to avoid creating + # a scratch dir that the second stage cannot write to. + # So we generate a different path if we are root. + scratch_dir += "-root" + + if not os.path.exists(scratch_dir): + os.makedirs(scratch_dir) + + if is_windows(): + subst = create_subst_path(scratch_dir) + print( + "Mapping scratch dir %s -> %s" % (scratch_dir, subst), file=sys.stderr + ) + scratch_dir = subst + else: + if not os.path.exists(scratch_dir): + os.makedirs(scratch_dir) + + # Make sure we normalize the scratch path. This path is used as part of the hash + # computation for detecting if projects have been updated, so we need to always + # use the exact same string to refer to a given directory. + # But! realpath in some combinations of Windows/Python3 versions can expand the + # drive substitutions on Windows, so avoid that! + if not is_windows(): + scratch_dir = os.path.realpath(scratch_dir) + + # Save any extra cmake defines passed by the user in an env variable, so it + # can be used while hashing this build. + os.environ["GETDEPS_CMAKE_DEFINES"] = getattr(args, "extra_cmake_defines", "") or "" + + host_type = _check_host_type(args, host_type) + + return BuildOptions( + fbcode_builder_dir, + scratch_dir, + host_type, + install_dir=args.install_prefix, + num_jobs=args.num_jobs, + use_shipit=args.use_shipit, + vcvars_path=args.vcvars_path, + allow_system_packages=args.allow_system_packages, + lfs_path=args.lfs_path, + ) diff --git a/build/fbcode_builder/getdeps/cache.py b/build/fbcode_builder/getdeps/cache.py new file mode 100644 index 000000000..a261541c7 --- /dev/null +++ b/build/fbcode_builder/getdeps/cache.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + + +class ArtifactCache(object): + """The ArtifactCache is a small abstraction that allows caching + named things in some external storage mechanism. + The primary use case is for storing the build products on CI + systems to accelerate the build""" + + def download_to_file(self, name, dest_file_name): + """If `name` exists in the cache, download it and place it + in the specified `dest_file_name` location on the filesystem. + If a transient issue was encountered a TransientFailure shall + be raised. + If `name` doesn't exist in the cache `False` shall be returned. + If `dest_file_name` was successfully updated `True` shall be + returned. + All other conditions shall raise an appropriate exception.""" + return False + + def upload_from_file(self, name, source_file_name): + """Causes `name` to be populated in the cache by uploading + the contents of `source_file_name` to the storage system. + If a transient issue was encountered a TransientFailure shall + be raised. + If the upload failed for some other reason, an appropriate + exception shall be raised.""" + pass + + +def create_cache(): + """This function is monkey patchable to provide an actual + implementation""" + return None diff --git a/build/fbcode_builder/getdeps/copytree.py b/build/fbcode_builder/getdeps/copytree.py new file mode 100644 index 000000000..2790bc0d9 --- /dev/null +++ b/build/fbcode_builder/getdeps/copytree.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import shutil +import subprocess + +from .platform import is_windows + + +PREFETCHED_DIRS = set() + + +def containing_repo_type(path): + while True: + if os.path.exists(os.path.join(path, ".git")): + return ("git", path) + if os.path.exists(os.path.join(path, ".hg")): + return ("hg", path) + + parent = os.path.dirname(path) + if parent == path: + return None, None + path = parent + + +def find_eden_root(dirpath): + """If the specified directory is inside an EdenFS checkout, returns + the canonical absolute path to the root of that checkout. + + Returns None if the specified directory is not in an EdenFS checkout. + """ + if is_windows(): + repo_type, repo_root = containing_repo_type(dirpath) + if repo_root is not None: + if os.path.exists(os.path.join(repo_root, ".eden", "config")): + return os.path.realpath(repo_root) + return None + + try: + return os.readlink(os.path.join(dirpath, ".eden", "root")) + except OSError: + return None + + +def prefetch_dir_if_eden(dirpath): + """After an amend/rebase, Eden may need to fetch a large number + of trees from the servers. The simplistic single threaded walk + performed by copytree makes this more expensive than is desirable + so we help accelerate things by performing a prefetch on the + source directory""" + global PREFETCHED_DIRS + if dirpath in PREFETCHED_DIRS: + return + root = find_eden_root(dirpath) + if root is None: + return + glob = f"{os.path.relpath(dirpath, root).replace(os.sep, '/')}/**" + print(f"Prefetching {glob}") + subprocess.call(["edenfsctl", "prefetch", "--repo", root, "--silent", glob]) + PREFETCHED_DIRS.add(dirpath) + + +def copytree(src_dir, dest_dir, ignore=None): + """Recursively copy the src_dir to the dest_dir, filtering + out entries using the ignore lambda. The behavior of the + ignore lambda must match that described by `shutil.copytree`. + This `copytree` function knows how to prefetch data when + running in an eden repo. + TODO: I'd like to either extend this or add a variant that + uses watchman to mirror src_dir into dest_dir. + """ + prefetch_dir_if_eden(src_dir) + return shutil.copytree(src_dir, dest_dir, ignore=ignore) diff --git a/build/fbcode_builder/getdeps/dyndeps.py b/build/fbcode_builder/getdeps/dyndeps.py new file mode 100644 index 000000000..216f26c46 --- /dev/null +++ b/build/fbcode_builder/getdeps/dyndeps.py @@ -0,0 +1,430 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import errno +import glob +import os +import re +import shutil +import stat +import subprocess +import sys +from struct import unpack + +from .envfuncs import path_search + + +OBJECT_SUBDIRS = ("bin", "lib", "lib64") + + +def copyfile(src, dest): + shutil.copyfile(src, dest) + shutil.copymode(src, dest) + + +class DepBase(object): + def __init__(self, buildopts, install_dirs, strip): + self.buildopts = buildopts + self.env = buildopts.compute_env_for_install_dirs(install_dirs) + self.install_dirs = install_dirs + self.strip = strip + self.processed_deps = set() + + def list_dynamic_deps(self, objfile): + raise RuntimeError("list_dynamic_deps not implemented") + + def interesting_dep(self, d): + return True + + # final_install_prefix must be the equivalent path to `destdir` on the + # installed system. For example, if destdir is `/tmp/RANDOM/usr/local' which + # is intended to map to `/usr/local` in the install image, then + # final_install_prefix='/usr/local'. + # If left unspecified, destdir will be used. + def process_deps(self, destdir, final_install_prefix=None): + if self.buildopts.is_windows(): + lib_dir = "bin" + else: + lib_dir = "lib" + self.munged_lib_dir = os.path.join(destdir, lib_dir) + + final_lib_dir = os.path.join(final_install_prefix or destdir, lib_dir) + + if not os.path.isdir(self.munged_lib_dir): + os.makedirs(self.munged_lib_dir) + + # Look only at the things that got installed in the leaf package, + # which will be the last entry in the install dirs list + inst_dir = self.install_dirs[-1] + print("Process deps under %s" % inst_dir, file=sys.stderr) + + for dir in OBJECT_SUBDIRS: + src_dir = os.path.join(inst_dir, dir) + if not os.path.isdir(src_dir): + continue + dest_dir = os.path.join(destdir, dir) + if not os.path.exists(dest_dir): + os.makedirs(dest_dir) + + for objfile in self.list_objs_in_dir(src_dir): + print("Consider %s/%s" % (dir, objfile)) + dest_obj = os.path.join(dest_dir, objfile) + copyfile(os.path.join(src_dir, objfile), dest_obj) + self.munge_in_place(dest_obj, final_lib_dir) + + def find_all_dependencies(self, build_dir): + all_deps = set() + for objfile in self.list_objs_in_dir( + build_dir, recurse=True, output_prefix=build_dir + ): + for d in self.list_dynamic_deps(objfile): + all_deps.add(d) + + interesting_deps = {d for d in all_deps if self.interesting_dep(d)} + dep_paths = [] + for dep in interesting_deps: + dep_path = self.resolve_loader_path(dep) + if dep_path: + dep_paths.append(dep_path) + + return dep_paths + + def munge_in_place(self, objfile, final_lib_dir): + print("Munging %s" % objfile) + for d in self.list_dynamic_deps(objfile): + if not self.interesting_dep(d): + continue + + # Resolve this dep: does it exist in any of our installation + # directories? If so, then it is a candidate for processing + dep = self.resolve_loader_path(d) + print("dep: %s -> %s" % (d, dep)) + if dep: + dest_dep = os.path.join(self.munged_lib_dir, os.path.basename(dep)) + if dep not in self.processed_deps: + self.processed_deps.add(dep) + copyfile(dep, dest_dep) + self.munge_in_place(dest_dep, final_lib_dir) + + self.rewrite_dep(objfile, d, dep, dest_dep, final_lib_dir) + + if self.strip: + self.strip_debug_info(objfile) + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + raise RuntimeError("rewrite_dep not implemented") + + def resolve_loader_path(self, dep): + if os.path.isabs(dep): + return dep + d = os.path.basename(dep) + for inst_dir in self.install_dirs: + for libdir in OBJECT_SUBDIRS: + candidate = os.path.join(inst_dir, libdir, d) + if os.path.exists(candidate): + return candidate + return None + + def list_objs_in_dir(self, dir, recurse=False, output_prefix=""): + for entry in os.listdir(dir): + entry_path = os.path.join(dir, entry) + st = os.lstat(entry_path) + if stat.S_ISREG(st.st_mode): + if self.is_objfile(entry_path): + relative_result = os.path.join(output_prefix, entry) + yield os.path.normcase(relative_result) + elif recurse and stat.S_ISDIR(st.st_mode): + child_prefix = os.path.join(output_prefix, entry) + for result in self.list_objs_in_dir( + entry_path, recurse=recurse, output_prefix=child_prefix + ): + yield result + + def is_objfile(self, objfile): + return True + + def strip_debug_info(self, objfile): + """override this to define how to remove debug information + from an object file""" + pass + + +class WinDeps(DepBase): + def __init__(self, buildopts, install_dirs, strip): + super(WinDeps, self).__init__(buildopts, install_dirs, strip) + self.dumpbin = self.find_dumpbin() + + def find_dumpbin(self): + # Looking for dumpbin in the following hardcoded paths. + # The registry option to find the install dir doesn't work anymore. + globs = [ + ( + "C:/Program Files (x86)/" + "Microsoft Visual Studio/" + "*/*/VC/Tools/" + "MSVC/*/bin/Hostx64/x64/dumpbin.exe" + ), + ( + "C:/Program Files (x86)/" + "Common Files/" + "Microsoft/Visual C++ for Python/*/" + "VC/bin/dumpbin.exe" + ), + ("c:/Program Files (x86)/Microsoft Visual Studio */VC/bin/dumpbin.exe"), + ] + for pattern in globs: + for exe in glob.glob(pattern): + return exe + + raise RuntimeError("could not find dumpbin.exe") + + def list_dynamic_deps(self, exe): + deps = [] + print("Resolve deps for %s" % exe) + output = subprocess.check_output( + [self.dumpbin, "/nologo", "/dependents", exe] + ).decode("utf-8") + + lines = output.split("\n") + for line in lines: + m = re.match("\\s+(\\S+.dll)", line, re.IGNORECASE) + if m: + deps.append(m.group(1).lower()) + + return deps + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + # We can't rewrite on windows, but we will + # place the deps alongside the exe so that + # they end up in the search path + pass + + # These are the Windows system dll, which we don't want to copy while + # packaging. + SYSTEM_DLLS = set( # noqa: C405 + [ + "advapi32.dll", + "dbghelp.dll", + "kernel32.dll", + "msvcp140.dll", + "vcruntime140.dll", + "ws2_32.dll", + "ntdll.dll", + "shlwapi.dll", + ] + ) + + def interesting_dep(self, d): + if "api-ms-win-crt" in d: + return False + if d in self.SYSTEM_DLLS: + return False + return True + + def is_objfile(self, objfile): + if not os.path.isfile(objfile): + return False + if objfile.lower().endswith(".exe"): + return True + return False + + def emit_dev_run_script(self, script_path, dep_dirs): + """Emit a script that can be used to run build artifacts directly from the + build directory, without installing them. + + The dep_dirs parameter should be a list of paths that need to be added to $PATH. + This can be computed by calling compute_dependency_paths() or + compute_dependency_paths_fast(). + + This is only necessary on Windows, which does not have RPATH, and instead + requires the $PATH environment variable be updated in order to find the proper + library dependencies. + """ + contents = self._get_dev_run_script_contents(dep_dirs) + with open(script_path, "w") as f: + f.write(contents) + + def compute_dependency_paths(self, build_dir): + """Return a list of all directories that need to be added to $PATH to ensure + that library dependencies can be found correctly. This is computed by scanning + binaries to determine exactly the right list of dependencies. + + The compute_dependency_paths_fast() is a alternative function that runs faster + but may return additional extraneous paths. + """ + dep_dirs = set() + # Find paths by scanning the binaries. + for dep in self.find_all_dependencies(build_dir): + dep_dirs.add(os.path.dirname(dep)) + + dep_dirs.update(self.read_custom_dep_dirs(build_dir)) + return sorted(dep_dirs) + + def compute_dependency_paths_fast(self, build_dir): + """Similar to compute_dependency_paths(), but rather than actually scanning + binaries, just add all library paths from the specified installation + directories. This is much faster than scanning the binaries, but may result in + more paths being returned than actually necessary. + """ + dep_dirs = set() + for inst_dir in self.install_dirs: + for subdir in OBJECT_SUBDIRS: + path = os.path.join(inst_dir, subdir) + if os.path.exists(path): + dep_dirs.add(path) + + dep_dirs.update(self.read_custom_dep_dirs(build_dir)) + return sorted(dep_dirs) + + def read_custom_dep_dirs(self, build_dir): + # The build system may also have included libraries from other locations that + # we might not be able to find normally in find_all_dependencies(). + # To handle this situation we support reading additional library paths + # from a LIBRARY_DEP_DIRS.txt file that may have been generated in the build + # output directory. + dep_dirs = set() + try: + explicit_dep_dirs_path = os.path.join(build_dir, "LIBRARY_DEP_DIRS.txt") + with open(explicit_dep_dirs_path, "r") as f: + for line in f.read().splitlines(): + dep_dirs.add(line) + except OSError as ex: + if ex.errno != errno.ENOENT: + raise + + return dep_dirs + + def _get_dev_run_script_contents(self, path_dirs): + path_entries = ["$env:PATH"] + path_dirs + path_str = ";".join(path_entries) + return """\ +$orig_env = $env:PATH +$env:PATH = "{path_str}" + +try {{ + $cmd_args = $args[1..$args.length] + & $args[0] @cmd_args +}} finally {{ + $env:PATH = $orig_env +}} +""".format( + path_str=path_str + ) + + +class ElfDeps(DepBase): + def __init__(self, buildopts, install_dirs, strip): + super(ElfDeps, self).__init__(buildopts, install_dirs, strip) + + # We need patchelf to rewrite deps, so ensure that it is built... + subprocess.check_call([sys.executable, sys.argv[0], "build", "patchelf"]) + # ... and that we know where it lives + self.patchelf = os.path.join( + os.fsdecode( + subprocess.check_output( + [sys.executable, sys.argv[0], "show-inst-dir", "patchelf"] + ).strip() + ), + "bin/patchelf", + ) + + def list_dynamic_deps(self, objfile): + out = ( + subprocess.check_output( + [self.patchelf, "--print-needed", objfile], env=dict(self.env.items()) + ) + .decode("utf-8") + .strip() + ) + lines = out.split("\n") + return lines + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + final_dep = os.path.join( + final_lib_dir, os.path.relpath(new_dep, self.munged_lib_dir) + ) + subprocess.check_call( + [self.patchelf, "--replace-needed", depname, final_dep, objfile] + ) + + def is_objfile(self, objfile): + if not os.path.isfile(objfile): + return False + with open(objfile, "rb") as f: + # https://en.wikipedia.org/wiki/Executable_and_Linkable_Format#File_header + magic = f.read(4) + return magic == b"\x7fELF" + + def strip_debug_info(self, objfile): + subprocess.check_call(["strip", objfile]) + + +# MACH-O magic number +MACH_MAGIC = 0xFEEDFACF + + +class MachDeps(DepBase): + def interesting_dep(self, d): + if d.startswith("/usr/lib/") or d.startswith("/System/"): + return False + return True + + def is_objfile(self, objfile): + if not os.path.isfile(objfile): + return False + with open(objfile, "rb") as f: + # mach stores the magic number in native endianness, + # so unpack as native here and compare + header = f.read(4) + if len(header) != 4: + return False + magic = unpack("I", header)[0] + return magic == MACH_MAGIC + + def list_dynamic_deps(self, objfile): + if not self.interesting_dep(objfile): + return + out = ( + subprocess.check_output( + ["otool", "-L", objfile], env=dict(self.env.items()) + ) + .decode("utf-8") + .strip() + ) + lines = out.split("\n") + deps = [] + for line in lines: + m = re.match("\t(\\S+)\\s", line) + if m: + if os.path.basename(m.group(1)) != os.path.basename(objfile): + deps.append(os.path.normcase(m.group(1))) + return deps + + def rewrite_dep(self, objfile, depname, old_dep, new_dep, final_lib_dir): + if objfile.endswith(".dylib"): + # Erase the original location from the id of the shared + # object. It doesn't appear to hurt to retain it, but + # it does look weird, so let's rewrite it to be sure. + subprocess.check_call( + ["install_name_tool", "-id", os.path.basename(objfile), objfile] + ) + final_dep = os.path.join( + final_lib_dir, os.path.relpath(new_dep, self.munged_lib_dir) + ) + + subprocess.check_call( + ["install_name_tool", "-change", depname, final_dep, objfile] + ) + + +def create_dyn_dep_munger(buildopts, install_dirs, strip=False): + if buildopts.is_linux(): + return ElfDeps(buildopts, install_dirs, strip) + if buildopts.is_darwin(): + return MachDeps(buildopts, install_dirs, strip) + if buildopts.is_windows(): + return WinDeps(buildopts, install_dirs, strip) diff --git a/build/fbcode_builder/getdeps/envfuncs.py b/build/fbcode_builder/getdeps/envfuncs.py new file mode 100644 index 000000000..f2e13f16f --- /dev/null +++ b/build/fbcode_builder/getdeps/envfuncs.py @@ -0,0 +1,195 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import shlex +import sys + + +class Env(object): + def __init__(self, src=None): + self._dict = {} + if src is None: + self.update(os.environ) + else: + self.update(src) + + def update(self, src): + for k, v in src.items(): + self.set(k, v) + + def copy(self): + return Env(self._dict) + + def _key(self, key): + # The `str` cast may not appear to be needed, but without it we run + # into issues when passing the environment to subprocess. The main + # issue is that in python2 `os.environ` (which is the initial source + # of data for the environment) uses byte based strings, but this + # project uses `unicode_literals`. `subprocess` will raise an error + # if the environment that it is passed has a mixture of byte and + # unicode strings. + # It is simplest to force everthing to be `str` for the sake of + # consistency. + key = str(key) + if sys.platform.startswith("win"): + # Windows env var names are case insensitive but case preserving. + # An implementation of PAR files on windows gets confused if + # the env block contains keys with conflicting case, so make a + # pass over the contents to remove any. + # While this O(n) scan is technically expensive and gross, it + # is practically not a problem because the volume of calls is + # relatively low and the cost of manipulating the env is dwarfed + # by the cost of spawning a process on windows. In addition, + # since the processes that we run are expensive anyway, this + # overhead is not the worst thing to worry about. + for k in list(self._dict.keys()): + if str(k).lower() == key.lower(): + return k + elif key in self._dict: + return key + return None + + def get(self, key, defval=None): + key = self._key(key) + if key is None: + return defval + return self._dict[key] + + def __getitem__(self, key): + val = self.get(key) + if key is None: + raise KeyError(key) + return val + + def unset(self, key): + if key is None: + raise KeyError("attempting to unset env[None]") + + key = self._key(key) + if key: + del self._dict[key] + + def __delitem__(self, key): + self.unset(key) + + def __repr__(self): + return repr(self._dict) + + def set(self, key, value): + if key is None: + raise KeyError("attempting to assign env[None] = %r" % value) + + if value is None: + raise ValueError("attempting to assign env[%s] = None" % key) + + # The `str` conversion is important to avoid triggering errors + # with subprocess if we pass in a unicode value; see commentary + # in the `_key` method. + key = str(key) + value = str(value) + + # The `unset` call is necessary on windows where the keys are + # case insensitive. Since this dict is case sensitive, simply + # assigning the value to the new key is not sufficient to remove + # the old value. The `unset` call knows how to match keys and + # remove any potential duplicates. + self.unset(key) + self._dict[key] = value + + def __setitem__(self, key, value): + self.set(key, value) + + def __iter__(self): + return self._dict.__iter__() + + def __len__(self): + return len(self._dict) + + def keys(self): + return self._dict.keys() + + def values(self): + return self._dict.values() + + def items(self): + return self._dict.items() + + +def add_path_entry(env, name, item, append=True, separator=os.pathsep): + """Cause `item` to be added to the path style env var named + `name` held in the `env` dict. `append` specifies whether + the item is added to the end (the default) or should be + prepended if `name` already exists.""" + val = env.get(name, "") + if len(val) > 0: + val = val.split(separator) + else: + val = [] + if append: + val.append(item) + else: + val.insert(0, item) + env.set(name, separator.join(val)) + + +def add_flag(env, name, flag, append=True): + """Cause `flag` to be added to the CXXFLAGS-style env var named + `name` held in the `env` dict. `append` specifies whether the + flag is added to the end (the default) or should be prepended if + `name` already exists.""" + val = shlex.split(env.get(name, "")) + if append: + val.append(flag) + else: + val.insert(0, flag) + env.set(name, " ".join(val)) + + +_path_search_cache = {} +_not_found = object() + + +def tpx_path(): + return "xplat/testinfra/tpx/ctp.tpx" + + +def path_search(env, exename, defval=None): + """Search for exename in the PATH specified in env. + exename is eg: `ninja` and this function knows to append a .exe + to the end on windows. + Returns the path to the exe if found, or None if either no + PATH is set in env or no executable is found.""" + + path = env.get("PATH", None) + if path is None: + return defval + + # The project hash computation code searches for C++ compilers (g++, clang, etc) + # repeatedly. Cache the result so we don't end up searching for these over and over + # again. + cache_key = (path, exename) + result = _path_search_cache.get(cache_key, _not_found) + if result is _not_found: + result = _perform_path_search(path, exename) + _path_search_cache[cache_key] = result + return result + + +def _perform_path_search(path, exename): + is_win = sys.platform.startswith("win") + if is_win: + exename = "%s.exe" % exename + + for bindir in path.split(os.pathsep): + full_name = os.path.join(bindir, exename) + if os.path.exists(full_name) and os.path.isfile(full_name): + if not is_win and not os.access(full_name, os.X_OK): + continue + return full_name + + return None diff --git a/build/fbcode_builder/getdeps/errors.py b/build/fbcode_builder/getdeps/errors.py new file mode 100644 index 000000000..3fad1a1de --- /dev/null +++ b/build/fbcode_builder/getdeps/errors.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + + +class TransientFailure(Exception): + """Raising this error causes getdeps to return with an error code + that Sandcastle will consider to be a retryable transient + infrastructure error""" + + pass + + +class ManifestNotFound(Exception): + def __init__(self, manifest_name): + super(Exception, self).__init__("Unable to find manifest '%s'" % manifest_name) diff --git a/build/fbcode_builder/getdeps/expr.py b/build/fbcode_builder/getdeps/expr.py new file mode 100644 index 000000000..6c0485d03 --- /dev/null +++ b/build/fbcode_builder/getdeps/expr.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import re +import shlex + + +def parse_expr(expr_text, valid_variables): + """parses the simple criteria expression syntax used in + dependency specifications. + Returns an ExprNode instance that can be evaluated like this: + + ``` + expr = parse_expr("os=windows") + ok = expr.eval({ + "os": "windows" + }) + ``` + + Whitespace is allowed between tokens. The following terms + are recognized: + + KEY = VALUE # Evaluates to True if ctx[KEY] == VALUE + not(EXPR) # Evaluates to True if EXPR evaluates to False + # and vice versa + all(EXPR1, EXPR2, ...) # Evaluates True if all of the supplied + # EXPR's also evaluate True + any(EXPR1, EXPR2, ...) # Evaluates True if any of the supplied + # EXPR's also evaluate True, False if + # none of them evaluated true. + """ + + p = Parser(expr_text, valid_variables) + return p.parse() + + +class ExprNode(object): + def eval(self, ctx): + return False + + +class TrueExpr(ExprNode): + def eval(self, ctx): + return True + + def __str__(self): + return "true" + + +class NotExpr(ExprNode): + def __init__(self, node): + self._node = node + + def eval(self, ctx): + return not self._node.eval(ctx) + + def __str__(self): + return "not(%s)" % self._node + + +class AllExpr(ExprNode): + def __init__(self, nodes): + self._nodes = nodes + + def eval(self, ctx): + for node in self._nodes: + if not node.eval(ctx): + return False + return True + + def __str__(self): + items = [] + for node in self._nodes: + items.append(str(node)) + return "all(%s)" % ",".join(items) + + +class AnyExpr(ExprNode): + def __init__(self, nodes): + self._nodes = nodes + + def eval(self, ctx): + for node in self._nodes: + if node.eval(ctx): + return True + return False + + def __str__(self): + items = [] + for node in self._nodes: + items.append(str(node)) + return "any(%s)" % ",".join(items) + + +class EqualExpr(ExprNode): + def __init__(self, key, value): + self._key = key + self._value = value + + def eval(self, ctx): + return ctx.get(self._key) == self._value + + def __str__(self): + return "%s=%s" % (self._key, self._value) + + +class Parser(object): + def __init__(self, text, valid_variables): + self.text = text + self.lex = shlex.shlex(text) + self.valid_variables = valid_variables + + def parse(self): + expr = self.top() + garbage = self.lex.get_token() + if garbage != "": + raise Exception( + "Unexpected token %s after EqualExpr in %s" % (garbage, self.text) + ) + return expr + + def top(self): + name = self.ident() + op = self.lex.get_token() + + if op == "(": + parsers = { + "not": self.parse_not, + "any": self.parse_any, + "all": self.parse_all, + } + func = parsers.get(name) + if not func: + raise Exception("invalid term %s in %s" % (name, self.text)) + return func() + + if op == "=": + if name not in self.valid_variables: + raise Exception("unknown variable %r in expression" % (name,)) + return EqualExpr(name, self.lex.get_token()) + + raise Exception( + "Unexpected token sequence '%s %s' in %s" % (name, op, self.text) + ) + + def ident(self): + ident = self.lex.get_token() + if not re.match("[a-zA-Z]+", ident): + raise Exception("expected identifier found %s" % ident) + return ident + + def parse_not(self): + node = self.top() + expr = NotExpr(node) + tok = self.lex.get_token() + if tok != ")": + raise Exception("expected ')' found %s" % tok) + return expr + + def parse_any(self): + nodes = [] + while True: + nodes.append(self.top()) + tok = self.lex.get_token() + if tok == ")": + break + if tok != ",": + raise Exception("expected ',' or ')' but found %s" % tok) + return AnyExpr(nodes) + + def parse_all(self): + nodes = [] + while True: + nodes.append(self.top()) + tok = self.lex.get_token() + if tok == ")": + break + if tok != ",": + raise Exception("expected ',' or ')' but found %s" % tok) + return AllExpr(nodes) diff --git a/build/fbcode_builder/getdeps/fetcher.py b/build/fbcode_builder/getdeps/fetcher.py new file mode 100644 index 000000000..041549ad7 --- /dev/null +++ b/build/fbcode_builder/getdeps/fetcher.py @@ -0,0 +1,771 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import errno +import hashlib +import os +import re +import shutil +import stat +import subprocess +import sys +import tarfile +import time +import zipfile +from datetime import datetime +from typing import Dict, NamedTuple + +from .copytree import prefetch_dir_if_eden +from .envfuncs import Env +from .errors import TransientFailure +from .platform import is_windows +from .runcmd import run_cmd + + +try: + from urllib import urlretrieve + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse + from urllib.request import urlretrieve + + +def file_name_is_cmake_file(file_name): + file_name = file_name.lower() + base = os.path.basename(file_name) + return ( + base.endswith(".cmake") + or base.endswith(".cmake.in") + or base == "cmakelists.txt" + ) + + +class ChangeStatus(object): + """Indicates the nature of changes that happened while updating + the source directory. There are two broad uses: + * When extracting archives for third party software we want to + know that we did something (eg: we either extracted code or + we didn't do anything) + * For 1st party code where we use shipit to transform the code, + we want to know if we changed anything so that we can perform + a build, but we generally want to be a little more nuanced + and be able to distinguish between just changing a source file + and whether we might need to reconfigure the build system. + """ + + def __init__(self, all_changed=False): + """Construct a ChangeStatus object. The default is to create + a status that indicates no changes, but passing all_changed=True + will create one that indicates that everything changed""" + if all_changed: + self.source_files = 1 + self.make_files = 1 + else: + self.source_files = 0 + self.make_files = 0 + + def record_change(self, file_name): + """Used by the shipit fetcher to record changes as it updates + files in the destination. If the file name might be one used + in the cmake build system that we use for 1st party code, then + record that as a "make file" change. We could broaden this + to match any file used by various build systems, but it is + only really useful for our internal cmake stuff at this time. + If the file isn't a build file and is under the `fbcode_builder` + dir then we don't class that as an interesting change that we + might need to rebuild, so we ignore it. + Otherwise we record the file as a source file change.""" + + file_name = file_name.lower() + if file_name_is_cmake_file(file_name): + self.make_files += 1 + elif "/fbcode_builder/cmake" in file_name: + self.source_files += 1 + elif "/fbcode_builder/" not in file_name: + self.source_files += 1 + + def sources_changed(self): + """Returns true if any source files were changed during + an update operation. This will typically be used to decide + that the build system to be run on the source dir in an + incremental mode""" + return self.source_files > 0 + + def build_changed(self): + """Returns true if any build files were changed during + an update operation. This will typically be used to decidfe + that the build system should be reconfigured and re-run + as a full build""" + return self.make_files > 0 + + +class Fetcher(object): + """The Fetcher is responsible for fetching and extracting the + sources for project. The Fetcher instance defines where the + extracted data resides and reports this to the consumer via + its `get_src_dir` method.""" + + def update(self): + """Brings the src dir up to date, ideally minimizing + changes so that a subsequent build doesn't over-build. + Returns a ChangeStatus object that helps the caller to + understand the nature of the changes required during + the update.""" + return ChangeStatus() + + def clean(self): + """Reverts any changes that might have been made to + the src dir""" + pass + + def hash(self): + """Returns a hash that identifies the version of the code in the + working copy. For a git repo this is commit hash for the working + copy. For other Fetchers this should relate to the version of + the code in the src dir. The intent is that if a manifest + changes the version/rev of a project that the hash be different. + Importantly, this should be computable without actually fetching + the code, as we want this to factor into a hash used to download + a pre-built version of the code, without having to first download + and extract its sources (eg: boost on windows is pretty painful). + """ + pass + + def get_src_dir(self): + """Returns the source directory that the project was + extracted into""" + pass + + +class LocalDirFetcher(object): + """This class exists to override the normal fetching behavior, and + use an explicit user-specified directory for the project sources. + + This fetcher cannot update or track changes. It always reports that the + project has changed, forcing it to always be built.""" + + def __init__(self, path): + self.path = os.path.realpath(path) + + def update(self): + return ChangeStatus(all_changed=True) + + def hash(self): + return "0" * 40 + + def get_src_dir(self): + return self.path + + +class SystemPackageFetcher(object): + def __init__(self, build_options, packages): + self.manager = build_options.host_type.get_package_manager() + self.packages = packages.get(self.manager) + if self.packages: + self.installed = None + else: + self.installed = False + + def packages_are_installed(self): + if self.installed is not None: + return self.installed + + if self.manager == "rpm": + result = run_cmd(["rpm", "-q"] + self.packages, allow_fail=True) + self.installed = result == 0 + elif self.manager == "deb": + result = run_cmd(["dpkg", "-s"] + self.packages, allow_fail=True) + self.installed = result == 0 + else: + self.installed = False + + return self.installed + + def update(self): + assert self.installed + return ChangeStatus(all_changed=False) + + def hash(self): + return "0" * 40 + + def get_src_dir(self): + return None + + +class PreinstalledNopFetcher(SystemPackageFetcher): + def __init__(self): + self.installed = True + + +class GitFetcher(Fetcher): + DEFAULT_DEPTH = 1 + + def __init__(self, build_options, manifest, repo_url, rev, depth): + # Extract the host/path portions of the URL and generate a flattened + # directory name. eg: + # github.com/facebook/folly.git -> github.com-facebook-folly.git + url = urlparse(repo_url) + directory = "%s%s" % (url.netloc, url.path) + for s in ["/", "\\", ":"]: + directory = directory.replace(s, "-") + + # Place it in a repos dir in the scratch space + repos_dir = os.path.join(build_options.scratch_dir, "repos") + if not os.path.exists(repos_dir): + os.makedirs(repos_dir) + self.repo_dir = os.path.join(repos_dir, directory) + + if not rev and build_options.project_hashes: + hash_file = os.path.join( + build_options.project_hashes, + re.sub("\\.git$", "-rev.txt", url.path[1:]), + ) + if os.path.exists(hash_file): + with open(hash_file, "r") as f: + data = f.read() + m = re.match("Subproject commit ([a-fA-F0-9]{40})", data) + if not m: + raise Exception("Failed to parse rev from %s" % hash_file) + rev = m.group(1) + print("Using pinned rev %s for %s" % (rev, repo_url)) + + self.rev = rev or "master" + self.origin_repo = repo_url + self.manifest = manifest + self.depth = depth if depth else GitFetcher.DEFAULT_DEPTH + + def _update(self): + current_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=self.repo_dir) + .strip() + .decode("utf-8") + ) + target_hash = ( + subprocess.check_output(["git", "rev-parse", self.rev], cwd=self.repo_dir) + .strip() + .decode("utf-8") + ) + if target_hash == current_hash: + # It's up to date, so there are no changes. This doesn't detect eg: + # if origin/master moved and rev='master', but that's ok for our purposes; + # we should be using explicit hashes or eg: a stable branch for the cases + # that we care about, and it isn't unreasonable to require that the user + # explicitly perform a clean build if those have moved. For the most + # part we prefer that folks build using a release tarball from github + # rather than use the git protocol, as it is generally a bit quicker + # to fetch and easier to hash and verify tarball downloads. + return ChangeStatus() + + print("Updating %s -> %s" % (self.repo_dir, self.rev)) + run_cmd(["git", "fetch", "origin", self.rev], cwd=self.repo_dir) + run_cmd(["git", "checkout", self.rev], cwd=self.repo_dir) + run_cmd(["git", "submodule", "update", "--init"], cwd=self.repo_dir) + + return ChangeStatus(True) + + def update(self): + if os.path.exists(self.repo_dir): + return self._update() + self._clone() + return ChangeStatus(True) + + def _clone(self): + print("Cloning %s..." % self.origin_repo) + # The basename/dirname stuff allows us to dance around issues where + # eg: this python process is native win32, but the git.exe is cygwin + # or msys and doesn't like the absolute windows path that we'd otherwise + # pass to it. Careful use of cwd helps avoid headaches with cygpath. + run_cmd( + [ + "git", + "clone", + "--depth=" + str(self.depth), + "--", + self.origin_repo, + os.path.basename(self.repo_dir), + ], + cwd=os.path.dirname(self.repo_dir), + ) + self._update() + + def clean(self): + if os.path.exists(self.repo_dir): + run_cmd(["git", "clean", "-fxd"], cwd=self.repo_dir) + + def hash(self): + return self.rev + + def get_src_dir(self): + return self.repo_dir + + +def does_file_need_update(src_name, src_st, dest_name): + try: + target_st = os.lstat(dest_name) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + return True + + if src_st.st_size != target_st.st_size: + return True + + if stat.S_IFMT(src_st.st_mode) != stat.S_IFMT(target_st.st_mode): + return True + if stat.S_ISLNK(src_st.st_mode): + return os.readlink(src_name) != os.readlink(dest_name) + if not stat.S_ISREG(src_st.st_mode): + return True + + # They might have the same content; compare. + with open(src_name, "rb") as sf, open(dest_name, "rb") as df: + chunk_size = 8192 + while True: + src_data = sf.read(chunk_size) + dest_data = df.read(chunk_size) + if src_data != dest_data: + return True + if len(src_data) < chunk_size: + # EOF + break + return False + + +def copy_if_different(src_name, dest_name): + """Copy src_name -> dest_name, but only touch dest_name + if src_name is different from dest_name, making this a + more build system friendly way to copy.""" + src_st = os.lstat(src_name) + if not does_file_need_update(src_name, src_st, dest_name): + return False + + dest_parent = os.path.dirname(dest_name) + if not os.path.exists(dest_parent): + os.makedirs(dest_parent) + if stat.S_ISLNK(src_st.st_mode): + try: + os.unlink(dest_name) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + target = os.readlink(src_name) + print("Symlinking %s -> %s" % (dest_name, target)) + os.symlink(target, dest_name) + else: + print("Copying %s -> %s" % (src_name, dest_name)) + shutil.copy2(src_name, dest_name) + + return True + + +def list_files_under_dir_newer_than_timestamp(dir_to_scan, ts): + for root, _dirs, files in os.walk(dir_to_scan): + for src_file in files: + full_name = os.path.join(root, src_file) + st = os.lstat(full_name) + if st.st_mtime > ts: + yield full_name + + +class ShipitPathMap(object): + def __init__(self): + self.roots = [] + self.mapping = [] + self.exclusion = [] + + def add_mapping(self, fbsource_dir, target_dir): + """Add a posix path or pattern. We cannot normpath the input + here because that would change the paths from posix to windows + form and break the logic throughout this class.""" + self.roots.append(fbsource_dir) + self.mapping.append((fbsource_dir, target_dir)) + + def add_exclusion(self, pattern): + self.exclusion.append(re.compile(pattern)) + + def _minimize_roots(self): + """compute the de-duplicated set of roots within fbsource. + We take the shortest common directory prefix to make this + determination""" + self.roots.sort(key=len) + minimized = [] + + for r in self.roots: + add_this_entry = True + for existing in minimized: + if r.startswith(existing + "/"): + add_this_entry = False + break + if add_this_entry: + minimized.append(r) + + self.roots = minimized + + def _sort_mapping(self): + self.mapping.sort(reverse=True, key=lambda x: len(x[0])) + + def _map_name(self, norm_name, dest_root): + if norm_name.endswith(".pyc") or norm_name.endswith(".swp"): + # Ignore some incidental garbage while iterating + return None + + for excl in self.exclusion: + if excl.match(norm_name): + return None + + for src_name, dest_name in self.mapping: + if norm_name == src_name or norm_name.startswith(src_name + "/"): + rel_name = os.path.relpath(norm_name, src_name) + # We can have "." as a component of some paths, depending + # on the contents of the shipit transformation section. + # normpath doesn't always remove `.` as the final component + # of the path, which be problematic when we later mkdir + # the dirname of the path that we return. Take care to avoid + # returning a path with a `.` in it. + rel_name = os.path.normpath(rel_name) + if dest_name == ".": + return os.path.normpath(os.path.join(dest_root, rel_name)) + dest_name = os.path.normpath(dest_name) + return os.path.normpath(os.path.join(dest_root, dest_name, rel_name)) + + raise Exception("%s did not match any rules" % norm_name) + + def mirror(self, fbsource_root, dest_root): + self._minimize_roots() + self._sort_mapping() + + change_status = ChangeStatus() + + # Record the full set of files that should be in the tree + full_file_list = set() + + for fbsource_subdir in self.roots: + dir_to_mirror = os.path.join(fbsource_root, fbsource_subdir) + prefetch_dir_if_eden(dir_to_mirror) + if not os.path.exists(dir_to_mirror): + raise Exception( + "%s doesn't exist; check your sparse profile!" % dir_to_mirror + ) + for root, _dirs, files in os.walk(dir_to_mirror): + for src_file in files: + full_name = os.path.join(root, src_file) + rel_name = os.path.relpath(full_name, fbsource_root) + norm_name = rel_name.replace("\\", "/") + + target_name = self._map_name(norm_name, dest_root) + if target_name: + full_file_list.add(target_name) + if copy_if_different(full_name, target_name): + change_status.record_change(target_name) + + # Compare the list of previously shipped files; if a file is + # in the old list but not the new list then it has been + # removed from the source and should be removed from the + # destination. + # Why don't we simply create this list by walking dest_root? + # Some builds currently have to be in-source builds and + # may legitimately need to keep some state in the source tree :-/ + installed_name = os.path.join(dest_root, ".shipit_shipped") + if os.path.exists(installed_name): + with open(installed_name, "rb") as f: + for name in f.read().decode("utf-8").splitlines(): + name = name.strip() + if name not in full_file_list: + print("Remove %s" % name) + os.unlink(name) + change_status.record_change(name) + + with open(installed_name, "wb") as f: + for name in sorted(list(full_file_list)): + f.write(("%s\n" % name).encode("utf-8")) + + return change_status + + +class FbsourceRepoData(NamedTuple): + hash: str + date: str + + +FBSOURCE_REPO_DATA: Dict[str, FbsourceRepoData] = {} + + +def get_fbsource_repo_data(build_options): + """Returns the commit metadata for the fbsource repo. + Since we may have multiple first party projects to + hash, and because we don't mutate the repo, we cache + this hash in a global.""" + cached_data = FBSOURCE_REPO_DATA.get(build_options.fbsource_dir) + if cached_data: + return cached_data + + cmd = ["hg", "log", "-r.", "-T{node}\n{date|hgdate}"] + env = Env() + env.set("HGPLAIN", "1") + log_data = subprocess.check_output( + cmd, cwd=build_options.fbsource_dir, env=dict(env.items()) + ).decode("ascii") + + (hash, datestr) = log_data.split("\n") + + # datestr is like "seconds fractionalseconds" + # We want "20200324.113140" + (unixtime, _fractional) = datestr.split(" ") + date = datetime.fromtimestamp(int(unixtime)).strftime("%Y%m%d.%H%M%S") + cached_data = FbsourceRepoData(hash=hash, date=date) + + FBSOURCE_REPO_DATA[build_options.fbsource_dir] = cached_data + + return cached_data + + +class SimpleShipitTransformerFetcher(Fetcher): + def __init__(self, build_options, manifest): + self.build_options = build_options + self.manifest = manifest + self.repo_dir = os.path.join(build_options.scratch_dir, "shipit", manifest.name) + + def clean(self): + if os.path.exists(self.repo_dir): + shutil.rmtree(self.repo_dir) + + def update(self): + mapping = ShipitPathMap() + for src, dest in self.manifest.get_section_as_ordered_pairs("shipit.pathmap"): + mapping.add_mapping(src, dest) + if self.manifest.shipit_fbcode_builder: + mapping.add_mapping( + "fbcode/opensource/fbcode_builder", "build/fbcode_builder" + ) + for pattern in self.manifest.get_section_as_args("shipit.strip"): + mapping.add_exclusion(pattern) + + return mapping.mirror(self.build_options.fbsource_dir, self.repo_dir) + + def hash(self): + # We return a fixed non-hash string for in-fbsource builds. + # We're relying on the `update` logic to correctly invalidate + # the build in the case that files have changed. + return "fbsource" + + def get_src_dir(self): + return self.repo_dir + + +class ShipitTransformerFetcher(Fetcher): + SHIPIT = "/var/www/scripts/opensource/shipit/run_shipit.php" + + def __init__(self, build_options, project_name): + self.build_options = build_options + self.project_name = project_name + self.repo_dir = os.path.join(build_options.scratch_dir, "shipit", project_name) + + def update(self): + if os.path.exists(self.repo_dir): + return ChangeStatus() + self.run_shipit() + return ChangeStatus(True) + + def clean(self): + if os.path.exists(self.repo_dir): + shutil.rmtree(self.repo_dir) + + @classmethod + def available(cls): + return os.path.exists(cls.SHIPIT) + + def run_shipit(self): + tmp_path = self.repo_dir + ".new" + try: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + + # Run shipit + run_cmd( + [ + "php", + ShipitTransformerFetcher.SHIPIT, + "--project=" + self.project_name, + "--create-new-repo", + "--source-repo-dir=" + self.build_options.fbsource_dir, + "--source-branch=.", + "--skip-source-init", + "--skip-source-pull", + "--skip-source-clean", + "--skip-push", + "--skip-reset", + "--destination-use-anonymous-https", + "--create-new-repo-output-path=" + tmp_path, + ] + ) + + # Remove the .git directory from the repository it generated. + # There is no need to commit this. + repo_git_dir = os.path.join(tmp_path, ".git") + shutil.rmtree(repo_git_dir) + os.rename(tmp_path, self.repo_dir) + except Exception: + # Clean up after a failed extraction + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + self.clean() + raise + + def hash(self): + # We return a fixed non-hash string for in-fbsource builds. + return "fbsource" + + def get_src_dir(self): + return self.repo_dir + + +def download_url_to_file_with_progress(url, file_name): + print("Download %s -> %s ..." % (url, file_name)) + + class Progress(object): + last_report = 0 + + def progress(self, count, block, total): + if total == -1: + total = "(Unknown)" + amount = count * block + + if sys.stdout.isatty(): + sys.stdout.write("\r downloading %s of %s " % (amount, total)) + else: + # When logging to CI logs, avoid spamming the logs and print + # status every few seconds + now = time.time() + if now - self.last_report > 5: + sys.stdout.write(".. %s of %s " % (amount, total)) + self.last_report = now + sys.stdout.flush() + + progress = Progress() + start = time.time() + try: + (_filename, headers) = urlretrieve(url, file_name, reporthook=progress.progress) + except (OSError, IOError) as exc: # noqa: B014 + raise TransientFailure( + "Failed to download %s to %s: %s" % (url, file_name, str(exc)) + ) + + end = time.time() + sys.stdout.write(" [Complete in %f seconds]\n" % (end - start)) + sys.stdout.flush() + print(f"{headers}") + + +class ArchiveFetcher(Fetcher): + def __init__(self, build_options, manifest, url, sha256): + self.manifest = manifest + self.url = url + self.sha256 = sha256 + self.build_options = build_options + + url = urlparse(self.url) + basename = "%s-%s" % (manifest.name, os.path.basename(url.path)) + self.file_name = os.path.join(build_options.scratch_dir, "downloads", basename) + self.src_dir = os.path.join(build_options.scratch_dir, "extracted", basename) + self.hash_file = self.src_dir + ".hash" + + def _verify_hash(self): + h = hashlib.sha256() + with open(self.file_name, "rb") as f: + while True: + block = f.read(8192) + if not block: + break + h.update(block) + digest = h.hexdigest() + if digest != self.sha256: + os.unlink(self.file_name) + raise Exception( + "%s: expected sha256 %s but got %s" % (self.url, self.sha256, digest) + ) + + def _download_dir(self): + """returns the download dir, creating it if it doesn't already exist""" + download_dir = os.path.dirname(self.file_name) + if not os.path.exists(download_dir): + os.makedirs(download_dir) + return download_dir + + def _download(self): + self._download_dir() + download_url_to_file_with_progress(self.url, self.file_name) + self._verify_hash() + + def clean(self): + if os.path.exists(self.src_dir): + shutil.rmtree(self.src_dir) + + def update(self): + try: + with open(self.hash_file, "r") as f: + saved_hash = f.read().strip() + if saved_hash == self.sha256 and os.path.exists(self.src_dir): + # Everything is up to date + return ChangeStatus() + print( + "saved hash %s doesn't match expected hash %s, re-validating" + % (saved_hash, self.sha256) + ) + os.unlink(self.hash_file) + except EnvironmentError: + pass + + # If we got here we know the contents of src_dir are either missing + # or wrong, so blow away whatever happened to be there first. + if os.path.exists(self.src_dir): + shutil.rmtree(self.src_dir) + + # If we already have a file here, make sure it looks legit before + # proceeding: any errors and we just remove it and re-download + if os.path.exists(self.file_name): + try: + self._verify_hash() + except Exception: + if os.path.exists(self.file_name): + os.unlink(self.file_name) + + if not os.path.exists(self.file_name): + self._download() + + if tarfile.is_tarfile(self.file_name): + opener = tarfile.open + elif zipfile.is_zipfile(self.file_name): + opener = zipfile.ZipFile + else: + raise Exception("don't know how to extract %s" % self.file_name) + os.makedirs(self.src_dir) + print("Extract %s -> %s" % (self.file_name, self.src_dir)) + t = opener(self.file_name) + if is_windows(): + # Ensure that we don't fall over when dealing with long paths + # on windows + src = r"\\?\%s" % os.path.normpath(self.src_dir) + else: + src = self.src_dir + # The `str` here is necessary to ensure that we don't pass a unicode + # object down to tarfile.extractall on python2. When extracting + # the boost tarball it makes some assumptions and tries to convert + # a non-ascii path to ascii and throws. + src = str(src) + t.extractall(src) + + with open(self.hash_file, "w") as f: + f.write(self.sha256) + + return ChangeStatus(True) + + def hash(self): + return self.sha256 + + def get_src_dir(self): + return self.src_dir diff --git a/build/fbcode_builder/getdeps/load.py b/build/fbcode_builder/getdeps/load.py new file mode 100644 index 000000000..c5f40d2fa --- /dev/null +++ b/build/fbcode_builder/getdeps/load.py @@ -0,0 +1,354 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import base64 +import hashlib +import os + +from . import fetcher +from .envfuncs import path_search +from .errors import ManifestNotFound +from .manifest import ManifestParser + + +class Loader(object): + """The loader allows our tests to patch the load operation""" + + def _list_manifests(self, build_opts): + """Returns a generator that iterates all the available manifests""" + for (path, _, files) in os.walk(build_opts.manifests_dir): + for name in files: + # skip hidden files + if name.startswith("."): + continue + + yield os.path.join(path, name) + + def _load_manifest(self, path): + return ManifestParser(path) + + def load_project(self, build_opts, project_name): + if "/" in project_name or "\\" in project_name: + # Assume this is a path already + return ManifestParser(project_name) + + for manifest in self._list_manifests(build_opts): + if os.path.basename(manifest) == project_name: + return ManifestParser(manifest) + + raise ManifestNotFound(project_name) + + def load_all(self, build_opts): + manifests_by_name = {} + + for manifest in self._list_manifests(build_opts): + m = self._load_manifest(manifest) + + if m.name in manifests_by_name: + raise Exception("found duplicate manifest '%s'" % m.name) + + manifests_by_name[m.name] = m + + return manifests_by_name + + +class ResourceLoader(Loader): + def __init__(self, namespace, manifests_dir): + self.namespace = namespace + self.manifests_dir = manifests_dir + + def _list_manifests(self, _build_opts): + import pkg_resources + + dirs = [self.manifests_dir] + + while dirs: + current = dirs.pop(0) + for name in pkg_resources.resource_listdir(self.namespace, current): + path = "%s/%s" % (current, name) + + if pkg_resources.resource_isdir(self.namespace, path): + dirs.append(path) + else: + yield "%s/%s" % (current, name) + + def _find_manifest(self, project_name): + for name in self._list_manifests(): + if name.endswith("/%s" % project_name): + return name + + raise ManifestNotFound(project_name) + + def _load_manifest(self, path): + import pkg_resources + + contents = pkg_resources.resource_string(self.namespace, path).decode("utf8") + return ManifestParser(file_name=path, fp=contents) + + def load_project(self, build_opts, project_name): + project_name = self._find_manifest(project_name) + return self._load_resource_manifest(project_name) + + +LOADER = Loader() + + +def patch_loader(namespace, manifests_dir="manifests"): + global LOADER + LOADER = ResourceLoader(namespace, manifests_dir) + + +def load_project(build_opts, project_name): + """given the name of a project or a path to a manifest file, + load up the ManifestParser instance for it and return it""" + return LOADER.load_project(build_opts, project_name) + + +def load_all_manifests(build_opts): + return LOADER.load_all(build_opts) + + +class ManifestLoader(object): + """ManifestLoader stores information about project manifest relationships for a + given set of (build options + platform) configuration. + + The ManifestLoader class primarily serves as a location to cache project dependency + relationships and project hash values for this build configuration. + """ + + def __init__(self, build_opts, ctx_gen=None): + self._loader = LOADER + self.build_opts = build_opts + if ctx_gen is None: + self.ctx_gen = self.build_opts.get_context_generator() + else: + self.ctx_gen = ctx_gen + + self.manifests_by_name = {} + self._loaded_all = False + self._project_hashes = {} + self._fetcher_overrides = {} + self._build_dir_overrides = {} + self._install_dir_overrides = {} + self._install_prefix_overrides = {} + + def load_manifest(self, name): + manifest = self.manifests_by_name.get(name) + if manifest is None: + manifest = self._loader.load_project(self.build_opts, name) + self.manifests_by_name[name] = manifest + return manifest + + def load_all_manifests(self): + if not self._loaded_all: + all_manifests_by_name = self._loader.load_all(self.build_opts) + if self.manifests_by_name: + # To help ensure that we only ever have a single manifest object for a + # given project, and that it can't change once we have loaded it, + # only update our mapping for projects that weren't already loaded. + for name, manifest in all_manifests_by_name.items(): + self.manifests_by_name.setdefault(name, manifest) + else: + self.manifests_by_name = all_manifests_by_name + self._loaded_all = True + + return self.manifests_by_name + + def manifests_in_dependency_order(self, manifest=None): + """Compute all dependencies of the specified project. Returns a list of the + dependencies plus the project itself, in topologically sorted order. + + Each entry in the returned list only depends on projects that appear before it + in the list. + + If the input manifest is None, the dependencies for all currently loaded + projects will be computed. i.e., if you call load_all_manifests() followed by + manifests_in_dependency_order() this will return a global dependency ordering of + all projects.""" + # The list of deps that have been fully processed + seen = set() + # The list of deps which have yet to be evaluated. This + # can potentially contain duplicates. + if manifest is None: + deps = list(self.manifests_by_name.values()) + else: + assert manifest.name in self.manifests_by_name + deps = [manifest] + # The list of manifests in dependency order + dep_order = [] + + while len(deps) > 0: + m = deps.pop(0) + if m.name in seen: + continue + + # Consider its deps, if any. + # We sort them for increased determinism; we'll produce + # a correct order even if they aren't sorted, but we prefer + # to produce the same order regardless of how they are listed + # in the project manifest files. + ctx = self.ctx_gen.get_context(m.name) + dep_list = sorted(m.get_section_as_dict("dependencies", ctx).keys()) + builder = m.get("build", "builder", ctx=ctx) + if builder in ("cmake", "python-wheel"): + dep_list.append("cmake") + elif builder == "autoconf" and m.name not in ( + "autoconf", + "libtool", + "automake", + ): + # they need libtool and its deps (automake, autoconf) so add + # those as deps (but obviously not if we're building those + # projects themselves) + dep_list.append("libtool") + + dep_count = 0 + for dep_name in dep_list: + # If we're not sure whether it is done, queue it up + if dep_name not in seen: + dep = self.manifests_by_name.get(dep_name) + if dep is None: + dep = self._loader.load_project(self.build_opts, dep_name) + self.manifests_by_name[dep.name] = dep + + deps.append(dep) + dep_count += 1 + + if dep_count > 0: + # If we queued anything, re-queue this item, as it depends + # those new item(s) and their transitive deps. + deps.append(m) + continue + + # Its deps are done, so we can emit it + seen.add(m.name) + dep_order.append(m) + + return dep_order + + def set_project_src_dir(self, project_name, path): + self._fetcher_overrides[project_name] = fetcher.LocalDirFetcher(path) + + def set_project_build_dir(self, project_name, path): + self._build_dir_overrides[project_name] = path + + def set_project_install_dir(self, project_name, path): + self._install_dir_overrides[project_name] = path + + def set_project_install_prefix(self, project_name, path): + self._install_prefix_overrides[project_name] = path + + def create_fetcher(self, manifest): + override = self._fetcher_overrides.get(manifest.name) + if override is not None: + return override + + ctx = self.ctx_gen.get_context(manifest.name) + return manifest.create_fetcher(self.build_opts, ctx) + + def get_project_hash(self, manifest): + h = self._project_hashes.get(manifest.name) + if h is None: + h = self._compute_project_hash(manifest) + self._project_hashes[manifest.name] = h + return h + + def _compute_project_hash(self, manifest): + """This recursive function computes a hash for a given manifest. + The hash takes into account some environmental factors on the + host machine and includes the hashes of its dependencies. + No caching of the computation is performed, which is theoretically + wasteful but the computation is fast enough that it is not required + to cache across multiple invocations.""" + ctx = self.ctx_gen.get_context(manifest.name) + + hasher = hashlib.sha256() + # Some environmental and configuration things matter + env = {} + env["install_dir"] = self.build_opts.install_dir + env["scratch_dir"] = self.build_opts.scratch_dir + env["vcvars_path"] = self.build_opts.vcvars_path + env["os"] = self.build_opts.host_type.ostype + env["distro"] = self.build_opts.host_type.distro + env["distro_vers"] = self.build_opts.host_type.distrovers + for name in [ + "CXXFLAGS", + "CPPFLAGS", + "LDFLAGS", + "CXX", + "CC", + "GETDEPS_CMAKE_DEFINES", + ]: + env[name] = os.environ.get(name) + for tool in ["cc", "c++", "gcc", "g++", "clang", "clang++"]: + env["tool-%s" % tool] = path_search(os.environ, tool) + for name in manifest.get_section_as_args("depends.environment", ctx): + env[name] = os.environ.get(name) + + fetcher = self.create_fetcher(manifest) + env["fetcher.hash"] = fetcher.hash() + + for name in sorted(env.keys()): + hasher.update(name.encode("utf-8")) + value = env.get(name) + if value is not None: + try: + hasher.update(value.encode("utf-8")) + except AttributeError as exc: + raise AttributeError("name=%r, value=%r: %s" % (name, value, exc)) + + manifest.update_hash(hasher, ctx) + + dep_list = sorted(manifest.get_section_as_dict("dependencies", ctx).keys()) + for dep in dep_list: + dep_manifest = self.load_manifest(dep) + dep_hash = self.get_project_hash(dep_manifest) + hasher.update(dep_hash.encode("utf-8")) + + # Use base64 to represent the hash, rather than the simple hex digest, + # so that the string is shorter. Use the URL-safe encoding so that + # the hash can also be safely used as a filename component. + h = base64.urlsafe_b64encode(hasher.digest()).decode("ascii") + # ... and because cmd.exe is troublesome with `=` signs, nerf those. + # They tend to be padding characters at the end anyway, so we can + # safely discard them. + h = h.replace("=", "") + + return h + + def _get_project_dir_name(self, manifest): + if manifest.is_first_party_project(): + return manifest.name + else: + project_hash = self.get_project_hash(manifest) + return "%s-%s" % (manifest.name, project_hash) + + def get_project_install_dir(self, manifest): + override = self._install_dir_overrides.get(manifest.name) + if override: + return override + + project_dir_name = self._get_project_dir_name(manifest) + return os.path.join(self.build_opts.install_dir, project_dir_name) + + def get_project_build_dir(self, manifest): + override = self._build_dir_overrides.get(manifest.name) + if override: + return override + + project_dir_name = self._get_project_dir_name(manifest) + return os.path.join(self.build_opts.scratch_dir, "build", project_dir_name) + + def get_project_install_prefix(self, manifest): + return self._install_prefix_overrides.get(manifest.name) + + def get_project_install_dir_respecting_install_prefix(self, manifest): + inst_dir = self.get_project_install_dir(manifest) + prefix = self.get_project_install_prefix(manifest) + if prefix: + return inst_dir + prefix + return inst_dir diff --git a/build/fbcode_builder/getdeps/manifest.py b/build/fbcode_builder/getdeps/manifest.py new file mode 100644 index 000000000..71566d659 --- /dev/null +++ b/build/fbcode_builder/getdeps/manifest.py @@ -0,0 +1,606 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import io +import os + +from .builder import ( + AutoconfBuilder, + Boost, + CargoBuilder, + CMakeBuilder, + BistroBuilder, + Iproute2Builder, + MakeBuilder, + NinjaBootstrap, + NopBuilder, + OpenNSABuilder, + OpenSSLBuilder, + SqliteBuilder, + CMakeBootStrapBuilder, +) +from .expr import parse_expr +from .fetcher import ( + ArchiveFetcher, + GitFetcher, + PreinstalledNopFetcher, + ShipitTransformerFetcher, + SimpleShipitTransformerFetcher, + SystemPackageFetcher, +) +from .py_wheel_builder import PythonWheelBuilder + + +try: + import configparser +except ImportError: + import ConfigParser as configparser + +REQUIRED = "REQUIRED" +OPTIONAL = "OPTIONAL" + +SCHEMA = { + "manifest": { + "optional_section": False, + "fields": { + "name": REQUIRED, + "fbsource_path": OPTIONAL, + "shipit_project": OPTIONAL, + "shipit_fbcode_builder": OPTIONAL, + }, + }, + "dependencies": {"optional_section": True, "allow_values": False}, + "depends.environment": {"optional_section": True}, + "git": { + "optional_section": True, + "fields": {"repo_url": REQUIRED, "rev": OPTIONAL, "depth": OPTIONAL}, + }, + "download": { + "optional_section": True, + "fields": {"url": REQUIRED, "sha256": REQUIRED}, + }, + "build": { + "optional_section": True, + "fields": { + "builder": REQUIRED, + "subdir": OPTIONAL, + "build_in_src_dir": OPTIONAL, + "disable_env_override_pkgconfig": OPTIONAL, + "disable_env_override_path": OPTIONAL, + }, + }, + "msbuild": {"optional_section": True, "fields": {"project": REQUIRED}}, + "cargo": { + "optional_section": True, + "fields": { + "build_doc": OPTIONAL, + "workspace_dir": OPTIONAL, + "manifests_to_build": OPTIONAL, + }, + }, + "cmake.defines": {"optional_section": True}, + "autoconf.args": {"optional_section": True}, + "rpms": {"optional_section": True}, + "debs": {"optional_section": True}, + "preinstalled.env": {"optional_section": True}, + "b2.args": {"optional_section": True}, + "make.build_args": {"optional_section": True}, + "make.install_args": {"optional_section": True}, + "make.test_args": {"optional_section": True}, + "header-only": {"optional_section": True, "fields": {"includedir": REQUIRED}}, + "shipit.pathmap": {"optional_section": True}, + "shipit.strip": {"optional_section": True}, + "install.files": {"optional_section": True}, +} + +# These sections are allowed to vary for different platforms +# using the expression syntax to enable/disable sections +ALLOWED_EXPR_SECTIONS = [ + "autoconf.args", + "build", + "cmake.defines", + "dependencies", + "make.build_args", + "make.install_args", + "b2.args", + "download", + "git", + "install.files", +] + + +def parse_conditional_section_name(name, section_def): + expr = name[len(section_def) + 1 :] + return parse_expr(expr, ManifestContext.ALLOWED_VARIABLES) + + +def validate_allowed_fields(file_name, section, config, allowed_fields): + for field in config.options(section): + if not allowed_fields.get(field): + raise Exception( + ("manifest file %s section '%s' contains " "unknown field '%s'") + % (file_name, section, field) + ) + + for field in allowed_fields: + if allowed_fields[field] == REQUIRED and not config.has_option(section, field): + raise Exception( + ("manifest file %s section '%s' is missing " "required field '%s'") + % (file_name, section, field) + ) + + +def validate_allow_values(file_name, section, config): + for field in config.options(section): + value = config.get(section, field) + if value is not None: + raise Exception( + ( + "manifest file %s section '%s' has '%s = %s' but " + "this section doesn't allow specifying values " + "for its entries" + ) + % (file_name, section, field, value) + ) + + +def validate_section(file_name, section, config): + section_def = SCHEMA.get(section) + if not section_def: + for name in ALLOWED_EXPR_SECTIONS: + if section.startswith(name + "."): + # Verify that the conditional parses, but discard it + try: + parse_conditional_section_name(section, name) + except Exception as exc: + raise Exception( + ("manifest file %s section '%s' has invalid " "conditional: %s") + % (file_name, section, str(exc)) + ) + section_def = SCHEMA.get(name) + canonical_section_name = name + break + if not section_def: + raise Exception( + "manifest file %s contains unknown section '%s'" % (file_name, section) + ) + else: + canonical_section_name = section + + allowed_fields = section_def.get("fields") + if allowed_fields: + validate_allowed_fields(file_name, section, config, allowed_fields) + elif not section_def.get("allow_values", True): + validate_allow_values(file_name, section, config) + return canonical_section_name + + +class ManifestParser(object): + def __init__(self, file_name, fp=None): + # allow_no_value enables listing parameters in the + # autoconf.args section one per line + config = configparser.RawConfigParser(allow_no_value=True) + config.optionxform = str # make it case sensitive + + if fp is None: + with open(file_name, "r") as fp: + config.read_file(fp) + elif isinstance(fp, type("")): + # For testing purposes, parse from a string (str + # or unicode) + config.read_file(io.StringIO(fp)) + else: + config.read_file(fp) + + # validate against the schema + seen_sections = set() + + for section in config.sections(): + seen_sections.add(validate_section(file_name, section, config)) + + for section in SCHEMA.keys(): + section_def = SCHEMA[section] + if ( + not section_def.get("optional_section", False) + and section not in seen_sections + ): + raise Exception( + "manifest file %s is missing required section %s" + % (file_name, section) + ) + + self._config = config + self.name = config.get("manifest", "name") + self.fbsource_path = self.get("manifest", "fbsource_path") + self.shipit_project = self.get("manifest", "shipit_project") + self.shipit_fbcode_builder = self.get("manifest", "shipit_fbcode_builder") + + if self.name != os.path.basename(file_name): + raise Exception( + "filename of the manifest '%s' does not match the manifest name '%s'" + % (file_name, self.name) + ) + + def get(self, section, key, defval=None, ctx=None): + ctx = ctx or {} + + for s in self._config.sections(): + if s == section: + if self._config.has_option(s, key): + return self._config.get(s, key) + return defval + + if s.startswith(section + "."): + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + + if self._config.has_option(s, key): + return self._config.get(s, key) + + return defval + + def get_section_as_args(self, section, ctx=None): + """Intended for use with the make.[build_args/install_args] and + autoconf.args sections, this method collects the entries and returns an + array of strings. + If the manifest contains conditional sections, ctx is used to + evaluate the condition and merge in the values. + """ + args = [] + ctx = ctx or {} + + for s in self._config.sections(): + if s != section: + if not s.startswith(section + "."): + continue + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + for field in self._config.options(s): + value = self._config.get(s, field) + if value is None: + args.append(field) + else: + args.append("%s=%s" % (field, value)) + return args + + def get_section_as_ordered_pairs(self, section, ctx=None): + """Used for eg: shipit.pathmap which has strong + ordering requirements""" + res = [] + ctx = ctx or {} + + for s in self._config.sections(): + if s != section: + if not s.startswith(section + "."): + continue + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + + for key in self._config.options(s): + value = self._config.get(s, key) + res.append((key, value)) + return res + + def get_section_as_dict(self, section, ctx=None): + d = {} + ctx = ctx or {} + + for s in self._config.sections(): + if s != section: + if not s.startswith(section + "."): + continue + expr = parse_conditional_section_name(s, section) + if not expr.eval(ctx): + continue + for field in self._config.options(s): + value = self._config.get(s, field) + d[field] = value + return d + + def update_hash(self, hasher, ctx): + """Compute a hash over the configuration for the given + context. The goal is for the hash to change if the config + for that context changes, but not if a change is made to + the config only for a different platform than that expressed + by ctx. The hash is intended to be used to help invalidate + a future cache for the third party build products. + The hasher argument is a hash object returned from hashlib.""" + for section in sorted(SCHEMA.keys()): + hasher.update(section.encode("utf-8")) + + # Note: at the time of writing, nothing in the implementation + # relies on keys in any config section being ordered. + # In theory we could have conflicting flags in different + # config sections and later flags override earlier flags. + # For the purposes of computing a hash we're not super + # concerned about this: manifest changes should be rare + # enough and we'd rather that this trigger an invalidation + # than strive for a cache hit at this time. + pairs = self.get_section_as_ordered_pairs(section, ctx) + pairs.sort(key=lambda pair: pair[0]) + for key, value in pairs: + hasher.update(key.encode("utf-8")) + if value is not None: + hasher.update(value.encode("utf-8")) + + def is_first_party_project(self): + """returns true if this is an FB first-party project""" + return self.shipit_project is not None + + def get_required_system_packages(self, ctx): + """Returns dictionary of packager system -> list of packages""" + return { + "rpm": self.get_section_as_args("rpms", ctx), + "deb": self.get_section_as_args("debs", ctx), + } + + def _is_satisfied_by_preinstalled_environment(self, ctx): + envs = self.get_section_as_args("preinstalled.env", ctx) + if not envs: + return False + for key in envs: + val = os.environ.get(key, None) + print(f"Testing ENV[{key}]: {repr(val)}") + if val is None: + return False + if len(val) == 0: + return False + + return True + + def create_fetcher(self, build_options, ctx): + use_real_shipit = ( + ShipitTransformerFetcher.available() and build_options.use_shipit + ) + if ( + not use_real_shipit + and self.fbsource_path + and build_options.fbsource_dir + and self.shipit_project + ): + return SimpleShipitTransformerFetcher(build_options, self) + + if ( + self.fbsource_path + and build_options.fbsource_dir + and self.shipit_project + and ShipitTransformerFetcher.available() + ): + # We can use the code from fbsource + return ShipitTransformerFetcher(build_options, self.shipit_project) + + # Can we satisfy this dep with system packages? + if build_options.allow_system_packages: + if self._is_satisfied_by_preinstalled_environment(ctx): + return PreinstalledNopFetcher() + + packages = self.get_required_system_packages(ctx) + package_fetcher = SystemPackageFetcher(build_options, packages) + if package_fetcher.packages_are_installed(): + return package_fetcher + + repo_url = self.get("git", "repo_url", ctx=ctx) + if repo_url: + rev = self.get("git", "rev") + depth = self.get("git", "depth") + return GitFetcher(build_options, self, repo_url, rev, depth) + + url = self.get("download", "url", ctx=ctx) + if url: + # We need to defer this import until now to avoid triggering + # a cycle when the facebook/__init__.py is loaded. + try: + from getdeps.facebook.lfs import LFSCachingArchiveFetcher + + return LFSCachingArchiveFetcher( + build_options, self, url, self.get("download", "sha256", ctx=ctx) + ) + except ImportError: + # This FB internal module isn't shippped to github, + # so just use its base class + return ArchiveFetcher( + build_options, self, url, self.get("download", "sha256", ctx=ctx) + ) + + raise KeyError( + "project %s has no fetcher configuration matching %s" % (self.name, ctx) + ) + + def create_builder( # noqa:C901 + self, + build_options, + src_dir, + build_dir, + inst_dir, + ctx, + loader, + final_install_prefix=None, + extra_cmake_defines=None, + ): + builder = self.get("build", "builder", ctx=ctx) + if not builder: + raise Exception("project %s has no builder for %r" % (self.name, ctx)) + build_in_src_dir = self.get("build", "build_in_src_dir", "false", ctx=ctx) + if build_in_src_dir == "true": + # Some scripts don't work when they are configured and build in + # a different directory than source (or when the build directory + # is not a subdir of source). + build_dir = src_dir + subdir = self.get("build", "subdir", None, ctx=ctx) + if subdir is not None: + build_dir = os.path.join(build_dir, subdir) + print("build_dir is %s" % build_dir) # just to quiet lint + + if builder == "make" or builder == "cmakebootstrap": + build_args = self.get_section_as_args("make.build_args", ctx) + install_args = self.get_section_as_args("make.install_args", ctx) + test_args = self.get_section_as_args("make.test_args", ctx) + if builder == "cmakebootstrap": + return CMakeBootStrapBuilder( + build_options, + ctx, + self, + src_dir, + None, + inst_dir, + build_args, + install_args, + test_args, + ) + else: + return MakeBuilder( + build_options, + ctx, + self, + src_dir, + None, + inst_dir, + build_args, + install_args, + test_args, + ) + + if builder == "autoconf": + args = self.get_section_as_args("autoconf.args", ctx) + return AutoconfBuilder( + build_options, ctx, self, src_dir, build_dir, inst_dir, args + ) + + if builder == "boost": + args = self.get_section_as_args("b2.args", ctx) + return Boost(build_options, ctx, self, src_dir, build_dir, inst_dir, args) + + if builder == "bistro": + return BistroBuilder( + build_options, + ctx, + self, + src_dir, + build_dir, + inst_dir, + ) + + if builder == "cmake": + defines = self.get_section_as_dict("cmake.defines", ctx) + return CMakeBuilder( + build_options, + ctx, + self, + src_dir, + build_dir, + inst_dir, + defines, + final_install_prefix, + extra_cmake_defines, + ) + + if builder == "python-wheel": + return PythonWheelBuilder( + build_options, ctx, self, src_dir, build_dir, inst_dir + ) + + if builder == "sqlite": + return SqliteBuilder(build_options, ctx, self, src_dir, build_dir, inst_dir) + + if builder == "ninja_bootstrap": + return NinjaBootstrap( + build_options, ctx, self, build_dir, src_dir, inst_dir + ) + + if builder == "nop": + return NopBuilder(build_options, ctx, self, src_dir, inst_dir) + + if builder == "openssl": + return OpenSSLBuilder( + build_options, ctx, self, build_dir, src_dir, inst_dir + ) + + if builder == "iproute2": + return Iproute2Builder( + build_options, ctx, self, src_dir, build_dir, inst_dir + ) + + if builder == "cargo": + build_doc = self.get("cargo", "build_doc", False, ctx) + workspace_dir = self.get("cargo", "workspace_dir", None, ctx) + manifests_to_build = self.get("cargo", "manifests_to_build", None, ctx) + return CargoBuilder( + build_options, + ctx, + self, + src_dir, + build_dir, + inst_dir, + build_doc, + workspace_dir, + manifests_to_build, + loader, + ) + + if builder == "OpenNSA": + return OpenNSABuilder(build_options, ctx, self, src_dir, inst_dir) + + raise KeyError("project %s has no known builder" % (self.name)) + + +class ManifestContext(object): + """ProjectContext contains a dictionary of values to use when evaluating boolean + expressions in a project manifest. + + This object should be passed as the `ctx` parameter in ManifestParser.get() calls. + """ + + ALLOWED_VARIABLES = {"os", "distro", "distro_vers", "fb", "test"} + + def __init__(self, ctx_dict): + assert set(ctx_dict.keys()) == self.ALLOWED_VARIABLES + self.ctx_dict = ctx_dict + + def get(self, key): + return self.ctx_dict[key] + + def set(self, key, value): + assert key in self.ALLOWED_VARIABLES + self.ctx_dict[key] = value + + def copy(self): + return ManifestContext(dict(self.ctx_dict)) + + def __str__(self): + s = ", ".join( + "%s=%s" % (key, value) for key, value in sorted(self.ctx_dict.items()) + ) + return "{" + s + "}" + + +class ContextGenerator(object): + """ContextGenerator allows creating ManifestContext objects on a per-project basis. + This allows us to evaluate different projects with slightly different contexts. + + For instance, this can be used to only enable tests for some projects.""" + + def __init__(self, default_ctx): + self.default_ctx = ManifestContext(default_ctx) + self.ctx_by_project = {} + + def set_value_for_project(self, project_name, key, value): + project_ctx = self.ctx_by_project.get(project_name) + if project_ctx is None: + project_ctx = self.default_ctx.copy() + self.ctx_by_project[project_name] = project_ctx + project_ctx.set(key, value) + + def set_value_for_all_projects(self, key, value): + self.default_ctx.set(key, value) + for ctx in self.ctx_by_project.values(): + ctx.set(key, value) + + def get_context(self, project_name): + return self.ctx_by_project.get(project_name, self.default_ctx) diff --git a/build/fbcode_builder/getdeps/platform.py b/build/fbcode_builder/getdeps/platform.py new file mode 100644 index 000000000..fd8382e73 --- /dev/null +++ b/build/fbcode_builder/getdeps/platform.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import platform +import re +import shlex +import sys + + +def is_windows(): + """Returns true if the system we are currently running on + is a Windows system""" + return sys.platform.startswith("win") + + +def get_linux_type(): + try: + with open("/etc/os-release") as f: + data = f.read() + except EnvironmentError: + return (None, None) + + os_vars = {} + for line in data.splitlines(): + parts = line.split("=", 1) + if len(parts) != 2: + continue + key = parts[0].strip() + value_parts = shlex.split(parts[1].strip()) + if not value_parts: + value = "" + else: + value = value_parts[0] + os_vars[key] = value + + name = os_vars.get("NAME") + if name: + name = name.lower() + name = re.sub("linux", "", name) + name = name.strip() + + version_id = os_vars.get("VERSION_ID") + if version_id: + version_id = version_id.lower() + + return "linux", name, version_id + + +class HostType(object): + def __init__(self, ostype=None, distro=None, distrovers=None): + if ostype is None: + distro = None + distrovers = None + if sys.platform.startswith("linux"): + ostype, distro, distrovers = get_linux_type() + elif sys.platform.startswith("darwin"): + ostype = "darwin" + elif is_windows(): + ostype = "windows" + distrovers = str(sys.getwindowsversion().major) + else: + ostype = sys.platform + + # The operating system type + self.ostype = ostype + # The distribution, if applicable + self.distro = distro + # The OS/distro version if known + self.distrovers = distrovers + machine = platform.machine().lower() + if "arm" in machine or "aarch" in machine: + self.isarm = True + else: + self.isarm = False + + def is_windows(self): + return self.ostype == "windows" + + def is_arm(self): + return self.isarm + + def is_darwin(self): + return self.ostype == "darwin" + + def is_linux(self): + return self.ostype == "linux" + + def as_tuple_string(self): + return "%s-%s-%s" % ( + self.ostype, + self.distro or "none", + self.distrovers or "none", + ) + + def get_package_manager(self): + if not self.is_linux(): + return None + if self.distro in ("fedora", "centos"): + return "rpm" + if self.distro in ("debian", "ubuntu"): + return "deb" + return None + + @staticmethod + def from_tuple_string(s): + ostype, distro, distrovers = s.split("-") + return HostType(ostype=ostype, distro=distro, distrovers=distrovers) + + def __eq__(self, b): + return ( + self.ostype == b.ostype + and self.distro == b.distro + and self.distrovers == b.distrovers + ) diff --git a/build/fbcode_builder/getdeps/py_wheel_builder.py b/build/fbcode_builder/getdeps/py_wheel_builder.py new file mode 100644 index 000000000..82ad8b807 --- /dev/null +++ b/build/fbcode_builder/getdeps/py_wheel_builder.py @@ -0,0 +1,289 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import codecs +import collections +import email +import os +import re +import stat + +from .builder import BuilderBase, CMakeBuilder + + +WheelNameInfo = collections.namedtuple( + "WheelNameInfo", ("distribution", "version", "build", "python", "abi", "platform") +) + +CMAKE_HEADER = """ +cmake_minimum_required(VERSION 3.8) + +project("{manifest_name}" LANGUAGES C) + +set(CMAKE_MODULE_PATH + "{cmake_dir}" + ${{CMAKE_MODULE_PATH}} +) +include(FBPythonBinary) + +set(CMAKE_INSTALL_DIR lib/cmake/{manifest_name} CACHE STRING + "The subdirectory where CMake package config files should be installed") +""" + +CMAKE_FOOTER = """ +install_fb_python_library({lib_name} EXPORT all) +install( + EXPORT all + FILE {manifest_name}-targets.cmake + NAMESPACE {namespace}:: + DESTINATION ${{CMAKE_INSTALL_DIR}} +) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${{CMAKE_BINARY_DIR}}/{manifest_name}-config.cmake.in + {manifest_name}-config.cmake + INSTALL_DESTINATION ${{CMAKE_INSTALL_DIR}} + PATH_VARS + CMAKE_INSTALL_DIR +) +install( + FILES ${{CMAKE_CURRENT_BINARY_DIR}}/{manifest_name}-config.cmake + DESTINATION ${{CMAKE_INSTALL_DIR}} +) +""" + +CMAKE_CONFIG_FILE = """ +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) + +set_and_check({upper_name}_CMAKE_DIR "@PACKAGE_CMAKE_INSTALL_DIR@") + +if (NOT TARGET {namespace}::{lib_name}) + include("${{{upper_name}_CMAKE_DIR}}/{manifest_name}-targets.cmake") +endif() + +set({upper_name}_LIBRARIES {namespace}::{lib_name}) + +{find_dependency_lines} + +if (NOT {manifest_name}_FIND_QUIETLY) + message(STATUS "Found {manifest_name}: ${{PACKAGE_PREFIX_DIR}}") +endif() +""" + + +# Note: for now we are manually manipulating the wheel packet contents. +# The wheel format is documented here: +# https://www.python.org/dev/peps/pep-0491/#file-format +# +# We currently aren't particularly smart about correctly handling the full wheel +# functionality, but this is good enough to handle simple pure-python wheels, +# which is the main thing we care about right now. +# +# We could potentially use pip to install the wheel to a temporary location and +# then copy its "installed" files, but this has its own set of complications. +# This would require pip to already be installed and available, and we would +# need to correctly find the right version of pip or pip3 to use. +# If we did ever want to go down that path, we would probably want to use +# something like the following pip3 command: +# pip3 --isolated install --no-cache-dir --no-index --system \ +# --target +class PythonWheelBuilder(BuilderBase): + """This Builder can take Python wheel archives and install them as python libraries + that can be used by add_fb_python_library()/add_fb_python_executable() CMake rules. + """ + + def _build(self, install_dirs, reconfigure): + # type: (List[str], bool) -> None + + # When we are invoked, self.src_dir contains the unpacked wheel contents. + # + # Since a wheel file is just a zip file, the Fetcher code recognizes it as such + # and goes ahead and unpacks it. (We could disable that Fetcher behavior in the + # future if we ever wanted to, say if we wanted to call pip here.) + wheel_name = self._parse_wheel_name() + name_version_prefix = "-".join((wheel_name.distribution, wheel_name.version)) + dist_info_name = name_version_prefix + ".dist-info" + data_dir_name = name_version_prefix + ".data" + self.dist_info_dir = os.path.join(self.src_dir, dist_info_name) + wheel_metadata = self._read_wheel_metadata(wheel_name) + + # Check that we can understand the wheel version. + # We don't really care about wheel_metadata["Root-Is-Purelib"] since + # we are generating our own standalone python archives rather than installing + # into site-packages. + version = wheel_metadata["Wheel-Version"] + if not version.startswith("1."): + raise Exception("unsupported wheel version %s" % (version,)) + + # Add a find_dependency() call for each of our dependencies. + # The dependencies are also listed in the wheel METADATA file, but it is simpler + # to pull this directly from the getdeps manifest. + dep_list = sorted( + self.manifest.get_section_as_dict("dependencies", self.ctx).keys() + ) + find_dependency_lines = ["find_dependency({})".format(dep) for dep in dep_list] + + getdeps_cmake_dir = os.path.join( + os.path.dirname(os.path.dirname(__file__)), "CMake" + ) + self.template_format_dict = { + # Note that CMake files always uses forward slash separators in path names, + # even on Windows. Therefore replace path separators here. + "cmake_dir": _to_cmake_path(getdeps_cmake_dir), + "lib_name": self.manifest.name, + "manifest_name": self.manifest.name, + "namespace": self.manifest.name, + "upper_name": self.manifest.name.upper().replace("-", "_"), + "find_dependency_lines": "\n".join(find_dependency_lines), + } + + # Find sources from the root directory + path_mapping = {} + for entry in os.listdir(self.src_dir): + if entry in (dist_info_name, data_dir_name): + continue + self._add_sources(path_mapping, os.path.join(self.src_dir, entry), entry) + + # Files under the .data directory also need to be installed in the correct + # locations + if os.path.exists(data_dir_name): + # TODO: process the subdirectories of data_dir_name + # This isn't implemented yet since for now we have only needed dependencies + # on some simple pure Python wheels, so I haven't tested against wheels with + # additional files in the .data directory. + raise Exception( + "handling of the subdirectories inside %s is not implemented yet" + % data_dir_name + ) + + # Emit CMake files + self._write_cmakelists(path_mapping, dep_list) + self._write_cmake_config_template() + + # Run the build + self._run_cmake_build(install_dirs, reconfigure) + + def _run_cmake_build(self, install_dirs, reconfigure): + # type: (List[str], bool) -> None + + cmake_builder = CMakeBuilder( + build_opts=self.build_opts, + ctx=self.ctx, + manifest=self.manifest, + # Note that we intentionally supply src_dir=build_dir, + # since we wrote out our generated CMakeLists.txt in the build directory + src_dir=self.build_dir, + build_dir=self.build_dir, + inst_dir=self.inst_dir, + defines={}, + final_install_prefix=None, + ) + cmake_builder.build(install_dirs=install_dirs, reconfigure=reconfigure) + + def _write_cmakelists(self, path_mapping, dependencies): + # type: (List[str]) -> None + + cmake_path = os.path.join(self.build_dir, "CMakeLists.txt") + with open(cmake_path, "w") as f: + f.write(CMAKE_HEADER.format(**self.template_format_dict)) + for dep in dependencies: + f.write("find_package({0} REQUIRED)\n".format(dep)) + + f.write( + "add_fb_python_library({lib_name}\n".format(**self.template_format_dict) + ) + f.write(' BASE_DIR "%s"\n' % _to_cmake_path(self.src_dir)) + f.write(" SOURCES\n") + for src_path, install_path in path_mapping.items(): + f.write( + ' "%s=%s"\n' + % (_to_cmake_path(src_path), _to_cmake_path(install_path)) + ) + if dependencies: + f.write(" DEPENDS\n") + for dep in dependencies: + f.write(' "{0}::{0}"\n'.format(dep)) + f.write(")\n") + + f.write(CMAKE_FOOTER.format(**self.template_format_dict)) + + def _write_cmake_config_template(self): + config_path_name = self.manifest.name + "-config.cmake.in" + output_path = os.path.join(self.build_dir, config_path_name) + + with open(output_path, "w") as f: + f.write(CMAKE_CONFIG_FILE.format(**self.template_format_dict)) + + def _add_sources(self, path_mapping, src_path, install_path): + # type: (List[str], str, str) -> None + + s = os.lstat(src_path) + if not stat.S_ISDIR(s.st_mode): + path_mapping[src_path] = install_path + return + + for entry in os.listdir(src_path): + self._add_sources( + path_mapping, + os.path.join(src_path, entry), + os.path.join(install_path, entry), + ) + + def _parse_wheel_name(self): + # type: () -> WheelNameInfo + + # The ArchiveFetcher prepends "manifest_name-", so strip that off first. + wheel_name = os.path.basename(self.src_dir) + prefix = self.manifest.name + "-" + if not wheel_name.startswith(prefix): + raise Exception( + "expected wheel source directory to be of the form %s-NAME.whl" + % (prefix,) + ) + wheel_name = wheel_name[len(prefix) :] + + wheel_name_re = re.compile( + r"(?P[^-]+)" + r"-(?P\d+[^-]*)" + r"(-(?P\d+[^-]*))?" + r"-(?P\w+\d+(\.\w+\d+)*)" + r"-(?P\w+)" + r"-(?P\w+(\.\w+)*)" + r"\.whl" + ) + match = wheel_name_re.match(wheel_name) + if not match: + raise Exception( + "bad python wheel name %s: expected to have the form " + "DISTRIBUTION-VERSION-[-BUILD]-PYTAG-ABI-PLATFORM" + ) + + return WheelNameInfo( + distribution=match.group("distribution"), + version=match.group("version"), + build=match.group("build"), + python=match.group("python"), + abi=match.group("abi"), + platform=match.group("platform"), + ) + + def _read_wheel_metadata(self, wheel_name): + metadata_path = os.path.join(self.dist_info_dir, "WHEEL") + with codecs.open(metadata_path, "r", encoding="utf-8") as f: + return email.message_from_file(f) + + +def _to_cmake_path(path): + # CMake always uses forward slashes to separate paths in CMakeLists.txt files, + # even on Windows. It treats backslashes as character escapes, so using + # backslashes in the path will cause problems. Therefore replace all path + # separators with forward slashes to make sure the paths are correct on Windows. + # e.g. "C:\foo\bar.txt" becomes "C:/foo/bar.txt" + return path.replace(os.path.sep, "/") diff --git a/build/fbcode_builder/getdeps/runcmd.py b/build/fbcode_builder/getdeps/runcmd.py new file mode 100644 index 000000000..44e7994aa --- /dev/null +++ b/build/fbcode_builder/getdeps/runcmd.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import os +import select +import subprocess +import sys + +from .envfuncs import Env +from .platform import is_windows + + +try: + from shlex import quote as shellquote +except ImportError: + from pipes import quote as shellquote + + +class RunCommandError(Exception): + pass + + +def _print_env_diff(env, log_fn): + current_keys = set(os.environ.keys()) + wanted_env = set(env.keys()) + + unset_keys = current_keys.difference(wanted_env) + for k in sorted(unset_keys): + log_fn("+ unset %s\n" % k) + + added_keys = wanted_env.difference(current_keys) + for k in wanted_env.intersection(current_keys): + if os.environ[k] != env[k]: + added_keys.add(k) + + for k in sorted(added_keys): + if ("PATH" in k) and (os.pathsep in env[k]): + log_fn("+ %s=\\\n" % k) + for elem in env[k].split(os.pathsep): + log_fn("+ %s%s\\\n" % (shellquote(elem), os.pathsep)) + else: + log_fn("+ %s=%s \\\n" % (k, shellquote(env[k]))) + + +def run_cmd(cmd, env=None, cwd=None, allow_fail=False, log_file=None): + def log_to_stdout(msg): + sys.stdout.buffer.write(msg.encode(errors="surrogateescape")) + + if log_file is not None: + with open(log_file, "a", encoding="utf-8", errors="surrogateescape") as log: + + def log_function(msg): + log.write(msg) + log_to_stdout(msg) + + return _run_cmd( + cmd, env=env, cwd=cwd, allow_fail=allow_fail, log_fn=log_function + ) + else: + return _run_cmd( + cmd, env=env, cwd=cwd, allow_fail=allow_fail, log_fn=log_to_stdout + ) + + +def _run_cmd(cmd, env, cwd, allow_fail, log_fn): + log_fn("---\n") + try: + cmd_str = " \\\n+ ".join(shellquote(arg) for arg in cmd) + except TypeError: + # eg: one of the elements is None + raise RunCommandError("problem quoting cmd: %r" % cmd) + + if env: + assert isinstance(env, Env) + _print_env_diff(env, log_fn) + + # Convert from our Env type to a regular dict. + # This is needed because python3 looks up b'PATH' and 'PATH' + # and emits an error if both are present. In our Env type + # we'll return the same value for both requests, but we don't + # have duplicate potentially conflicting values which is the + # spirit of the check. + env = dict(env.items()) + + if cwd: + log_fn("+ cd %s && \\\n" % shellquote(cwd)) + # Our long path escape sequence may confuse cmd.exe, so if the cwd + # is short enough, strip that off. + if is_windows() and (len(cwd) < 250) and cwd.startswith("\\\\?\\"): + cwd = cwd[4:] + + log_fn("+ %s\n" % cmd_str) + + isinteractive = os.isatty(sys.stdout.fileno()) + if isinteractive: + stdout = None + sys.stdout.buffer.flush() + else: + stdout = subprocess.PIPE + + try: + p = subprocess.Popen( + cmd, env=env, cwd=cwd, stdout=stdout, stderr=subprocess.STDOUT + ) + except (TypeError, ValueError, OSError) as exc: + log_fn("error running `%s`: %s" % (cmd_str, exc)) + raise RunCommandError( + "%s while running `%s` with env=%r\nos.environ=%r" + % (str(exc), cmd_str, env, os.environ) + ) + + if not isinteractive: + _pipe_output(p, log_fn) + + p.wait() + if p.returncode != 0 and not allow_fail: + raise subprocess.CalledProcessError(p.returncode, cmd) + + return p.returncode + + +if hasattr(select, "poll"): + + def _pipe_output(p, log_fn): + """Read output from p.stdout and call log_fn() with each chunk of data as it + becomes available.""" + # Perform non-blocking reads + import fcntl + + fcntl.fcntl(p.stdout.fileno(), fcntl.F_SETFL, os.O_NONBLOCK) + poll = select.poll() + poll.register(p.stdout.fileno(), select.POLLIN) + + buffer_size = 4096 + while True: + poll.poll() + data = p.stdout.read(buffer_size) + if not data: + break + # log_fn() accepts arguments as str (binary in Python 2, unicode in + # Python 3). In Python 3 the subprocess output will be plain bytes, + # and need to be decoded. + if not isinstance(data, str): + data = data.decode("utf-8", errors="surrogateescape") + log_fn(data) + + +else: + + def _pipe_output(p, log_fn): + """Read output from p.stdout and call log_fn() with each chunk of data as it + becomes available.""" + # Perform blocking reads. Use a smaller buffer size to avoid blocking + # for very long when data is available. + buffer_size = 64 + while True: + data = p.stdout.read(buffer_size) + if not data: + break + # log_fn() accepts arguments as str (binary in Python 2, unicode in + # Python 3). In Python 3 the subprocess output will be plain bytes, + # and need to be decoded. + if not isinstance(data, str): + data = data.decode("utf-8", errors="surrogateescape") + log_fn(data) diff --git a/build/fbcode_builder/getdeps/subcmd.py b/build/fbcode_builder/getdeps/subcmd.py new file mode 100644 index 000000000..95f9a07ca --- /dev/null +++ b/build/fbcode_builder/getdeps/subcmd.py @@ -0,0 +1,58 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + + +class SubCmd(object): + NAME = None + HELP = None + + def run(self, args): + """perform the command""" + return 0 + + def setup_parser(self, parser): + # Subclasses should override setup_parser() if they have any + # command line options or arguments. + pass + + +CmdTable = [] + + +def add_subcommands(parser, common_args, cmd_table=CmdTable): + """Register parsers for the defined commands with the provided parser""" + for cls in cmd_table: + command = cls() + command_parser = parser.add_parser( + command.NAME, help=command.HELP, parents=[common_args] + ) + command.setup_parser(command_parser) + command_parser.set_defaults(func=command.run) + + +def cmd(name, help=None, cmd_table=CmdTable): + """ + @cmd() is a decorator that can be used to help define Subcmd instances + + Example usage: + + @subcmd('list', 'Show the result list') + class ListCmd(Subcmd): + def run(self, args): + # Perform the command actions here... + pass + """ + + def wrapper(cls): + class SubclassedCmd(cls): + NAME = name + HELP = help + + cmd_table.append(SubclassedCmd) + return SubclassedCmd + + return wrapper diff --git a/build/fbcode_builder/getdeps/test/expr_test.py b/build/fbcode_builder/getdeps/test/expr_test.py new file mode 100644 index 000000000..59d66a943 --- /dev/null +++ b/build/fbcode_builder/getdeps/test/expr_test.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import unittest + +from ..expr import parse_expr + + +class ExprTest(unittest.TestCase): + def test_equal(self): + valid_variables = {"foo", "some_var", "another_var"} + e = parse_expr("foo=bar", valid_variables) + self.assertTrue(e.eval({"foo": "bar"})) + self.assertFalse(e.eval({"foo": "not-bar"})) + self.assertFalse(e.eval({"not-foo": "bar"})) + + def test_not_equal(self): + valid_variables = {"foo"} + e = parse_expr("not(foo=bar)", valid_variables) + self.assertFalse(e.eval({"foo": "bar"})) + self.assertTrue(e.eval({"foo": "not-bar"})) + + def test_bad_not(self): + valid_variables = {"foo"} + with self.assertRaises(Exception): + parse_expr("foo=not(bar)", valid_variables) + + def test_bad_variable(self): + valid_variables = {"bar"} + with self.assertRaises(Exception): + parse_expr("foo=bar", valid_variables) + + def test_all(self): + valid_variables = {"foo", "baz"} + e = parse_expr("all(foo = bar, baz = qux)", valid_variables) + self.assertTrue(e.eval({"foo": "bar", "baz": "qux"})) + self.assertFalse(e.eval({"foo": "bar", "baz": "nope"})) + self.assertFalse(e.eval({"foo": "nope", "baz": "nope"})) + + def test_any(self): + valid_variables = {"foo", "baz"} + e = parse_expr("any(foo = bar, baz = qux)", valid_variables) + self.assertTrue(e.eval({"foo": "bar", "baz": "qux"})) + self.assertTrue(e.eval({"foo": "bar", "baz": "nope"})) + self.assertFalse(e.eval({"foo": "nope", "baz": "nope"})) diff --git a/build/fbcode_builder/getdeps/test/fixtures/duplicate/foo b/build/fbcode_builder/getdeps/test/fixtures/duplicate/foo new file mode 100644 index 000000000..a0384ee3b --- /dev/null +++ b/build/fbcode_builder/getdeps/test/fixtures/duplicate/foo @@ -0,0 +1,2 @@ +[manifest] +name = foo diff --git a/build/fbcode_builder/getdeps/test/fixtures/duplicate/subdir/foo b/build/fbcode_builder/getdeps/test/fixtures/duplicate/subdir/foo new file mode 100644 index 000000000..a0384ee3b --- /dev/null +++ b/build/fbcode_builder/getdeps/test/fixtures/duplicate/subdir/foo @@ -0,0 +1,2 @@ +[manifest] +name = foo diff --git a/build/fbcode_builder/getdeps/test/manifest_test.py b/build/fbcode_builder/getdeps/test/manifest_test.py new file mode 100644 index 000000000..8be9896d8 --- /dev/null +++ b/build/fbcode_builder/getdeps/test/manifest_test.py @@ -0,0 +1,233 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import sys +import unittest + +from ..load import load_all_manifests, patch_loader +from ..manifest import ManifestParser + + +class ManifestTest(unittest.TestCase): + def test_missing_section(self): + with self.assertRaisesRegex( + Exception, "manifest file test is missing required section manifest" + ): + ManifestParser("test", "") + + def test_missing_name(self): + with self.assertRaisesRegex( + Exception, + "manifest file test section 'manifest' is missing required field 'name'", + ): + ManifestParser( + "test", + """ +[manifest] +""", + ) + + def test_minimal(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test +""", + ) + self.assertEqual(p.name, "test") + self.assertEqual(p.fbsource_path, None) + + def test_minimal_with_fbsource_path(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test +fbsource_path = fbcode/wat +""", + ) + self.assertEqual(p.name, "test") + self.assertEqual(p.fbsource_path, "fbcode/wat") + + def test_unknown_field(self): + with self.assertRaisesRegex( + Exception, + ( + "manifest file test section 'manifest' contains " + "unknown field 'invalid.field'" + ), + ): + ManifestParser( + "test", + """ +[manifest] +name = test +invalid.field = woot +""", + ) + + def test_invalid_section_name(self): + with self.assertRaisesRegex( + Exception, "manifest file test contains unknown section 'invalid.section'" + ): + ManifestParser( + "test", + """ +[manifest] +name = test + +[invalid.section] +foo = bar +""", + ) + + def test_value_in_dependencies_section(self): + with self.assertRaisesRegex( + Exception, + ( + "manifest file test section 'dependencies' has " + "'foo = bar' but this section doesn't allow " + "specifying values for its entries" + ), + ): + ManifestParser( + "test", + """ +[manifest] +name = test + +[dependencies] +foo = bar +""", + ) + + def test_invalid_conditional_section_name(self): + with self.assertRaisesRegex( + Exception, + ( + "manifest file test section 'dependencies.=' " + "has invalid conditional: expected " + "identifier found =" + ), + ): + ManifestParser( + "test", + """ +[manifest] +name = test + +[dependencies.=] +""", + ) + + def test_section_as_args(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test + +[dependencies] +a +b +c + +[dependencies.test=on] +foo +""", + ) + self.assertEqual(p.get_section_as_args("dependencies"), ["a", "b", "c"]) + self.assertEqual( + p.get_section_as_args("dependencies", {"test": "off"}), ["a", "b", "c"] + ) + self.assertEqual( + p.get_section_as_args("dependencies", {"test": "on"}), + ["a", "b", "c", "foo"], + ) + + p2 = ManifestParser( + "test", + """ +[manifest] +name = test + +[autoconf.args] +--prefix=/foo +--with-woot +""", + ) + self.assertEqual( + p2.get_section_as_args("autoconf.args"), ["--prefix=/foo", "--with-woot"] + ) + + def test_section_as_dict(self): + p = ManifestParser( + "test", + """ +[manifest] +name = test + +[cmake.defines] +foo = bar + +[cmake.defines.test=on] +foo = baz +""", + ) + self.assertEqual(p.get_section_as_dict("cmake.defines"), {"foo": "bar"}) + self.assertEqual( + p.get_section_as_dict("cmake.defines", {"test": "on"}), {"foo": "baz"} + ) + + p2 = ManifestParser( + "test", + """ +[manifest] +name = test + +[cmake.defines.test=on] +foo = baz + +[cmake.defines] +foo = bar +""", + ) + self.assertEqual( + p2.get_section_as_dict("cmake.defines", {"test": "on"}), + {"foo": "bar"}, + msg="sections cascade in the order they appear in the manifest", + ) + + def test_parse_common_manifests(self): + patch_loader(__name__) + manifests = load_all_manifests(None) + self.assertNotEqual(0, len(manifests), msg="parsed some number of manifests") + + def test_mismatch_name(self): + with self.assertRaisesRegex( + Exception, + "filename of the manifest 'foo' does not match the manifest name 'bar'", + ): + ManifestParser( + "foo", + """ +[manifest] +name = bar +""", + ) + + def test_duplicate_manifest(self): + patch_loader(__name__, "fixtures/duplicate") + + with self.assertRaisesRegex(Exception, "found duplicate manifest 'foo'"): + load_all_manifests(None) + + if sys.version_info < (3, 2): + + def assertRaisesRegex(self, *args, **kwargs): + return self.assertRaisesRegexp(*args, **kwargs) diff --git a/build/fbcode_builder/getdeps/test/platform_test.py b/build/fbcode_builder/getdeps/test/platform_test.py new file mode 100644 index 000000000..311e9c76c --- /dev/null +++ b/build/fbcode_builder/getdeps/test/platform_test.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function, unicode_literals + +import unittest + +from ..platform import HostType + + +class PlatformTest(unittest.TestCase): + def test_create(self): + p = HostType() + self.assertNotEqual(p.ostype, None, msg="probed and returned something") + + tuple_string = p.as_tuple_string() + round_trip = HostType.from_tuple_string(tuple_string) + self.assertEqual(round_trip, p) + + def test_rendering_of_none(self): + p = HostType(ostype="foo") + self.assertEqual(p.as_tuple_string(), "foo-none-none") + + def test_is_methods(self): + p = HostType(ostype="windows") + self.assertTrue(p.is_windows()) + self.assertFalse(p.is_darwin()) + self.assertFalse(p.is_linux()) + + p = HostType(ostype="darwin") + self.assertFalse(p.is_windows()) + self.assertTrue(p.is_darwin()) + self.assertFalse(p.is_linux()) + + p = HostType(ostype="linux") + self.assertFalse(p.is_windows()) + self.assertFalse(p.is_darwin()) + self.assertTrue(p.is_linux()) diff --git a/build/fbcode_builder/getdeps/test/scratch_test.py b/build/fbcode_builder/getdeps/test/scratch_test.py new file mode 100644 index 000000000..1f43c5951 --- /dev/null +++ b/build/fbcode_builder/getdeps/test/scratch_test.py @@ -0,0 +1,80 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import absolute_import, division, print_function + +import unittest + +from ..buildopts import find_existing_win32_subst_for_path + + +class Win32SubstTest(unittest.TestCase): + def test_no_existing_subst(self): + self.assertIsNone( + find_existing_win32_subst_for_path( + r"C:\users\alice\appdata\local\temp\fbcode_builder_getdeps", + subst_mapping={}, + ) + ) + self.assertIsNone( + find_existing_win32_subst_for_path( + r"C:\users\alice\appdata\local\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:\users\alice\appdata\local\temp\other"}, + ) + ) + + def test_exact_match_returns_drive_path(self): + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:\temp\fbcode_builder_getdeps"}, + ), + "X:\\", + ) + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:/temp/fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:/temp/fbcode_builder_getdeps"}, + ), + "X:\\", + ) + + def test_multiple_exact_matches_returns_arbitrary_drive_path(self): + self.assertIn( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={ + "X:\\": r"C:\temp\fbcode_builder_getdeps", + "Y:\\": r"C:\temp\fbcode_builder_getdeps", + "Z:\\": r"C:\temp\fbcode_builder_getdeps", + }, + ), + ("X:\\", "Y:\\", "Z:\\"), + ) + + def test_drive_letter_is_case_insensitive(self): + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"c:\temp\fbcode_builder_getdeps"}, + ), + "X:\\", + ) + + def test_path_components_are_case_insensitive(self): + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\TEMP\FBCODE_builder_getdeps", + subst_mapping={"X:\\": r"C:\temp\fbcode_builder_getdeps"}, + ), + "X:\\", + ) + self.assertEqual( + find_existing_win32_subst_for_path( + r"C:\temp\fbcode_builder_getdeps", + subst_mapping={"X:\\": r"C:\TEMP\FBCODE_builder_getdeps"}, + ), + "X:\\", + ) diff --git a/build/fbcode_builder/make_docker_context.py b/build/fbcode_builder/make_docker_context.py new file mode 100755 index 000000000..d4b0f0a89 --- /dev/null +++ b/build/fbcode_builder/make_docker_context.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" +Reads `fbcode_builder_config.py` from the current directory, and prepares a +Docker context directory to build this project. Prints to stdout the path +to the context directory. + +Try `.../make_docker_context.py --help` from a project's `build/` directory. + +By default, the Docker context directory will be in /tmp. It will always +contain a Dockerfile, and might also contain copies of your local repos, and +other data needed for the build container. +""" + +import os +import tempfile +import textwrap + +from docker_builder import DockerFBCodeBuilder +from parse_args import parse_args_to_fbcode_builder_opts + + +def make_docker_context( + get_steps_fn, github_project, opts=None, default_context_dir=None +): + """ + Returns a path to the Docker context directory. See parse_args.py. + + Helper for making a command-line utility that writes your project's + Dockerfile and associated data into a (temporary) directory. Your main + program might look something like this: + + print(make_docker_context( + lambda builder: [builder.step(...), ...], + 'facebook/your_project', + )) + """ + + if opts is None: + opts = {} + + valid_versions = ( + ("ubuntu:16.04", "5"), + ("ubuntu:18.04", "7"), + ) + + def add_args(parser): + parser.add_argument( + "--docker-context-dir", + metavar="DIR", + default=default_context_dir, + help="Write the Dockerfile and its context into this directory. " + "If empty, make a temporary directory. Default: %(default)s.", + ) + parser.add_argument( + "--user", + metavar="NAME", + default=opts.get("user", "nobody"), + help="Build and install as this user. Default: %(default)s.", + ) + parser.add_argument( + "--prefix", + metavar="DIR", + default=opts.get("prefix", "/home/install"), + help="Install all libraries in this prefix. Default: %(default)s.", + ) + parser.add_argument( + "--projects-dir", + metavar="DIR", + default=opts.get("projects_dir", "/home"), + help="Place project code directories here. Default: %(default)s.", + ) + parser.add_argument( + "--os-image", + metavar="IMG", + choices=zip(*valid_versions)[0], + default=opts.get("os_image", valid_versions[0][0]), + help="Docker OS image -- be sure to use only ones you trust (See " + "README.docker). Choices: %(choices)s. Default: %(default)s.", + ) + parser.add_argument( + "--gcc-version", + metavar="VER", + choices=set(zip(*valid_versions)[1]), + default=opts.get("gcc_version", valid_versions[0][1]), + help="Choices: %(choices)s. Default: %(default)s.", + ) + parser.add_argument( + "--make-parallelism", + metavar="NUM", + type=int, + default=opts.get("make_parallelism", 1), + help="Use `make -j` on multi-CPU systems with lots of RAM. " + "Default: %(default)s.", + ) + parser.add_argument( + "--local-repo-dir", + metavar="DIR", + help="If set, build {0} from a local directory instead of Github.".format( + github_project + ), + ) + parser.add_argument( + "--ccache-tgz", + metavar="PATH", + help="If set, enable ccache for the build. To initialize the " + "cache, first try to hardlink, then to copy --cache-tgz " + "as ccache.tgz into the --docker-context-dir.", + ) + + opts = parse_args_to_fbcode_builder_opts( + add_args, + # These have add_argument() calls, others are set via --option. + ( + "docker_context_dir", + "user", + "prefix", + "projects_dir", + "os_image", + "gcc_version", + "make_parallelism", + "local_repo_dir", + "ccache_tgz", + ), + opts, + help=textwrap.dedent( + """ + + Reads `fbcode_builder_config.py` from the current directory, and + prepares a Docker context directory to build {github_project} and + its dependencies. Prints to stdout the path to the context + directory. + + Pass --option {github_project}:git_hash SHA1 to build something + other than the master branch from Github. + + Or, pass --option {github_project}:local_repo_dir LOCAL_PATH to + build from a local repo instead of cloning from Github. + + Usage: + (cd $(./make_docker_context.py) && docker build . 2>&1 | tee log) + + """.format( + github_project=github_project + ) + ), + ) + + # This allows travis_docker_build.sh not to know the main Github project. + local_repo_dir = opts.pop("local_repo_dir", None) + if local_repo_dir is not None: + opts["{0}:local_repo_dir".format(github_project)] = local_repo_dir + + if (opts.get("os_image"), opts.get("gcc_version")) not in valid_versions: + raise Exception( + "Due to 4/5 ABI changes (std::string), we can only use {0}".format( + " / ".join("GCC {1} on {0}".format(*p) for p in valid_versions) + ) + ) + + if opts.get("docker_context_dir") is None: + opts["docker_context_dir"] = tempfile.mkdtemp(prefix="docker-context-") + elif not os.path.exists(opts.get("docker_context_dir")): + os.makedirs(opts.get("docker_context_dir")) + + builder = DockerFBCodeBuilder(**opts) + context_dir = builder.option("docker_context_dir") # Mark option "in-use" + # The renderer may also populate some files into the context_dir. + dockerfile = builder.render(get_steps_fn(builder)) + + with os.fdopen( + os.open( + os.path.join(context_dir, "Dockerfile"), + os.O_RDWR | os.O_CREAT | os.O_EXCL, # Do not overwrite existing files + 0o644, + ), + "w", + ) as f: + f.write(dockerfile) + + return context_dir + + +if __name__ == "__main__": + from utils import read_fbcode_builder_config, build_fbcode_builder_config + + # Load a spec from the current directory + config = read_fbcode_builder_config("fbcode_builder_config.py") + print( + make_docker_context( + build_fbcode_builder_config(config), + config["github_project"], + ) + ) diff --git a/build/fbcode_builder/manifests/CLI11 b/build/fbcode_builder/manifests/CLI11 new file mode 100644 index 000000000..14cb2332a --- /dev/null +++ b/build/fbcode_builder/manifests/CLI11 @@ -0,0 +1,14 @@ +[manifest] +name = CLI11 + +[download] +url = https://github.com/CLIUtils/CLI11/archive/v2.0.0.tar.gz +sha256 = 2c672f17bf56e8e6223a3bfb74055a946fa7b1ff376510371902adb9cb0ab6a3 + +[build] +builder = cmake +subdir = CLI11-2.0.0 + +[cmake.defines] +CLI11_BUILD_TESTS = OFF +CLI11_BUILD_EXAMPLES = OFF diff --git a/build/fbcode_builder/manifests/OpenNSA b/build/fbcode_builder/manifests/OpenNSA new file mode 100644 index 000000000..62354c997 --- /dev/null +++ b/build/fbcode_builder/manifests/OpenNSA @@ -0,0 +1,17 @@ +[manifest] +name = OpenNSA + +[download] +url = https://docs.broadcom.com/docs-and-downloads/csg/opennsa-6.5.22.tgz +sha256 = 74bfbdaebb6bfe9ebb0deac3aff624385cdcf5aa416ba63706c36538b3c3c46c + +[build] +builder = nop +subdir = opennsa-6.5.22 + +[install.files] +lib/x86-64 = lib +include = include +src/gpl-modules/systems/bde/linux/include = include/systems/bde/linux +src/gpl-modules/include/ibde.h = include/ibde.h +src/gpl-modules = src/gpl-modules diff --git a/build/fbcode_builder/manifests/autoconf b/build/fbcode_builder/manifests/autoconf new file mode 100644 index 000000000..35963096c --- /dev/null +++ b/build/fbcode_builder/manifests/autoconf @@ -0,0 +1,16 @@ +[manifest] +name = autoconf + +[rpms] +autoconf + +[debs] +autoconf + +[download] +url = http://ftp.gnu.org/gnu/autoconf/autoconf-2.69.tar.gz +sha256 = 954bd69b391edc12d6a4a51a2dd1476543da5c6bbf05a95b59dc0dd6fd4c2969 + +[build] +builder = autoconf +subdir = autoconf-2.69 diff --git a/build/fbcode_builder/manifests/automake b/build/fbcode_builder/manifests/automake new file mode 100644 index 000000000..71115068a --- /dev/null +++ b/build/fbcode_builder/manifests/automake @@ -0,0 +1,19 @@ +[manifest] +name = automake + +[rpms] +automake + +[debs] +automake + +[download] +url = http://ftp.gnu.org/gnu/automake/automake-1.16.1.tar.gz +sha256 = 608a97523f97db32f1f5d5615c98ca69326ced2054c9f82e65bade7fc4c9dea8 + +[build] +builder = autoconf +subdir = automake-1.16.1 + +[dependencies] +autoconf diff --git a/build/fbcode_builder/manifests/bison b/build/fbcode_builder/manifests/bison new file mode 100644 index 000000000..6e355d052 --- /dev/null +++ b/build/fbcode_builder/manifests/bison @@ -0,0 +1,27 @@ +[manifest] +name = bison + +[rpms] +bison + +[debs] +bison + +[download.not(os=windows)] +url = https://mirrors.kernel.org/gnu/bison/bison-3.3.tar.gz +sha256 = fdeafb7fffade05604a61e66b8c040af4b2b5cbb1021dcfe498ed657ac970efd + +[download.os=windows] +url = https://github.com/lexxmark/winflexbison/releases/download/v2.5.17/winflexbison-2.5.17.zip +sha256 = 3dc27a16c21b717bcc5de8590b564d4392a0b8577170c058729d067d95ded825 + +[build.not(os=windows)] +builder = autoconf +subdir = bison-3.3 + +[build.os=windows] +builder = nop + +[install.files.os=windows] +data = bin/data +win_bison.exe = bin/bison.exe diff --git a/build/fbcode_builder/manifests/bistro b/build/fbcode_builder/manifests/bistro new file mode 100644 index 000000000..d93839275 --- /dev/null +++ b/build/fbcode_builder/manifests/bistro @@ -0,0 +1,28 @@ +[manifest] +name = bistro +fbsource_path = fbcode/bistro +shipit_project = bistro +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/bistro.git + +[build.os=linux] +builder = bistro + +# Bistro is Linux-specific +[build.not(os=linux)] +builder = nop + +[dependencies] +fmt +folly +proxygen +fbthrift +libsodium +googletest_1_8 +sqlite3 + +[shipit.pathmap] +fbcode/bistro/public_tld = . +fbcode/bistro = bistro diff --git a/build/fbcode_builder/manifests/boost b/build/fbcode_builder/manifests/boost new file mode 100644 index 000000000..4b254e308 --- /dev/null +++ b/build/fbcode_builder/manifests/boost @@ -0,0 +1,86 @@ +[manifest] +name = boost + +[download.not(os=windows)] +url = https://versaweb.dl.sourceforge.net/project/boost/boost/1.69.0/boost_1_69_0.tar.bz2 +sha256 = 8f32d4617390d1c2d16f26a27ab60d97807b35440d45891fa340fc2648b04406 + +[download.os=windows] +url = https://versaweb.dl.sourceforge.net/project/boost/boost/1.69.0/boost_1_69_0.zip +sha256 = d074bcbcc0501c4917b965fc890e303ee70d8b01ff5712bae4a6c54f2b6b4e52 + +[preinstalled.env] +BOOST_ROOT_1_69_0 + +[debs] +libboost-all-dev + +[rpms] +boost +boost-math +boost-test +boost-fiber +boost-graph +boost-log +boost-openmpi +boost-timer +boost-chrono +boost-locale +boost-thread +boost-atomic +boost-random +boost-static +boost-contract +boost-date-time +boost-iostreams +boost-container +boost-coroutine +boost-filesystem +boost-system +boost-stacktrace +boost-regex +boost-devel +boost-context +boost-python3-devel +boost-type_erasure +boost-wave +boost-python3 +boost-serialization +boost-program-options + +[build] +builder = boost + +[b2.args] +--with-atomic +--with-chrono +--with-container +--with-context +--with-contract +--with-coroutine +--with-date_time +--with-exception +--with-fiber +--with-filesystem +--with-graph +--with-graph_parallel +--with-iostreams +--with-locale +--with-log +--with-math +--with-mpi +--with-program_options +--with-python +--with-random +--with-regex +--with-serialization +--with-stacktrace +--with-system +--with-test +--with-thread +--with-timer +--with-type_erasure +--with-wave + +[b2.args.os=darwin] +toolset=clang diff --git a/build/fbcode_builder/manifests/cmake b/build/fbcode_builder/manifests/cmake new file mode 100644 index 000000000..f756caed0 --- /dev/null +++ b/build/fbcode_builder/manifests/cmake @@ -0,0 +1,43 @@ +[manifest] +name = cmake + +[rpms] +cmake + +# All current deb based distros have a cmake that is too old +#[debs] +#cmake + +[dependencies] +ninja + +[download.os=windows] +url = https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-win64-x64.zip +sha256 = 40e8140d68120378262322bbc8c261db8d184d7838423b2e5bf688a6209d3807 + +[download.os=darwin] +url = https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0-Darwin-x86_64.tar.gz +sha256 = a02ad0d5b955dfad54c095bd7e937eafbbbfe8a99860107025cc442290a3e903 + +[download.os=linux] +url = https://github.com/Kitware/CMake/releases/download/v3.14.0/cmake-3.14.0.tar.gz +sha256 = aa76ba67b3c2af1946701f847073f4652af5cbd9f141f221c97af99127e75502 + +[build.os=windows] +builder = nop +subdir = cmake-3.14.0-win64-x64 + +[build.os=darwin] +builder = nop +subdir = cmake-3.14.0-Darwin-x86_64 + +[install.files.os=darwin] +CMake.app/Contents/bin = bin +CMake.app/Contents/share = share + +[build.os=linux] +builder = cmakebootstrap +subdir = cmake-3.14.0 + +[make.install_args.os=linux] +install diff --git a/build/fbcode_builder/manifests/cpptoml b/build/fbcode_builder/manifests/cpptoml new file mode 100644 index 000000000..5a3c781dc --- /dev/null +++ b/build/fbcode_builder/manifests/cpptoml @@ -0,0 +1,10 @@ +[manifest] +name = cpptoml + +[download] +url = https://github.com/skystrife/cpptoml/archive/v0.1.1.tar.gz +sha256 = 23af72468cfd4040984d46a0dd2a609538579c78ddc429d6b8fd7a10a6e24403 + +[build] +builder = cmake +subdir = cpptoml-0.1.1 diff --git a/build/fbcode_builder/manifests/delos_core b/build/fbcode_builder/manifests/delos_core new file mode 100644 index 000000000..1de6c3342 --- /dev/null +++ b/build/fbcode_builder/manifests/delos_core @@ -0,0 +1,25 @@ +[manifest] +name = delos_core +fbsource_path = fbcode/delos_core +shipit_project = delos_core +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/delos_core.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +builder = nop + +[dependencies] +glog +googletest +folly +fbthrift +fb303 +re2 + +[shipit.pathmap] +fbcode/delos_core = . diff --git a/build/fbcode_builder/manifests/double-conversion b/build/fbcode_builder/manifests/double-conversion new file mode 100644 index 000000000..e27c7ae06 --- /dev/null +++ b/build/fbcode_builder/manifests/double-conversion @@ -0,0 +1,11 @@ +[manifest] +name = double-conversion + +[download] +url = https://github.com/google/double-conversion/archive/v3.1.4.tar.gz +sha256 = 95004b65e43fefc6100f337a25da27bb99b9ef8d4071a36a33b5e83eb1f82021 + +[build] +builder = cmake +subdir = double-conversion-3.1.4 + diff --git a/build/fbcode_builder/manifests/eden b/build/fbcode_builder/manifests/eden new file mode 100644 index 000000000..700cc82ec --- /dev/null +++ b/build/fbcode_builder/manifests/eden @@ -0,0 +1,70 @@ +[manifest] +name = eden +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build] +builder = cmake + +[dependencies] +googletest +folly +fbthrift +fb303 +cpptoml +rocksdb +re2 +libgit2 +lz4 +pexpect +python-toml + +[dependencies.fb=on] +rust + +# macOS ships with sqlite3, and some of the core system +# frameworks require that that version be linked rather +# than the one we might build for ourselves here, so we +# skip building it on macos. +[dependencies.not(os=darwin)] +sqlite3 + +[dependencies.os=darwin] +osxfuse + +# TODO: teach getdeps to compile curl on Windows. +# Enabling curl on Windows requires us to find a way to compile libcurl with +# msvc. +[dependencies.not(os=windows)] +libcurl + +[shipit.pathmap] +fbcode/common/rust/shed/hostcaps = common/rust/shed/hostcaps +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/tools/lfs = tools/lfs +fbcode/thrift/lib/rust = thrift/lib/rust + +[shipit.strip] +^fbcode/eden/fs/eden-config\.h$ +^fbcode/eden/fs/py/eden/config\.py$ +^fbcode/eden/hg/.*$ +^fbcode/eden/mononoke/(?!lfs_protocol) +^fbcode/eden/scm/build/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/.*/\.cargo/.*$ +/Cargo\.lock$ +\.pyc$ + +[cmake.defines.all(fb=on,os=windows)] +INSTALL_PYTHON_LIB=ON + +[cmake.defines.fb=on] +USE_CARGO_VENDOR=ON + +[depends.environment] +EDEN_VERSION_OVERRIDE diff --git a/build/fbcode_builder/manifests/eden_scm b/build/fbcode_builder/manifests/eden_scm new file mode 100644 index 000000000..cfe9c7096 --- /dev/null +++ b/build/fbcode_builder/manifests/eden_scm @@ -0,0 +1,57 @@ +[manifest] +name = eden_scm +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build.not(os=windows)] +builder = make +subdir = eden/scm +disable_env_override_pkgconfig = 1 +disable_env_override_path = 1 + +[build.os=windows] +# For now the biggest blocker is missing "make" on windows, but there are bound +# to be more +builder = nop + +[make.build_args] +getdepsbuild + +[make.install_args] +install-getdeps + +[make.test_args] +test-getdeps + +[shipit.pathmap] +fbcode/common/rust = common/rust +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/tools/lfs = tools/lfs +fbcode/fboss/common = common + +[shipit.strip] +^fbcode/eden/fs/eden-config\.h$ +^fbcode/eden/fs/py/eden/config\.py$ +^fbcode/eden/hg/.*$ +^fbcode/eden/mononoke/(?!lfs_protocol) +^fbcode/eden/scm/build/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/.*/\.cargo/.*$ +^.*/fb/.*$ +/Cargo\.lock$ +\.pyc$ + +[dependencies] +fb303-source +fbthrift +fbthrift-source +openssl +rust-shed + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/eden_scm_lib_edenapi_tools b/build/fbcode_builder/manifests/eden_scm_lib_edenapi_tools new file mode 100644 index 000000000..be29d70f8 --- /dev/null +++ b/build/fbcode_builder/manifests/eden_scm_lib_edenapi_tools @@ -0,0 +1,36 @@ +[manifest] +name = eden_scm_lib_edenapi_tools +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build] +builder = cargo + +[cargo] +build_doc = true +manifests_to_build = eden/scm/lib/edenapi/tools/make_req/Cargo.toml,eden/scm/lib/edenapi/tools/read_res/Cargo.toml + +[shipit.pathmap] +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/tools/lfs = tools/lfs +fbcode/fboss/common = common + +[shipit.strip] +^fbcode/eden/fs/eden-config\.h$ +^fbcode/eden/fs/py/eden/config\.py$ +^fbcode/eden/hg/.*$ +^fbcode/eden/mononoke/(?!lfs_protocol) +^fbcode/eden/scm/build/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/.*/\.cargo/.*$ +^.*/fb/.*$ +/Cargo\.lock$ +\.pyc$ + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/f4d b/build/fbcode_builder/manifests/f4d new file mode 100644 index 000000000..db30894c7 --- /dev/null +++ b/build/fbcode_builder/manifests/f4d @@ -0,0 +1,29 @@ +[manifest] +name = f4d +fbsource_path = fbcode/f4d +shipit_project = f4d +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexternal/f4d.git + +[build.os=windows] +builder = nop + +[build.not(os=windows)] +builder = cmake + +[dependencies] +double-conversion +folly +glog +googletest +boost +protobuf +lzo +libicu +re2 + +[shipit.pathmap] +fbcode/f4d/public_tld = . +fbcode/f4d = f4d diff --git a/build/fbcode_builder/manifests/fatal b/build/fbcode_builder/manifests/fatal new file mode 100644 index 000000000..3c333561f --- /dev/null +++ b/build/fbcode_builder/manifests/fatal @@ -0,0 +1,15 @@ +[manifest] +name = fatal +fbsource_path = fbcode/fatal +shipit_project = fatal + +[git] +repo_url = https://github.com/facebook/fatal.git + +[shipit.pathmap] +fbcode/fatal = . +fbcode/fatal/public_tld = . + +[build] +builder = nop +subdir = . diff --git a/build/fbcode_builder/manifests/fb303 b/build/fbcode_builder/manifests/fb303 new file mode 100644 index 000000000..743aca01e --- /dev/null +++ b/build/fbcode_builder/manifests/fb303 @@ -0,0 +1,27 @@ +[manifest] +name = fb303 +fbsource_path = fbcode/fb303 +shipit_project = fb303 +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/fb303.git + +[build] +builder = cmake + +[dependencies] +folly +gflags +glog +fbthrift + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF + +[shipit.pathmap] +fbcode/fb303/github = . +fbcode/fb303 = fb303 diff --git a/build/fbcode_builder/manifests/fb303-source b/build/fbcode_builder/manifests/fb303-source new file mode 100644 index 000000000..ea160c500 --- /dev/null +++ b/build/fbcode_builder/manifests/fb303-source @@ -0,0 +1,15 @@ +[manifest] +name = fb303-source +fbsource_path = fbcode/fb303 +shipit_project = fb303 +shipit_fbcode_builder = false + +[git] +repo_url = https://github.com/facebook/fb303.git + +[build] +builder = nop + +[shipit.pathmap] +fbcode/fb303/github = . +fbcode/fb303 = fb303 diff --git a/build/fbcode_builder/manifests/fboss b/build/fbcode_builder/manifests/fboss new file mode 100644 index 000000000..f29873e72 --- /dev/null +++ b/build/fbcode_builder/manifests/fboss @@ -0,0 +1,42 @@ +[manifest] +name = fboss +fbsource_path = fbcode/fboss +shipit_project = fboss +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fboss.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +builder = nop + +[dependencies] +folly +fb303 +wangle +fizz +fmt +libsodium +googletest +zstd +fbthrift +iproute2 +libmnl +libusb +libcurl +libnl +libsai +OpenNSA +re2 +python +yaml-cpp +libyaml +CLI11 + +[shipit.pathmap] +fbcode/fboss/github = . +fbcode/fboss/common = common +fbcode/fboss = fboss diff --git a/build/fbcode_builder/manifests/fbthrift b/build/fbcode_builder/manifests/fbthrift new file mode 100644 index 000000000..072dd4512 --- /dev/null +++ b/build/fbcode_builder/manifests/fbthrift @@ -0,0 +1,33 @@ +[manifest] +name = fbthrift +fbsource_path = fbcode/thrift +shipit_project = fbthrift +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fbthrift.git + +[build] +builder = cmake + +[dependencies] +bison +flex +folly +wangle +fizz +fmt +googletest +libsodium +python-six +zstd + +[shipit.pathmap] +fbcode/thrift/public_tld = . +fbcode/thrift = thrift + +[shipit.strip] +^fbcode/thrift/thrift-config\.h$ +^fbcode/thrift/perf/canary.py$ +^fbcode/thrift/perf/loadtest.py$ +^fbcode/thrift/.castle/.* diff --git a/build/fbcode_builder/manifests/fbthrift-source b/build/fbcode_builder/manifests/fbthrift-source new file mode 100644 index 000000000..7af0d6dda --- /dev/null +++ b/build/fbcode_builder/manifests/fbthrift-source @@ -0,0 +1,21 @@ +[manifest] +name = fbthrift-source +fbsource_path = fbcode/thrift +shipit_project = fbthrift +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fbthrift.git + +[build] +builder = nop + +[shipit.pathmap] +fbcode/thrift/public_tld = . +fbcode/thrift = thrift + +[shipit.strip] +^fbcode/thrift/thrift-config\.h$ +^fbcode/thrift/perf/canary.py$ +^fbcode/thrift/perf/loadtest.py$ +^fbcode/thrift/.castle/.* diff --git a/build/fbcode_builder/manifests/fbzmq b/build/fbcode_builder/manifests/fbzmq new file mode 100644 index 000000000..5739016c8 --- /dev/null +++ b/build/fbcode_builder/manifests/fbzmq @@ -0,0 +1,29 @@ +[manifest] +name = fbzmq +fbsource_path = facebook/fbzmq +shipit_project = fbzmq +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/fbzmq.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +# boost.fiber is required and that is not available on macos. +# libzmq doesn't currently build on windows. +builder = nop + +[dependencies] +boost +folly +fbthrift +googletest +libzmq + +[shipit.pathmap] +fbcode/fbzmq = fbzmq +fbcode/fbzmq/public_tld = . + +[shipit.strip] diff --git a/build/fbcode_builder/manifests/fizz b/build/fbcode_builder/manifests/fizz new file mode 100644 index 000000000..72f29973f --- /dev/null +++ b/build/fbcode_builder/manifests/fizz @@ -0,0 +1,36 @@ +[manifest] +name = fizz +fbsource_path = fbcode/fizz +shipit_project = fizz +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/fizz.git + +[build] +builder = cmake +subdir = fizz + +[cmake.defines] +BUILD_EXAMPLES = OFF + +[cmake.defines.test=on] +BUILD_TESTS = ON + +[cmake.defines.all(os=windows, test=on)] +BUILD_TESTS = OFF + +[cmake.defines.test=off] +BUILD_TESTS = OFF + +[dependencies] +folly +libsodium +zstd + +[dependencies.all(test=on, not(os=windows))] +googletest_1_8 + +[shipit.pathmap] +fbcode/fizz/public_tld = . +fbcode/fizz = fizz diff --git a/build/fbcode_builder/manifests/flex b/build/fbcode_builder/manifests/flex new file mode 100644 index 000000000..f266c4033 --- /dev/null +++ b/build/fbcode_builder/manifests/flex @@ -0,0 +1,32 @@ +[manifest] +name = flex + +[rpms] +flex + +[debs] +flex + +[download.not(os=windows)] +url = https://github.com/westes/flex/releases/download/v2.6.4/flex-2.6.4.tar.gz +sha256 = e87aae032bf07c26f85ac0ed3250998c37621d95f8bd748b31f15b33c45ee995 + +[download.os=windows] +url = https://github.com/lexxmark/winflexbison/releases/download/v2.5.17/winflexbison-2.5.17.zip +sha256 = 3dc27a16c21b717bcc5de8590b564d4392a0b8577170c058729d067d95ded825 + +[build.not(os=windows)] +builder = autoconf +subdir = flex-2.6.4 + +[build.os=windows] +builder = nop + +[install.files.os=windows] +data = bin/data +win_flex.exe = bin/flex.exe + +# Moral equivalent to this PR that fixes a crash when bootstrapping flex +# on linux: https://github.com/easybuilders/easybuild-easyconfigs/pull/5792 +[autoconf.args.os=linux] +CFLAGS=-D_GNU_SOURCE diff --git a/build/fbcode_builder/manifests/fmt b/build/fbcode_builder/manifests/fmt new file mode 100644 index 000000000..21503d202 --- /dev/null +++ b/build/fbcode_builder/manifests/fmt @@ -0,0 +1,14 @@ +[manifest] +name = fmt + +[download] +url = https://github.com/fmtlib/fmt/archive/6.1.1.tar.gz +sha256 = bf4e50955943c1773cc57821d6c00f7e2b9e10eb435fafdd66739d36056d504e + +[build] +builder = cmake +subdir = fmt-6.1.1 + +[cmake.defines] +FMT_TEST = OFF +FMT_DOC = OFF diff --git a/build/fbcode_builder/manifests/folly b/build/fbcode_builder/manifests/folly new file mode 100644 index 000000000..9647b17f8 --- /dev/null +++ b/build/fbcode_builder/manifests/folly @@ -0,0 +1,58 @@ +[manifest] +name = folly +fbsource_path = fbcode/folly +shipit_project = folly +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/folly.git + +[build] +builder = cmake + +[dependencies] +gflags +glog +googletest +boost +libevent +double-conversion +fmt +lz4 +snappy +zstd +# no openssl or zlib in the linux case, why? +# these are usually installed on the system +# and are the easiest system deps to pull in. +# In the future we want to be able to express +# that a system dep is sufficient in the manifest +# for eg: openssl and zlib, but for now we don't +# have it. + +# macOS doesn't expose the openssl api so we need +# to build our own. +[dependencies.os=darwin] +openssl + +# Windows has neither openssl nor zlib, so we get +# to provide both +[dependencies.os=windows] +openssl +zlib + +[shipit.pathmap] +fbcode/folly/public_tld = . +fbcode/folly = folly + +[shipit.strip] +^fbcode/folly/folly-config\.h$ +^fbcode/folly/public_tld/build/facebook_.* + +[cmake.defines] +BUILD_SHARED_LIBS=OFF + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF diff --git a/build/fbcode_builder/manifests/gflags b/build/fbcode_builder/manifests/gflags new file mode 100644 index 000000000..d7ec44eab --- /dev/null +++ b/build/fbcode_builder/manifests/gflags @@ -0,0 +1,17 @@ +[manifest] +name = gflags + +[download] +url = https://github.com/gflags/gflags/archive/v2.2.2.tar.gz +sha256 = 34af2f15cf7367513b352bdcd2493ab14ce43692d2dcd9dfc499492966c64dcf + +[build] +builder = cmake +subdir = gflags-2.2.2 + +[cmake.defines] +BUILD_SHARED_LIBS = ON +BUILD_STATIC_LIBS = ON +#BUILD_gflags_nothreads_LIB = OFF +BUILD_gflags_LIB = ON + diff --git a/build/fbcode_builder/manifests/git-lfs b/build/fbcode_builder/manifests/git-lfs new file mode 100644 index 000000000..38a5e6aeb --- /dev/null +++ b/build/fbcode_builder/manifests/git-lfs @@ -0,0 +1,12 @@ +[manifest] +name = git-lfs + +[download.os=linux] +url = https://github.com/git-lfs/git-lfs/releases/download/v2.9.1/git-lfs-linux-amd64-v2.9.1.tar.gz +sha256 = 2a8e60cf51ec45aa0f4332aa0521d60ec75c76e485d13ebaeea915b9d70ea466 + +[build] +builder = nop + +[install.files] +git-lfs = bin/git-lfs diff --git a/build/fbcode_builder/manifests/glog b/build/fbcode_builder/manifests/glog new file mode 100644 index 000000000..d2354610a --- /dev/null +++ b/build/fbcode_builder/manifests/glog @@ -0,0 +1,16 @@ +[manifest] +name = glog + +[download] +url = https://github.com/google/glog/archive/v0.4.0.tar.gz +sha256 = f28359aeba12f30d73d9e4711ef356dc842886968112162bc73002645139c39c + +[build] +builder = cmake +subdir = glog-0.4.0 + +[dependencies] +gflags + +[cmake.defines] +BUILD_SHARED_LIBS=ON diff --git a/build/fbcode_builder/manifests/gnu-bash b/build/fbcode_builder/manifests/gnu-bash new file mode 100644 index 000000000..89da77ca2 --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-bash @@ -0,0 +1,20 @@ +[manifest] +name = gnu-bash + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/bash/bash-5.1-rc1.tar.gz +sha256 = 0b2684eb1990329d499c96decfe2459f3e150deb915b0a9d03cf1be692b1d6d3 + +[build.os=darwin] +# The buildin FreeBSD bash on OSX is both outdated and incompatible with the +# modern GNU bash, so for the sake of being cross-platform friendly this +# manifest provides GNU bash. +# NOTE: This is the 5.1-rc1 version, which is almost the same as what Homebrew +# uses (Homebrew installs 5.0 with the 18 patches that in fact make the 5.1-rc1 +# version). +builder = autoconf +subdir = bash-5.1-rc1 +build_in_src_dir = true + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/gnu-coreutils b/build/fbcode_builder/manifests/gnu-coreutils new file mode 100644 index 000000000..1ab4d9d4a --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-coreutils @@ -0,0 +1,15 @@ +[manifest] +name = gnu-coreutils + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/coreutils/coreutils-8.32.tar.gz +sha256 = d5ab07435a74058ab69a2007e838be4f6a90b5635d812c2e26671e3972fca1b8 + +[build.os=darwin] +# The buildin FreeBSD version incompatible with the GNU one, so for the sake of +# being cross-platform friendly this manifest provides the GNU version. +builder = autoconf +subdir = coreutils-8.32 + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/gnu-grep b/build/fbcode_builder/manifests/gnu-grep new file mode 100644 index 000000000..e6a163d37 --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-grep @@ -0,0 +1,15 @@ +[manifest] +name = gnu-grep + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/grep/grep-3.5.tar.gz +sha256 = 9897220992a8fd38a80b70731462defa95f7ff2709b235fb54864ddd011141dd + +[build.os=darwin] +# The buildin FreeBSD version incompatible with the GNU one, so for the sake of +# being cross-platform friendly this manifest provides the GNU version. +builder = autoconf +subdir = grep-3.5 + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/gnu-sed b/build/fbcode_builder/manifests/gnu-sed new file mode 100644 index 000000000..9b458df6e --- /dev/null +++ b/build/fbcode_builder/manifests/gnu-sed @@ -0,0 +1,15 @@ +[manifest] +name = gnu-sed + +[download.os=darwin] +url = https://ftp.gnu.org/gnu/sed/sed-4.8.tar.gz +sha256 = 53cf3e14c71f3a149f29d13a0da64120b3c1d3334fba39c4af3e520be053982a + +[build.os=darwin] +# The buildin FreeBSD version incompatible with the GNU one, so for the sake of +# being cross-platform friendly this manifest provides the GNU version. +builder = autoconf +subdir = sed-4.8 + +[build.not(os=darwin)] +builder = nop diff --git a/build/fbcode_builder/manifests/googletest b/build/fbcode_builder/manifests/googletest new file mode 100644 index 000000000..775aac34f --- /dev/null +++ b/build/fbcode_builder/manifests/googletest @@ -0,0 +1,18 @@ +[manifest] +name = googletest + +[download] +url = https://github.com/google/googletest/archive/release-1.10.0.tar.gz +sha256 = 9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb + +[build] +builder = cmake +subdir = googletest-release-1.10.0 + +[cmake.defines] +# Everything else defaults to the shared runtime, so tell gtest that +# it should not use its choice of the static runtime +gtest_force_shared_crt=ON + +[cmake.defines.os=windows] +BUILD_SHARED_LIBS=ON diff --git a/build/fbcode_builder/manifests/googletest_1_8 b/build/fbcode_builder/manifests/googletest_1_8 new file mode 100644 index 000000000..76c0ce51f --- /dev/null +++ b/build/fbcode_builder/manifests/googletest_1_8 @@ -0,0 +1,18 @@ +[manifest] +name = googletest_1_8 + +[download] +url = https://github.com/google/googletest/archive/release-1.8.0.tar.gz +sha256 = 58a6f4277ca2bc8565222b3bbd58a177609e9c488e8a72649359ba51450db7d8 + +[build] +builder = cmake +subdir = googletest-release-1.8.0 + +[cmake.defines] +# Everything else defaults to the shared runtime, so tell gtest that +# it should not use its choice of the static runtime +gtest_force_shared_crt=ON + +[cmake.defines.os=windows] +BUILD_SHARED_LIBS=ON diff --git a/build/fbcode_builder/manifests/gperf b/build/fbcode_builder/manifests/gperf new file mode 100644 index 000000000..13d7a890f --- /dev/null +++ b/build/fbcode_builder/manifests/gperf @@ -0,0 +1,14 @@ +[manifest] +name = gperf + +[download] +url = http://ftp.gnu.org/pub/gnu/gperf/gperf-3.1.tar.gz +sha256 = 588546b945bba4b70b6a3a616e80b4ab466e3f33024a352fc2198112cdbb3ae2 + +[build.not(os=windows)] +builder = autoconf +subdir = gperf-3.1 + +[build.os=windows] +builder = nop + diff --git a/build/fbcode_builder/manifests/iproute2 b/build/fbcode_builder/manifests/iproute2 new file mode 100644 index 000000000..6fb7f77ed --- /dev/null +++ b/build/fbcode_builder/manifests/iproute2 @@ -0,0 +1,13 @@ +[manifest] +name = iproute2 + +[download] +url = https://mirrors.edge.kernel.org/pub/linux/utils/net/iproute2/iproute2-4.12.0.tar.gz +sha256 = 46612a1e2d01bb31932557bccdb1b8618cae9a439dfffc08ef35ed8e197f14ce + +[build.os=linux] +builder = iproute2 +subdir = iproute2-4.12.0 + +[build.not(os=linux)] +builder = nop diff --git a/build/fbcode_builder/manifests/jq b/build/fbcode_builder/manifests/jq new file mode 100644 index 000000000..231818f34 --- /dev/null +++ b/build/fbcode_builder/manifests/jq @@ -0,0 +1,24 @@ +[manifest] +name = jq + +[rpms] +jq + +[debs] +jq + +[download.not(os=windows)] +url = https://github.com/stedolan/jq/releases/download/jq-1.5/jq-1.5.tar.gz +sha256 = c4d2bfec6436341113419debf479d833692cc5cdab7eb0326b5a4d4fbe9f493c + +[build.not(os=windows)] +builder = autoconf +subdir = jq-1.5 + +[build.os=windows] +builder = nop + +[autoconf.args] +# This argument turns off some developers tool and it is recommended in jq's +# README +--disable-maintainer-mode diff --git a/build/fbcode_builder/manifests/katran b/build/fbcode_builder/manifests/katran new file mode 100644 index 000000000..224ccbe21 --- /dev/null +++ b/build/fbcode_builder/manifests/katran @@ -0,0 +1,38 @@ +[manifest] +name = katran +fbsource_path = fbcode/katran +shipit_project = katran +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/katran.git + +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = cmake +subdir = . + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF + +[dependencies] +folly +fizz +libbpf +libmnl +zlib +googletest + + +[shipit.pathmap] +fbcode/katran/public_root = . +fbcode/katran = katran + +[shipit.strip] +^fbcode/katran/facebook +^fbcode/katran/OSS_SYNC diff --git a/build/fbcode_builder/manifests/libbpf b/build/fbcode_builder/manifests/libbpf new file mode 100644 index 000000000..0416822e4 --- /dev/null +++ b/build/fbcode_builder/manifests/libbpf @@ -0,0 +1,26 @@ +[manifest] +name = libbpf + +[download] +url = https://github.com/libbpf/libbpf/archive/v0.3.tar.gz +sha256 = c168d84a75b541f753ceb49015d9eb886e3fb5cca87cdd9aabce7e10ad3a1efc + +# BPF only builds on linux, so make it a NOP on other platforms +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = make +subdir = libbpf-0.3/src + +[make.build_args] +BUILD_STATIC_ONLY=y + +# libbpf-0.3 requires uapi headers >= 5.8 +[make.install_args] +install +install_uapi_headers +BUILD_STATIC_ONLY=y + +[dependencies] +libelf diff --git a/build/fbcode_builder/manifests/libbpf_0_2_0_beta b/build/fbcode_builder/manifests/libbpf_0_2_0_beta new file mode 100644 index 000000000..072639817 --- /dev/null +++ b/build/fbcode_builder/manifests/libbpf_0_2_0_beta @@ -0,0 +1,26 @@ +[manifest] +name = libbpf_0_2_0_beta + +[download] +url = https://github.com/libbpf/libbpf/archive/b6dd2f2.tar.gz +sha256 = 8db9dca90f5c445ef2362e3c6a00f3d6c4bf36e8782f8e27704109c78e541497 + +# BPF only builds on linux, so make it a NOP on other platforms +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = make +subdir = libbpf-b6dd2f2b7df4d3bd35d64aaf521d9ad18d766f53/src + +[make.build_args] +BUILD_STATIC_ONLY=y + +# libbpf now requires uapi headers >= 5.8 +[make.install_args] +install +install_uapi_headers +BUILD_STATIC_ONLY=y + +[dependencies] +libelf diff --git a/build/fbcode_builder/manifests/libcurl b/build/fbcode_builder/manifests/libcurl new file mode 100644 index 000000000..466b4497c --- /dev/null +++ b/build/fbcode_builder/manifests/libcurl @@ -0,0 +1,39 @@ +[manifest] +name = libcurl + +[rpms] +libcurl-devel +libcurl + +[debs] +libcurl4-openssl-dev + +[download] +url = https://curl.haxx.se/download/curl-7.65.1.tar.gz +sha256 = 821aeb78421375f70e55381c9ad2474bf279fc454b791b7e95fc83562951c690 + +[dependencies] +nghttp2 + +# We use system OpenSSL on Linux (see folly's manifest for details) +[dependencies.not(os=linux)] +openssl + +[build.not(os=windows)] +builder = autoconf +subdir = curl-7.65.1 + +[autoconf.args] +# fboss (which added the libcurl dep) doesn't need ldap so it is disabled here. +# if someone in the future wants to add ldap for something else, it won't hurt +# fboss. However, that would require adding an ldap manifest. +# +# For the same reason, we disable libssh2 and libidn2 which aren't really used +# but would require adding manifests if we don't disable them. +--disable-ldap +--without-libssh2 +--without-libidn2 + +[build.os=windows] +builder = cmake +subdir = curl-7.65.1 diff --git a/build/fbcode_builder/manifests/libelf b/build/fbcode_builder/manifests/libelf new file mode 100644 index 000000000..a46aab879 --- /dev/null +++ b/build/fbcode_builder/manifests/libelf @@ -0,0 +1,20 @@ +[manifest] +name = libelf + +[rpms] +elfutils-libelf-devel-static + +[debs] +libelf-dev + +[download] +url = https://ftp.osuosl.org/pub/blfs/conglomeration/libelf/libelf-0.8.13.tar.gz +sha256 = 591a9b4ec81c1f2042a97aa60564e0cb79d041c52faa7416acb38bc95bd2c76d + +# libelf only makes sense on linux, so make it a NOP on other platforms +[build.not(os=linux)] +builder = nop + +[build.os=linux] +builder = autoconf +subdir = libelf-0.8.13 diff --git a/build/fbcode_builder/manifests/libevent b/build/fbcode_builder/manifests/libevent new file mode 100644 index 000000000..eaa39a9e6 --- /dev/null +++ b/build/fbcode_builder/manifests/libevent @@ -0,0 +1,29 @@ +[manifest] +name = libevent + +[rpms] +libevent-devel + +[debs] +libevent-dev + +# Note that the CMakeLists.txt file is present only in +# git repo and not in the release tarball, so take care +# to use the github generated source tarball rather than +# the explicitly uploaded source tarball +[download] +url = https://github.com/libevent/libevent/archive/release-2.1.8-stable.tar.gz +sha256 = 316ddb401745ac5d222d7c529ef1eada12f58f6376a66c1118eee803cb70f83d + +[build] +builder = cmake +subdir = libevent-release-2.1.8-stable + +[cmake.defines] +EVENT__DISABLE_TESTS = ON +EVENT__DISABLE_BENCHMARK = ON +EVENT__DISABLE_SAMPLES = ON +EVENT__DISABLE_REGRESS = ON + +[dependencies.not(os=linux)] +openssl diff --git a/build/fbcode_builder/manifests/libgit2 b/build/fbcode_builder/manifests/libgit2 new file mode 100644 index 000000000..1d6a53e5e --- /dev/null +++ b/build/fbcode_builder/manifests/libgit2 @@ -0,0 +1,24 @@ +[manifest] +name = libgit2 + +[rpms] +libgit2-devel + +[debs] +libgit2-dev + +[download] +url = https://github.com/libgit2/libgit2/archive/v0.28.1.tar.gz +sha256 = 0ca11048795b0d6338f2e57717370208c2c97ad66c6d5eac0c97a8827d13936b + +[build] +builder = cmake +subdir = libgit2-0.28.1 + +[cmake.defines] +# Could turn this on if we also wanted to add a manifest for libssh2 +USE_SSH = OFF +BUILD_CLAR = OFF +# Have to build shared to work around annoying problems with cmake +# mis-parsing the frameworks required to link this on macos :-/ +BUILD_SHARED_LIBS = ON diff --git a/build/fbcode_builder/manifests/libicu b/build/fbcode_builder/manifests/libicu new file mode 100644 index 000000000..c1deda503 --- /dev/null +++ b/build/fbcode_builder/manifests/libicu @@ -0,0 +1,19 @@ +[manifest] +name = libicu + +[rpms] +libicu-devel + +[debs] +libicu-dev + +[download] +url = https://github.com/unicode-org/icu/releases/download/release-68-2/icu4c-68_2-src.tgz +sha256 = c79193dee3907a2199b8296a93b52c5cb74332c26f3d167269487680d479d625 + +[build.not(os=windows)] +builder = autoconf +subdir = icu/source + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/libmnl b/build/fbcode_builder/manifests/libmnl new file mode 100644 index 000000000..9b28b87b9 --- /dev/null +++ b/build/fbcode_builder/manifests/libmnl @@ -0,0 +1,17 @@ +[manifest] +name = libmnl + +[rpms] +libmnl-devel +libmnl-static + +[debs] +libmnl-dev + +[download] +url = http://www.netfilter.org/pub/libmnl/libmnl-1.0.4.tar.bz2 +sha256 = 171f89699f286a5854b72b91d06e8f8e3683064c5901fb09d954a9ab6f551f81 + +[build.os=linux] +builder = autoconf +subdir = libmnl-1.0.4 diff --git a/build/fbcode_builder/manifests/libnl b/build/fbcode_builder/manifests/libnl new file mode 100644 index 000000000..f864acb49 --- /dev/null +++ b/build/fbcode_builder/manifests/libnl @@ -0,0 +1,17 @@ +[manifest] +name = libnl + +[rpms] +libnl3-devel +libnl3 + +[debs] +libnl-3-dev + +[download] +url = https://www.infradead.org/~tgr/libnl/files/libnl-3.2.25.tar.gz +sha256 = 8beb7590674957b931de6b7f81c530b85dc7c1ad8fbda015398bc1e8d1ce8ec5 + +[build.os=linux] +builder = autoconf +subdir = libnl-3.2.25 diff --git a/build/fbcode_builder/manifests/libsai b/build/fbcode_builder/manifests/libsai new file mode 100644 index 000000000..4f422d8e1 --- /dev/null +++ b/build/fbcode_builder/manifests/libsai @@ -0,0 +1,13 @@ +[manifest] +name = libsai + +[download] +url = https://github.com/opencomputeproject/SAI/archive/v1.7.1.tar.gz +sha256 = e18eb1a2a6e5dd286d97e13569d8b78cc1f8229030beed0db4775b9a50ab6a83 + +[build] +builder = nop +subdir = SAI-1.7.1 + +[install.files] +inc = include diff --git a/build/fbcode_builder/manifests/libsodium b/build/fbcode_builder/manifests/libsodium new file mode 100644 index 000000000..d69bfcc4b --- /dev/null +++ b/build/fbcode_builder/manifests/libsodium @@ -0,0 +1,33 @@ +[manifest] +name = libsodium + +[rpms] +libsodium-devel +libsodium-static + +[debs] +libsodium-dev + +[download.not(os=windows)] +url = https://github.com/jedisct1/libsodium/releases/download/1.0.17/libsodium-1.0.17.tar.gz +sha256 = 0cc3dae33e642cc187b5ceb467e0ad0e1b51dcba577de1190e9ffa17766ac2b1 + +[build.not(os=windows)] +builder = autoconf +subdir = libsodium-1.0.17 + +[download.os=windows] +url = https://download.libsodium.org/libsodium/releases/libsodium-1.0.17-msvc.zip +sha256 = f0f32ad8ebd76eee99bb039f843f583f2babca5288a8c26a7261db9694c11467 + +[build.os=windows] +builder = nop + +[install.files.os=windows] +x64/Release/v141/dynamic/libsodium.dll = bin/libsodium.dll +x64/Release/v141/dynamic/libsodium.lib = lib/libsodium.lib +x64/Release/v141/dynamic/libsodium.exp = lib/libsodium.exp +x64/Release/v141/dynamic/libsodium.pdb = lib/libsodium.pdb +include = include + +[autoconf.args] diff --git a/build/fbcode_builder/manifests/libtool b/build/fbcode_builder/manifests/libtool new file mode 100644 index 000000000..1ec99b5f4 --- /dev/null +++ b/build/fbcode_builder/manifests/libtool @@ -0,0 +1,22 @@ +[manifest] +name = libtool + +[rpms] +libtool + +[debs] +libtool + +[download] +url = http://ftp.gnu.org/gnu/libtool/libtool-2.4.6.tar.gz +sha256 = e3bd4d5d3d025a36c21dd6af7ea818a2afcd4dfc1ea5a17b39d7854bcd0c06e3 + +[build] +builder = autoconf +subdir = libtool-2.4.6 + +[dependencies] +automake + +[autoconf.args] +--enable-ltdl-install diff --git a/build/fbcode_builder/manifests/libusb b/build/fbcode_builder/manifests/libusb new file mode 100644 index 000000000..74702d3f0 --- /dev/null +++ b/build/fbcode_builder/manifests/libusb @@ -0,0 +1,23 @@ +[manifest] +name = libusb + +[rpms] +libusb-devel +libusb + +[debs] +libusb-1.0-0-dev + +[download] +url = https://github.com/libusb/libusb/releases/download/v1.0.22/libusb-1.0.22.tar.bz2 +sha256 = 75aeb9d59a4fdb800d329a545c2e6799f732362193b465ea198f2aa275518157 + +[build.os=linux] +builder = autoconf +subdir = libusb-1.0.22 + +[autoconf.args] +# fboss (which added the libusb dep) doesn't need udev so it is disabled here. +# if someone in the future wants to add udev for something else, it won't hurt +# fboss. +--disable-udev diff --git a/build/fbcode_builder/manifests/libyaml b/build/fbcode_builder/manifests/libyaml new file mode 100644 index 000000000..a7ff57316 --- /dev/null +++ b/build/fbcode_builder/manifests/libyaml @@ -0,0 +1,13 @@ +[manifest] +name = libyaml + +[download] +url = http://pyyaml.org/download/libyaml/yaml-0.1.7.tar.gz +sha256 = 8088e457264a98ba451a90b8661fcb4f9d6f478f7265d48322a196cec2480729 + +[build.os=linux] +builder = autoconf +subdir = yaml-0.1.7 + +[build.not(os=linux)] +builder = nop diff --git a/build/fbcode_builder/manifests/libzmq b/build/fbcode_builder/manifests/libzmq new file mode 100644 index 000000000..4f555fa65 --- /dev/null +++ b/build/fbcode_builder/manifests/libzmq @@ -0,0 +1,24 @@ +[manifest] +name = libzmq + +[rpms] +zeromq-devel +zeromq + +[debs] +libzmq3-dev + +[download] +url = https://github.com/zeromq/libzmq/releases/download/v4.3.1/zeromq-4.3.1.tar.gz +sha256 = bcbabe1e2c7d0eec4ed612e10b94b112dd5f06fcefa994a0c79a45d835cd21eb + + +[build] +builder = autoconf +subdir = zeromq-4.3.1 + +[autoconf.args] + +[dependencies] +autoconf +libtool diff --git a/build/fbcode_builder/manifests/lz4 b/build/fbcode_builder/manifests/lz4 new file mode 100644 index 000000000..03dbd9de4 --- /dev/null +++ b/build/fbcode_builder/manifests/lz4 @@ -0,0 +1,17 @@ +[manifest] +name = lz4 + +[rpms] +lz4-devel +lz4-static + +[debs] +liblz4-dev + +[download] +url = https://github.com/lz4/lz4/archive/v1.8.3.tar.gz +sha256 = 33af5936ac06536805f9745e0b6d61da606a1f8b4cc5c04dd3cbaca3b9b4fc43 + +[build] +builder = cmake +subdir = lz4-1.8.3/contrib/cmake_unofficial diff --git a/build/fbcode_builder/manifests/lzo b/build/fbcode_builder/manifests/lzo new file mode 100644 index 000000000..342428ab5 --- /dev/null +++ b/build/fbcode_builder/manifests/lzo @@ -0,0 +1,19 @@ +[manifest] +name = lzo + +[rpms] +lzo-devel + +[debs] +liblzo2-dev + +[download] +url = http://www.oberhumer.com/opensource/lzo/download/lzo-2.10.tar.gz +sha256 = c0f892943208266f9b6543b3ae308fab6284c5c90e627931446fb49b4221a072 + +[build.not(os=windows)] +builder = autoconf +subdir = lzo-2.10 + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/mononoke b/build/fbcode_builder/manifests/mononoke new file mode 100644 index 000000000..7df92c77b --- /dev/null +++ b/build/fbcode_builder/manifests/mononoke @@ -0,0 +1,44 @@ +[manifest] +name = mononoke +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/eden.git + +[build.not(os=windows)] +builder = cargo + +[build.os=windows] +# building Mononoke on windows is not supported +builder = nop + +[cargo] +build_doc = true +workspace_dir = eden/mononoke + +[shipit.pathmap] +fbcode/configerator/structs/scm/mononoke/public_autocargo = configerator/structs/scm/mononoke +fbcode/configerator/structs/scm/mononoke = configerator/structs/scm/mononoke +fbcode/eden/oss = . +fbcode/eden = eden +fbcode/eden/mononoke/public_autocargo = eden/mononoke +fbcode/tools/lfs = tools/lfs +tools/rust/ossconfigs = . + +[shipit.strip] +# strip all code unrelated to mononoke to prevent triggering unnecessary checks +^fbcode/eden/(?!mononoke|scm/lib/xdiff.*)/.*$ +^fbcode/eden/scm/lib/third-party/rust/.*/Cargo.toml$ +^fbcode/eden/mononoke/Cargo\.toml$ +^fbcode/eden/mononoke/(?!public_autocargo).+/Cargo\.toml$ +^fbcode/configerator/structs/scm/mononoke/(?!public_autocargo).+/Cargo\.toml$ +^.*/facebook/.*$ + +[dependencies] +fbthrift-source +rust-shed + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/mononoke_integration b/build/fbcode_builder/manifests/mononoke_integration new file mode 100644 index 000000000..a796e967e --- /dev/null +++ b/build/fbcode_builder/manifests/mononoke_integration @@ -0,0 +1,47 @@ +[manifest] +name = mononoke_integration +fbsource_path = fbcode/eden +shipit_project = eden +shipit_fbcode_builder = true + +[build.not(os=windows)] +builder = make +subdir = eden/mononoke/tests/integration + +[build.os=windows] +# building Mononoke on windows is not supported +builder = nop + +[make.build_args] +build-getdeps + +[make.install_args] +install-getdeps + +[make.test_args] +test-getdeps + +[shipit.pathmap] +fbcode/eden/mononoke/tests/integration = eden/mononoke/tests/integration + +[shipit.strip] +^.*/facebook/.*$ + +[dependencies] +eden_scm +eden_scm_lib_edenapi_tools +jq +mononoke +nmap +python-click +python-dulwich +tree + +[dependencies.os=linux] +sqlite3-bin + +[dependencies.os=darwin] +gnu-bash +gnu-coreutils +gnu-grep +gnu-sed diff --git a/build/fbcode_builder/manifests/mvfst b/build/fbcode_builder/manifests/mvfst new file mode 100644 index 000000000..4f72a9192 --- /dev/null +++ b/build/fbcode_builder/manifests/mvfst @@ -0,0 +1,32 @@ +[manifest] +name = mvfst +fbsource_path = fbcode/quic +shipit_project = mvfst +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookincubator/mvfst.git + +[build] +builder = cmake +subdir = . + +[cmake.defines.test=on] +BUILD_TESTS = ON + +[cmake.defines.all(os=windows, test=on)] +BUILD_TESTS = OFF + +[cmake.defines.test=off] +BUILD_TESTS = OFF + +[dependencies] +folly +fizz + +[dependencies.all(test=on, not(os=windows))] +googletest_1_8 + +[shipit.pathmap] +fbcode/quic/public_root = . +fbcode/quic = quic diff --git a/build/fbcode_builder/manifests/nghttp2 b/build/fbcode_builder/manifests/nghttp2 new file mode 100644 index 000000000..151daf8af --- /dev/null +++ b/build/fbcode_builder/manifests/nghttp2 @@ -0,0 +1,20 @@ +[manifest] +name = nghttp2 + +[rpms] +libnghttp2-devel +libnghttp2 + +[debs] +libnghttp2-dev + +[download] +url = https://github.com/nghttp2/nghttp2/releases/download/v1.39.2/nghttp2-1.39.2.tar.gz +sha256 = fc820a305e2f410fade1a3260f09229f15c0494fc089b0100312cd64a33a38c0 + +[build] +builder = autoconf +subdir = nghttp2-1.39.2 + +[autoconf.args] +--enable-lib-only diff --git a/build/fbcode_builder/manifests/ninja b/build/fbcode_builder/manifests/ninja new file mode 100644 index 000000000..2b6c5dc8d --- /dev/null +++ b/build/fbcode_builder/manifests/ninja @@ -0,0 +1,26 @@ +[manifest] +name = ninja + +[rpms] +ninja-build + +[debs] +ninja-build + +[download.os=windows] +url = https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-win.zip +sha256 = bbde850d247d2737c5764c927d1071cbb1f1957dcabda4a130fa8547c12c695f + +[build.os=windows] +builder = nop + +[install.files.os=windows] +ninja.exe = bin/ninja.exe + +[download.not(os=windows)] +url = https://github.com/ninja-build/ninja/archive/v1.10.2.tar.gz +sha256 = ce35865411f0490368a8fc383f29071de6690cbadc27704734978221f25e2bed + +[build.not(os=windows)] +builder = ninja_bootstrap +subdir = ninja-1.10.2 diff --git a/build/fbcode_builder/manifests/nmap b/build/fbcode_builder/manifests/nmap new file mode 100644 index 000000000..c245e1241 --- /dev/null +++ b/build/fbcode_builder/manifests/nmap @@ -0,0 +1,25 @@ +[manifest] +name = nmap + +[rpms] +nmap + +[debs] +nmap + +[download.not(os=windows)] +url = https://api.github.com/repos/nmap/nmap/tarball/ef8213a36c2e89233c806753a57b5cd473605408 +sha256 = eda39e5a8ef4964fac7db16abf91cc11ff568eac0fa2d680b0bfa33b0ed71f4a + +[build.not(os=windows)] +builder = autoconf +subdir = nmap-nmap-ef8213a +build_in_src_dir = true + +[build.os=windows] +builder = nop + +[autoconf.args] +# Without this option the build was filing to find some third party libraries +# that we don't need +enable_rdma=no diff --git a/build/fbcode_builder/manifests/openr b/build/fbcode_builder/manifests/openr new file mode 100644 index 000000000..754ba8cd5 --- /dev/null +++ b/build/fbcode_builder/manifests/openr @@ -0,0 +1,37 @@ +[manifest] +name = openr +fbsource_path = facebook/openr +shipit_project = openr +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/openr.git + +[build.os=linux] +builder = cmake + +[build.not(os=linux)] +# boost.fiber is required and that is not available on macos. +# libzmq doesn't currently build on windows. +builder = nop + +[dependencies] +boost +fb303 +fbthrift +fbzmq +folly +googletest +re2 + +[cmake.defines.test=on] +BUILD_TESTS=ON +ADD_ROOT_TESTS=OFF + +[cmake.defines.test=off] +BUILD_TESTS=OFF + + +[shipit.pathmap] +fbcode/openr = openr +fbcode/openr/public_tld = . diff --git a/build/fbcode_builder/manifests/openssl b/build/fbcode_builder/manifests/openssl new file mode 100644 index 000000000..991196c9a --- /dev/null +++ b/build/fbcode_builder/manifests/openssl @@ -0,0 +1,20 @@ +[manifest] +name = openssl + +[rpms] +openssl-devel +openssl + +[debs] +libssl-dev + +[download] +url = https://www.openssl.org/source/openssl-1.1.1i.tar.gz +sha256 = e8be6a35fe41d10603c3cc635e93289ed00bf34b79671a3a4de64fcee00d5242 + +[build] +builder = openssl +subdir = openssl-1.1.1i + +[dependencies.os=windows] +perl diff --git a/build/fbcode_builder/manifests/osxfuse b/build/fbcode_builder/manifests/osxfuse new file mode 100644 index 000000000..b6c6c551f --- /dev/null +++ b/build/fbcode_builder/manifests/osxfuse @@ -0,0 +1,12 @@ +[manifest] +name = osxfuse + +[download] +url = https://github.com/osxfuse/osxfuse/archive/osxfuse-3.8.3.tar.gz +sha256 = 93bab6731bdfe8dc1ef069483437270ce7fe5a370f933d40d8d0ef09ba846c0c + +[build] +builder = nop + +[install.files] +osxfuse-osxfuse-3.8.3/common = include diff --git a/build/fbcode_builder/manifests/patchelf b/build/fbcode_builder/manifests/patchelf new file mode 100644 index 000000000..f9d050424 --- /dev/null +++ b/build/fbcode_builder/manifests/patchelf @@ -0,0 +1,17 @@ +[manifest] +name = patchelf + +[rpms] +patchelf + +[debs] +patchelf + +[download] +url = https://github.com/NixOS/patchelf/archive/0.10.tar.gz +sha256 = b3cb6bdedcef5607ce34a350cf0b182eb979f8f7bc31eae55a93a70a3f020d13 + +[build] +builder = autoconf +subdir = patchelf-0.10 + diff --git a/build/fbcode_builder/manifests/pcre b/build/fbcode_builder/manifests/pcre new file mode 100644 index 000000000..5353d8c27 --- /dev/null +++ b/build/fbcode_builder/manifests/pcre @@ -0,0 +1,18 @@ +[manifest] +name = pcre + +[rpms] +pcre-devel +pcre-static + +[debs] +libpcre3-dev + +[download] +url = https://ftp.pcre.org/pub/pcre/pcre-8.43.tar.gz +sha256 = 0b8e7465dc5e98c757cc3650a20a7843ee4c3edf50aaf60bb33fd879690d2c73 + +[build] +builder = cmake +subdir = pcre-8.43 + diff --git a/build/fbcode_builder/manifests/perl b/build/fbcode_builder/manifests/perl new file mode 100644 index 000000000..32bddc51c --- /dev/null +++ b/build/fbcode_builder/manifests/perl @@ -0,0 +1,11 @@ +[manifest] +name = perl + +[download.os=windows] +url = http://strawberryperl.com/download/5.28.1.1/strawberry-perl-5.28.1.1-64bit-portable.zip +sha256 = 935c95ba096fa11c4e1b5188732e3832d330a2a79e9882ab7ba8460ddbca810d + +[build.os=windows] +builder = nop +subdir = perl + diff --git a/build/fbcode_builder/manifests/pexpect b/build/fbcode_builder/manifests/pexpect new file mode 100644 index 000000000..682e66a54 --- /dev/null +++ b/build/fbcode_builder/manifests/pexpect @@ -0,0 +1,12 @@ +[manifest] +name = pexpect + +[download] +url = https://files.pythonhosted.org/packages/0e/3e/377007e3f36ec42f1b84ec322ee12141a9e10d808312e5738f52f80a232c/pexpect-4.7.0-py2.py3-none-any.whl +sha256 = 2094eefdfcf37a1fdbfb9aa090862c1a4878e5c7e0e7e7088bdb511c558e5cd1 + +[build] +builder = python-wheel + +[dependencies] +python-ptyprocess diff --git a/build/fbcode_builder/manifests/protobuf b/build/fbcode_builder/manifests/protobuf new file mode 100644 index 000000000..7f21e4821 --- /dev/null +++ b/build/fbcode_builder/manifests/protobuf @@ -0,0 +1,17 @@ +[manifest] +name = protobuf + +[rpms] +protobuf-devel + +[debs] +libprotobuf-dev + +[git] +repo_url = https://github.com/protocolbuffers/protobuf.git + +[build.not(os=windows)] +builder = autoconf + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/proxygen b/build/fbcode_builder/manifests/proxygen new file mode 100644 index 000000000..5452a2454 --- /dev/null +++ b/build/fbcode_builder/manifests/proxygen @@ -0,0 +1,39 @@ +[manifest] +name = proxygen +fbsource_path = fbcode/proxygen +shipit_project = proxygen +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/proxygen.git + +[build.os=windows] +builder = nop + +[build] +builder = cmake +subdir = . + +[cmake.defines] +BUILD_QUIC = ON + +[cmake.defines.test=on] +BUILD_TESTS = ON + +[cmake.defines.test=off] +BUILD_TESTS = OFF + +[dependencies] +zlib +gperf +folly +fizz +wangle +mvfst + +[dependencies.test=on] +googletest_1_8 + +[shipit.pathmap] +fbcode/proxygen/public_tld = . +fbcode/proxygen = proxygen diff --git a/build/fbcode_builder/manifests/python b/build/fbcode_builder/manifests/python new file mode 100644 index 000000000..e51c0ab51 --- /dev/null +++ b/build/fbcode_builder/manifests/python @@ -0,0 +1,17 @@ +[manifest] +name = python + +[rpms] +python3 +python3-devel + +[debs] +python3-all-dev + +[download.os=linux] +url = https://www.python.org/ftp/python/3.7.6/Python-3.7.6.tgz +sha256 = aeee681c235ad336af116f08ab6563361a0c81c537072c1b309d6e4050aa2114 + +[build.os=linux] +builder = autoconf +subdir = Python-3.7.6 diff --git a/build/fbcode_builder/manifests/python-click b/build/fbcode_builder/manifests/python-click new file mode 100644 index 000000000..ea9a9d2d3 --- /dev/null +++ b/build/fbcode_builder/manifests/python-click @@ -0,0 +1,9 @@ +[manifest] +name = python-click + +[download] +url = https://files.pythonhosted.org/packages/d2/3d/fa76db83bf75c4f8d338c2fd15c8d33fdd7ad23a9b5e57eb6c5de26b430e/click-7.1.2-py2.py3-none-any.whl +sha256 = dacca89f4bfadd5de3d7489b7c8a566eee0d3676333fbb50030263894c38c0dc + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/python-dulwich b/build/fbcode_builder/manifests/python-dulwich new file mode 100644 index 000000000..0d995e12f --- /dev/null +++ b/build/fbcode_builder/manifests/python-dulwich @@ -0,0 +1,19 @@ +[manifest] +name = python-dulwich + +# The below links point to custom github forks of project dulwich, because the +# 0.18.6 version didn't have an official rollout of wheel packages. + +[download.os=linux] +url = https://github.com/lukaspiatkowski/dulwich/releases/download/dulwich-0.18.6-wheel/dulwich-0.18.6-cp36-cp36m-linux_x86_64.whl +sha256 = e96f545f3d003e67236785473caaba2c368e531ea85fd508a3bd016ebac3a6d8 + +[download.os=darwin] +url = https://github.com/lukaspiatkowski/dulwich/releases/download/dulwich-0.18.6-wheel/dulwich-0.18.6-cp37-cp37m-macosx_10_14_x86_64.whl +sha256 = 8373652056284ad40ea5220b659b3489b0a91f25536322345a3e4b5d29069308 + +[build.not(os=windows)] +builder = python-wheel + +[build.os=windows] +builder = nop diff --git a/build/fbcode_builder/manifests/python-ptyprocess b/build/fbcode_builder/manifests/python-ptyprocess new file mode 100644 index 000000000..adc60e048 --- /dev/null +++ b/build/fbcode_builder/manifests/python-ptyprocess @@ -0,0 +1,9 @@ +[manifest] +name = python-ptyprocess + +[download] +url = https://files.pythonhosted.org/packages/d1/29/605c2cc68a9992d18dada28206eeada56ea4bd07a239669da41674648b6f/ptyprocess-0.6.0-py2.py3-none-any.whl +sha256 = d7cc528d76e76342423ca640335bd3633420dc1366f258cb31d05e865ef5ca1f + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/python-six b/build/fbcode_builder/manifests/python-six new file mode 100644 index 000000000..a712188dc --- /dev/null +++ b/build/fbcode_builder/manifests/python-six @@ -0,0 +1,9 @@ +[manifest] +name = python-six + +[download] +url = https://files.pythonhosted.org/packages/73/fb/00a976f728d0d1fecfe898238ce23f502a721c0ac0ecfedb80e0d88c64e9/six-1.12.0-py2.py3-none-any.whl +sha256 = 3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/python-toml b/build/fbcode_builder/manifests/python-toml new file mode 100644 index 000000000..b49a3b8fb --- /dev/null +++ b/build/fbcode_builder/manifests/python-toml @@ -0,0 +1,9 @@ +[manifest] +name = python-toml + +[download] +url = https://files.pythonhosted.org/packages/a2/12/ced7105d2de62fa7c8fb5fce92cc4ce66b57c95fb875e9318dba7f8c5db0/toml-0.10.0-py2.py3-none-any.whl +sha256 = 235682dd292d5899d361a811df37e04a8828a5b1da3115886b73cf81ebc9100e + +[build] +builder = python-wheel diff --git a/build/fbcode_builder/manifests/re2 b/build/fbcode_builder/manifests/re2 new file mode 100644 index 000000000..eb4d6a92c --- /dev/null +++ b/build/fbcode_builder/manifests/re2 @@ -0,0 +1,17 @@ +[manifest] +name = re2 + +[rpms] +re2 +re2-devel + +[debs] +libre2-dev + +[download] +url = https://github.com/google/re2/archive/2019-06-01.tar.gz +sha256 = 02b7d73126bd18e9fbfe5d6375a8bb13fadaf8e99e48cbb062e4500fc18e8e2e + +[build] +builder = cmake +subdir = re2-2019-06-01 diff --git a/build/fbcode_builder/manifests/rocksdb b/build/fbcode_builder/manifests/rocksdb new file mode 100644 index 000000000..323e6dc6d --- /dev/null +++ b/build/fbcode_builder/manifests/rocksdb @@ -0,0 +1,41 @@ +[manifest] +name = rocksdb + +[download] +url = https://github.com/facebook/rocksdb/archive/v6.8.1.tar.gz +sha256 = ca192a06ed3bcb9f09060add7e9d0daee1ae7a8705a3d5ecbe41867c5e2796a2 + +[dependencies] +lz4 +snappy + +[build] +builder = cmake +subdir = rocksdb-6.8.1 + +[cmake.defines] +WITH_SNAPPY=ON +WITH_LZ4=ON +WITH_TESTS=OFF +WITH_BENCHMARK_TOOLS=OFF +# We get relocation errors with the static gflags lib, +# and there's no clear way to make it pick the shared gflags +# so just turn it off. +WITH_GFLAGS=OFF +# mac pro machines don't have some of the newer features that +# rocksdb enables by default; ask it to disable their use even +# when building on new hardware +PORTABLE = ON +# Disable the use of -Werror +FAIL_ON_WARNINGS = OFF + +[cmake.defines.os=windows] +ROCKSDB_INSTALL_ON_WINDOWS=ON +# RocksDB hard codes the paths to the snappy libs to something +# that doesn't exist; ignoring the usual cmake rules. As a result, +# we can't build it with snappy without either patching rocksdb or +# without introducing more complex logic to the build system to +# connect the snappy build outputs to rocksdb's custom logic here. +# Let's just turn it off on windows. +WITH_SNAPPY=OFF +WITH_LZ4=OFF diff --git a/build/fbcode_builder/manifests/rust-shed b/build/fbcode_builder/manifests/rust-shed new file mode 100644 index 000000000..c94b3fdd6 --- /dev/null +++ b/build/fbcode_builder/manifests/rust-shed @@ -0,0 +1,34 @@ +[manifest] +name = rust-shed +fbsource_path = fbcode/common/rust/shed +shipit_project = rust-shed +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebookexperimental/rust-shed.git + +[build] +builder = cargo + +[cargo] +build_doc = true +workspace_dir = + +[shipit.pathmap] +fbcode/common/rust/shed = shed +fbcode/common/rust/shed/public_autocargo = shed +fbcode/common/rust/shed/public_tld = . +tools/rust/ossconfigs = . + +[shipit.strip] +^fbcode/common/rust/shed/(?!public_autocargo|public_tld).+/Cargo\.toml$ + +[dependencies] +fbthrift +# macOS doesn't expose the openssl api so we need to build our own. +# Windows doesn't have openssl and Linux might contain an old version, +# so we get to provide it +openssl + +[dependencies.fb=on] +rust diff --git a/build/fbcode_builder/manifests/snappy b/build/fbcode_builder/manifests/snappy new file mode 100644 index 000000000..2f46a7734 --- /dev/null +++ b/build/fbcode_builder/manifests/snappy @@ -0,0 +1,25 @@ +[manifest] +name = snappy + +[rpms] +snappy +snappy-devel + +[debs] +libsnappy-dev + +[download] +url = https://github.com/google/snappy/archive/1.1.7.tar.gz +sha256 = 3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4 + +[build] +builder = cmake +subdir = snappy-1.1.7 + +[cmake.defines] +SNAPPY_BUILD_TESTS = OFF + +# Avoid problems like `relocation R_X86_64_PC32 against symbol` on ELF systems +# when linking rocksdb, which builds PIC even when building a static lib +[cmake.defines.os=linux] +BUILD_SHARED_LIBS = ON diff --git a/build/fbcode_builder/manifests/sqlite3 b/build/fbcode_builder/manifests/sqlite3 new file mode 100644 index 000000000..2463f5761 --- /dev/null +++ b/build/fbcode_builder/manifests/sqlite3 @@ -0,0 +1,21 @@ +[manifest] +name = sqlite3 + +[rpms] +sqlite-devel +sqlite-libs + +[debs] +libsqlite3-dev + +[download] +url = https://sqlite.org/2019/sqlite-amalgamation-3280000.zip +sha256 = d02fc4e95cfef672b45052e221617a050b7f2e20103661cda88387349a9b1327 + +[dependencies] +cmake +ninja + +[build] +builder = sqlite +subdir = sqlite-amalgamation-3280000 diff --git a/build/fbcode_builder/manifests/sqlite3-bin b/build/fbcode_builder/manifests/sqlite3-bin new file mode 100644 index 000000000..aa138d499 --- /dev/null +++ b/build/fbcode_builder/manifests/sqlite3-bin @@ -0,0 +1,28 @@ +[manifest] +name = sqlite3-bin + +[rpms] +sqlite + +[debs] +sqlite3 + +[download.os=linux] +url = https://github.com/sqlite/sqlite/archive/version-3.33.0.tar.gz +sha256 = 48e5f989eefe9af0ac758096f82ead0f3c7b58118ac17cc5810495bd5084a331 + +[build.os=linux] +builder = autoconf +subdir = sqlite-version-3.33.0 + +[build.not(os=linux)] +# MacOS comes with sqlite3 preinstalled and don't need Windows here +builder = nop + +[dependencies.os=linux] +tcl + +[autoconf.args] +# This flag disabled tcl as a runtime library used for some functionality, +# but tcl is still a required dependency as it is used by the build files +--disable-tcl diff --git a/build/fbcode_builder/manifests/tcl b/build/fbcode_builder/manifests/tcl new file mode 100644 index 000000000..5e9892f37 --- /dev/null +++ b/build/fbcode_builder/manifests/tcl @@ -0,0 +1,20 @@ +[manifest] +name = tcl + +[rpms] +tcl + +[debs] +tcl + +[download] +url = https://github.com/tcltk/tcl/archive/core-8-7a3.tar.gz +sha256 = 22d748f0c9652f3ecc195fed3f24a1b6eea8d449003085e6651197951528982e + +[build.os=linux] +builder = autoconf +subdir = tcl-core-8-7a3/unix + +[build.not(os=linux)] +# This is for sqlite3 on Linux for now +builder = nop diff --git a/build/fbcode_builder/manifests/tree b/build/fbcode_builder/manifests/tree new file mode 100644 index 000000000..0c982f35a --- /dev/null +++ b/build/fbcode_builder/manifests/tree @@ -0,0 +1,34 @@ +[manifest] +name = tree + +[rpms] +tree + +[debs] +tree + +[download.os=linux] +url = https://salsa.debian.org/debian/tree-packaging/-/archive/debian/1.8.0-1/tree-packaging-debian-1.8.0-1.tar.gz +sha256 = a841eee1d52bfd64a48f54caab9937b9bd92935055c48885c4ab1ae4dab7fae5 + +[download.os=darwin] +# The official package of tree source requires users of non-Linux platform to +# comment/uncomment certain lines in the Makefile to build for their platform. +# Besauce getdeps.py doesn't have that functionality we just use this custom +# fork of tree which has proper lines uncommented for a OSX build +url = https://github.com/lukaspiatkowski/tree-command/archive/debian/1.8.0-1-macos.tar.gz +sha256 = 9cbe889553d95cf5a2791dd0743795d46a3c092c5bba691769c0e5c52e11229e + +[build.os=linux] +builder = make +subdir = tree-packaging-debian-1.8.0-1 + +[build.os=darwin] +builder = make +subdir = tree-command-debian-1.8.0-1-macos + +[build.os=windows] +builder = nop + +[make.install_args] +install diff --git a/build/fbcode_builder/manifests/wangle b/build/fbcode_builder/manifests/wangle new file mode 100644 index 000000000..6b330d620 --- /dev/null +++ b/build/fbcode_builder/manifests/wangle @@ -0,0 +1,27 @@ +[manifest] +name = wangle +fbsource_path = fbcode/wangle +shipit_project = wangle +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/wangle.git + +[build] +builder = cmake +subdir = wangle + +[cmake.defines.test=on] +BUILD_TESTS=ON + +[cmake.defines.test=off] +BUILD_TESTS=OFF + +[dependencies] +folly +googletest +fizz + +[shipit.pathmap] +fbcode/wangle/public_tld = . +fbcode/wangle = wangle diff --git a/build/fbcode_builder/manifests/watchman b/build/fbcode_builder/manifests/watchman new file mode 100644 index 000000000..0fcd6bb9f --- /dev/null +++ b/build/fbcode_builder/manifests/watchman @@ -0,0 +1,45 @@ +[manifest] +name = watchman +fbsource_path = fbcode/watchman +shipit_project = watchman +shipit_fbcode_builder = true + +[git] +repo_url = https://github.com/facebook/watchman.git + +[build] +builder = cmake + +[dependencies] +boost +cpptoml +fb303 +fbthrift +folly +pcre +googletest + +[dependencies.fb=on] +rust + +[shipit.pathmap] +fbcode/watchman = watchman +fbcode/watchman/oss = . +fbcode/eden/fs = eden/fs + +[shipit.strip] +^fbcode/eden/fs/(?!.*\.thrift|service/shipit_test_file\.txt) + +[cmake.defines.fb=on] +ENABLE_EDEN_SUPPORT=ON + +# FB macos specific settings +[cmake.defines.all(fb=on,os=darwin)] +# this path is coupled with the FB internal watchman-osx.spec +WATCHMAN_STATE_DIR=/opt/facebook/watchman/var/run/watchman +# tell cmake not to try to create /opt/facebook/... +INSTALL_WATCHMAN_STATE_DIR=OFF +USE_SYS_PYTHON=OFF + +[depends.environment] +WATCHMAN_VERSION_OVERRIDE diff --git a/build/fbcode_builder/manifests/yaml-cpp b/build/fbcode_builder/manifests/yaml-cpp new file mode 100644 index 000000000..bffa540fe --- /dev/null +++ b/build/fbcode_builder/manifests/yaml-cpp @@ -0,0 +1,20 @@ +[manifest] +name = yaml-cpp + +[download] +url = https://github.com/jbeder/yaml-cpp/archive/yaml-cpp-0.6.2.tar.gz +sha256 = e4d8560e163c3d875fd5d9e5542b5fd5bec810febdcba61481fe5fc4e6b1fd05 + +[build.os=linux] +builder = cmake +subdir = yaml-cpp-yaml-cpp-0.6.2 + +[build.not(os=linux)] +builder = nop + +[dependencies] +boost +googletest + +[cmake.defines] +YAML_CPP_BUILD_TESTS=OFF diff --git a/build/fbcode_builder/manifests/zlib b/build/fbcode_builder/manifests/zlib new file mode 100644 index 000000000..8df0e3e48 --- /dev/null +++ b/build/fbcode_builder/manifests/zlib @@ -0,0 +1,22 @@ +[manifest] +name = zlib + +[rpms] +zlib-devel +zlib-static + +[debs] +zlib1g-dev + +[download] +url = http://www.zlib.net/zlib-1.2.11.tar.gz +sha256 = c3e5e9fdd5004dcb542feda5ee4f0ff0744628baf8ed2dd5d66f8ca1197cb1a1 + +[build.os=windows] +builder = cmake +subdir = zlib-1.2.11 + +# Every platform but windows ships with zlib, so just skip +# building on not(windows) +[build.not(os=windows)] +builder = nop diff --git a/build/fbcode_builder/manifests/zstd b/build/fbcode_builder/manifests/zstd new file mode 100644 index 000000000..71db9d5c6 --- /dev/null +++ b/build/fbcode_builder/manifests/zstd @@ -0,0 +1,28 @@ +[manifest] +name = zstd + +[rpms] +libzstd-devel +libzstd + +[debs] +libzstd-dev + +[download] +url = https://github.com/facebook/zstd/releases/download/v1.4.5/zstd-1.4.5.tar.gz +sha256 = 98e91c7c6bf162bf90e4e70fdbc41a8188b9fa8de5ad840c401198014406ce9e + +[build] +builder = cmake +subdir = zstd-1.4.5/build/cmake + +# The zstd cmake build explicitly sets the install name +# for the shared library in such a way that cmake discards +# the path to the library from the install_name, rendering +# the library non-resolvable during the build. The short +# term solution for this is just to link static on macos. +[cmake.defines.os=darwin] +ZSTD_BUILD_SHARED = OFF + +[cmake.defines.os=windows] +ZSTD_BUILD_SHARED = OFF diff --git a/build/fbcode_builder/parse_args.py b/build/fbcode_builder/parse_args.py new file mode 100644 index 000000000..8d5e35330 --- /dev/null +++ b/build/fbcode_builder/parse_args.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +"Argument parsing logic shared by all fbcode_builder CLI tools." + +import argparse +import logging + +from shell_quoting import raw_shell, ShellQuoted + + +def parse_args_to_fbcode_builder_opts(add_args_fn, top_level_opts, opts, help): + """ + + Provides some standard arguments: --debug, --option, --shell-quoted-option + + Then, calls `add_args_fn(parser)` to add application-specific arguments. + + `opts` are first used as defaults for the various command-line + arguments. Then, the parsed arguments are mapped back into `opts`, + which then become the values for `FBCodeBuilder.option()`, to be used + both by the builder and by `get_steps_fn()`. + + `help` is printed in response to the `--help` argument. + + """ + top_level_opts = set(top_level_opts) + + parser = argparse.ArgumentParser( + description=help, formatter_class=argparse.RawDescriptionHelpFormatter + ) + + add_args_fn(parser) + + parser.add_argument( + "--option", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + default=[ + (k, v) + for k, v in opts.items() + if k not in top_level_opts and not isinstance(v, ShellQuoted) + ], + help="Set project-specific options. These are assumed to be raw " + "strings, to be shell-escaped as needed. Default: %(default)s.", + ) + parser.add_argument( + "--shell-quoted-option", + nargs=2, + metavar=("KEY", "VALUE"), + action="append", + default=[ + (k, raw_shell(v)) + for k, v in opts.items() + if k not in top_level_opts and isinstance(v, ShellQuoted) + ], + help="Set project-specific options. These are assumed to be shell-" + "quoted, and may be used in commands as-is. Default: %(default)s.", + ) + + parser.add_argument("--debug", action="store_true", help="Log more") + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.debug else logging.INFO, + format="%(levelname)s: %(message)s", + ) + + # Map command-line args back into opts. + logging.debug("opts before command-line arguments: {0}".format(opts)) + + new_opts = {} + for key in top_level_opts: + val = getattr(args, key) + # Allow clients to unset a default by passing a value of None in opts + if val is not None: + new_opts[key] = val + for key, val in args.option: + new_opts[key] = val + for key, val in args.shell_quoted_option: + new_opts[key] = ShellQuoted(val) + + logging.debug("opts after command-line arguments: {0}".format(new_opts)) + + return new_opts diff --git a/build/fbcode_builder/shell_builder.py b/build/fbcode_builder/shell_builder.py new file mode 100644 index 000000000..e0d5429ad --- /dev/null +++ b/build/fbcode_builder/shell_builder.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" +shell_builder.py allows running the fbcode_builder logic +on the host rather than in a container. + +It emits a bash script with set -exo pipefail configured such that +any failing step will cause the script to exit with failure. + +== How to run it? == + +cd build +python fbcode_builder/shell_builder.py > ~/run.sh +bash ~/run.sh +""" + +import distutils.spawn +import os + +from fbcode_builder import FBCodeBuilder +from shell_quoting import raw_shell, shell_comment, shell_join, ShellQuoted +from utils import recursively_flatten_list + + +class ShellFBCodeBuilder(FBCodeBuilder): + def _render_impl(self, steps): + return raw_shell(shell_join("\n", recursively_flatten_list(steps))) + + def set_env(self, key, value): + return ShellQuoted("export {key}={val}").format(key=key, val=value) + + def workdir(self, dir): + return [ + ShellQuoted("mkdir -p {d} && cd {d}").format(d=dir), + ] + + def run(self, shell_cmd): + return ShellQuoted("{cmd}").format(cmd=shell_cmd) + + def step(self, name, actions): + assert "\n" not in name, "Name {0} would span > 1 line".format(name) + b = ShellQuoted("") + return [ShellQuoted("### {0} ###".format(name)), b] + actions + [b] + + def setup(self): + steps = ( + [ + ShellQuoted("set -exo pipefail"), + ] + + self.create_python_venv() + + self.python_venv() + ) + if self.has_option("ccache_dir"): + ccache_dir = self.option("ccache_dir") + steps += [ + ShellQuoted( + # Set CCACHE_DIR before the `ccache` invocations below. + "export CCACHE_DIR={ccache_dir} " + 'CC="ccache ${{CC:-gcc}}" CXX="ccache ${{CXX:-g++}}"' + ).format(ccache_dir=ccache_dir) + ] + return steps + + def comment(self, comment): + return shell_comment(comment) + + def copy_local_repo(self, dir, dest_name): + return [ + ShellQuoted("cp -r {dir} {dest_name}").format(dir=dir, dest_name=dest_name), + ] + + +def find_project_root(): + here = os.path.dirname(os.path.realpath(__file__)) + maybe_root = os.path.dirname(os.path.dirname(here)) + if os.path.isdir(os.path.join(maybe_root, ".git")): + return maybe_root + raise RuntimeError( + "I expected shell_builder.py to be in the " + "build/fbcode_builder subdir of a git repo" + ) + + +def persistent_temp_dir(repo_root): + escaped = repo_root.replace("/", "sZs").replace("\\", "sZs").replace(":", "") + return os.path.join(os.path.expandvars("$HOME"), ".fbcode_builder-" + escaped) + + +if __name__ == "__main__": + from utils import read_fbcode_builder_config, build_fbcode_builder_config + + repo_root = find_project_root() + temp = persistent_temp_dir(repo_root) + + config = read_fbcode_builder_config("fbcode_builder_config.py") + builder = ShellFBCodeBuilder(projects_dir=temp) + + if distutils.spawn.find_executable("ccache"): + builder.add_option( + "ccache_dir", os.environ.get("CCACHE_DIR", os.path.join(temp, ".ccache")) + ) + builder.add_option("prefix", os.path.join(temp, "installed")) + builder.add_option("make_parallelism", 4) + builder.add_option( + "{project}:local_repo_dir".format(project=config["github_project"]), repo_root + ) + make_steps = build_fbcode_builder_config(config) + steps = make_steps(builder) + print(builder.render(steps)) diff --git a/build/fbcode_builder/shell_quoting.py b/build/fbcode_builder/shell_quoting.py new file mode 100644 index 000000000..7429226bd --- /dev/null +++ b/build/fbcode_builder/shell_quoting.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +""" + +Almost every FBCodeBuilder string is ultimately passed to a shell. Escaping +too little or too much tends to be the most common error. The utilities in +this file give a systematic way of avoiding such bugs: + - When you write literal strings destined for the shell, use `ShellQuoted`. + - When these literal strings are parameterized, use `ShellQuoted.format`. + - Any parameters that are raw strings get `shell_quote`d automatically, + while any ShellQuoted parameters will be left intact. + - Use `path_join` to join path components. + - Use `shell_join` to join already-quoted command arguments or shell lines. + +""" + +import os +from collections import namedtuple + + +class ShellQuoted(namedtuple("ShellQuoted", ("do_not_use_raw_str",))): + """ + + Wrap a string with this to make it transparent to shell_quote(). It + will almost always suffice to use ShellQuoted.format(), path_join(), + or shell_join(). + + If you really must, use raw_shell() to access the raw string. + + """ + + def __new__(cls, s): + "No need to nest ShellQuoted." + return super(ShellQuoted, cls).__new__( + cls, s.do_not_use_raw_str if isinstance(s, ShellQuoted) else s + ) + + def __str__(self): + raise RuntimeError( + "One does not simply convert {0} to a string -- use path_join() " + "or ShellQuoted.format() instead".format(repr(self)) + ) + + def __repr__(self): + return "{0}({1})".format(self.__class__.__name__, repr(self.do_not_use_raw_str)) + + def format(self, **kwargs): + """ + + Use instead of str.format() when the arguments are either + `ShellQuoted()` or raw strings needing to be `shell_quote()`d. + + Positional args are deliberately not supported since they are more + error-prone. + + """ + return ShellQuoted( + self.do_not_use_raw_str.format( + **dict( + (k, shell_quote(v).do_not_use_raw_str) for k, v in kwargs.items() + ) + ) + ) + + +def shell_quote(s): + "Quotes a string if it is not already quoted" + return ( + s + if isinstance(s, ShellQuoted) + else ShellQuoted("'" + str(s).replace("'", "'\\''") + "'") + ) + + +def raw_shell(s): + "Not a member of ShellQuoted so we get a useful error for raw strings" + if isinstance(s, ShellQuoted): + return s.do_not_use_raw_str + raise RuntimeError("{0} should have been ShellQuoted".format(s)) + + +def shell_join(delim, it): + "Joins an iterable of ShellQuoted with a delimiter between each two" + return ShellQuoted(delim.join(raw_shell(s) for s in it)) + + +def path_join(*args): + "Joins ShellQuoted and raw pieces of paths to make a shell-quoted path" + return ShellQuoted(os.path.join(*[raw_shell(shell_quote(s)) for s in args])) + + +def shell_comment(c): + "Do not shell-escape raw strings in comments, but do handle line breaks." + return ShellQuoted("# {c}").format( + c=ShellQuoted( + (raw_shell(c) if isinstance(c, ShellQuoted) else c).replace("\n", "\n# ") + ) + ) diff --git a/build/fbcode_builder/specs/__init__.py b/build/fbcode_builder/specs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/build/fbcode_builder/specs/fbthrift.py b/build/fbcode_builder/specs/fbthrift.py new file mode 100644 index 000000000..f0c7e7ac7 --- /dev/null +++ b/build/fbcode_builder/specs/fbthrift.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.sodium as sodium +import specs.wangle as wangle +import specs.zstd as zstd + + +def fbcode_builder_spec(builder): + return { + "depends_on": [fmt, folly, fizz, sodium, wangle, zstd], + "steps": [ + builder.fb_github_cmake_install("fbthrift/thrift"), + ], + } diff --git a/build/fbcode_builder/specs/fbzmq.py b/build/fbcode_builder/specs/fbzmq.py new file mode 100644 index 000000000..78c8bc9dd --- /dev/null +++ b/build/fbcode_builder/specs/fbzmq.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fbthrift as fbthrift +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.sodium as sodium +from shell_quoting import ShellQuoted + + +def fbcode_builder_spec(builder): + builder.add_option("zeromq/libzmq:git_hash", "v4.2.2") + return { + "depends_on": [fmt, folly, fbthrift, gmock, sodium], + "steps": [ + builder.github_project_workdir("zeromq/libzmq", "."), + builder.step( + "Build and install zeromq/libzmq", + [ + builder.run(ShellQuoted("./autogen.sh")), + builder.configure(), + builder.make_and_install(), + ], + ), + builder.fb_github_project_workdir("fbzmq/_build", "facebook"), + builder.step( + "Build and install fbzmq/", + [ + builder.cmake_configure("fbzmq/_build"), + # we need the pythonpath to find the thrift compiler + builder.run( + ShellQuoted( + 'PYTHONPATH="$PYTHONPATH:"{p}/lib/python2.7/site-packages ' + "make -j {n}" + ).format( + p=builder.option("prefix"), + n=builder.option("make_parallelism"), + ) + ), + builder.run(ShellQuoted("make install")), + ], + ), + ], + } diff --git a/build/fbcode_builder/specs/fizz.py b/build/fbcode_builder/specs/fizz.py new file mode 100644 index 000000000..82f26e67c --- /dev/null +++ b/build/fbcode_builder/specs/fizz.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.sodium as sodium +import specs.zstd as zstd + + +def fbcode_builder_spec(builder): + builder.add_option( + "fizz/fizz/build:cmake_defines", + { + # Fizz's build is kind of broken, in the sense that both `mvfst` + # and `proxygen` depend on files that are only installed with + # `BUILD_TESTS` enabled, e.g. `fizz/crypto/test/TestUtil.h`. + "BUILD_TESTS": "ON" + }, + ) + return { + "depends_on": [gmock, fmt, folly, sodium, zstd], + "steps": [ + builder.fb_github_cmake_install( + "fizz/fizz/build", github_org="facebookincubator" + ) + ], + } diff --git a/build/fbcode_builder/specs/fmt.py b/build/fbcode_builder/specs/fmt.py new file mode 100644 index 000000000..395316799 --- /dev/null +++ b/build/fbcode_builder/specs/fmt.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + builder.add_option("fmtlib/fmt:git_hash", "6.2.1") + builder.add_option( + "fmtlib/fmt:cmake_defines", + { + # Avoids a bizarred failure to run tests in Bistro: + # test_crontab_selector: error while loading shared libraries: + # libfmt.so.6: cannot open shared object file: + # No such file or directory + "BUILD_SHARED_LIBS": "OFF", + }, + ) + return { + "steps": [ + builder.github_project_workdir("fmtlib/fmt", "build"), + builder.cmake_install("fmtlib/fmt"), + ], + } diff --git a/build/fbcode_builder/specs/folly.py b/build/fbcode_builder/specs/folly.py new file mode 100644 index 000000000..e89d5e955 --- /dev/null +++ b/build/fbcode_builder/specs/folly.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fmt as fmt + + +def fbcode_builder_spec(builder): + return { + "depends_on": [fmt], + "steps": [ + # on macOS the filesystem is typically case insensitive. + # We need to ensure that the CWD is not the folly source + # dir when we build, otherwise the system will decide + # that `folly/String.h` is the file it wants when including + # `string.h` and the build will fail. + builder.fb_github_project_workdir("folly/_build"), + builder.cmake_install("facebook/folly"), + ], + } diff --git a/build/fbcode_builder/specs/gmock.py b/build/fbcode_builder/specs/gmock.py new file mode 100644 index 000000000..774137301 --- /dev/null +++ b/build/fbcode_builder/specs/gmock.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + builder.add_option("google/googletest:git_hash", "release-1.8.1") + builder.add_option( + "google/googletest:cmake_defines", + { + "BUILD_GTEST": "ON", + # Avoid problems with MACOSX_RPATH + "BUILD_SHARED_LIBS": "OFF", + }, + ) + return { + "steps": [ + builder.github_project_workdir("google/googletest", "build"), + builder.cmake_install("google/googletest"), + ], + } diff --git a/build/fbcode_builder/specs/mvfst.py b/build/fbcode_builder/specs/mvfst.py new file mode 100644 index 000000000..ce8b003d9 --- /dev/null +++ b/build/fbcode_builder/specs/mvfst.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.folly as folly +import specs.gmock as gmock + + +def fbcode_builder_spec(builder): + # Projects that **depend** on mvfst should don't need to build tests. + builder.add_option( + "mvfst/build:cmake_defines", + { + # This is set to ON in the mvfst `fbcode_builder_config.py` + "BUILD_TESTS": "OFF" + }, + ) + return { + "depends_on": [gmock, folly, fizz], + "steps": [ + builder.fb_github_cmake_install( + "mvfst/build", github_org="facebookincubator" + ) + ], + } diff --git a/build/fbcode_builder/specs/proxygen.py b/build/fbcode_builder/specs/proxygen.py new file mode 100644 index 000000000..6a584d710 --- /dev/null +++ b/build/fbcode_builder/specs/proxygen.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.mvfst as mvfst +import specs.sodium as sodium +import specs.wangle as wangle +import specs.zstd as zstd + + +def fbcode_builder_spec(builder): + # Projects that **depend** on proxygen should don't need to build tests + # or QUIC support. + builder.add_option( + "proxygen/proxygen:cmake_defines", + { + # These 2 are set to ON in `proxygen_quic.py` + "BUILD_QUIC": "OFF", + "BUILD_TESTS": "OFF", + # For bistro + "BUILD_SHARED_LIBS": "OFF", + }, + ) + + return { + "depends_on": [gmock, fmt, folly, wangle, fizz, sodium, zstd, mvfst], + "steps": [builder.fb_github_cmake_install("proxygen/proxygen", "..")], + } diff --git a/build/fbcode_builder/specs/proxygen_quic.py b/build/fbcode_builder/specs/proxygen_quic.py new file mode 100644 index 000000000..b4959fb89 --- /dev/null +++ b/build/fbcode_builder/specs/proxygen_quic.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.mvfst as mvfst +import specs.sodium as sodium +import specs.wangle as wangle +import specs.zstd as zstd + +# DO NOT USE THIS AS A LIBRARY -- this is currently effectively just part +# ofthe implementation of proxygen's `fbcode_builder_config.py`. This is +# why this builds tests and sets `BUILD_QUIC`. +def fbcode_builder_spec(builder): + builder.add_option( + "proxygen/proxygen:cmake_defines", + {"BUILD_QUIC": "ON", "BUILD_SHARED_LIBS": "OFF", "BUILD_TESTS": "ON"}, + ) + return { + "depends_on": [gmock, fmt, folly, wangle, fizz, sodium, zstd, mvfst], + "steps": [builder.fb_github_cmake_install("proxygen/proxygen", "..")], + } diff --git a/build/fbcode_builder/specs/re2.py b/build/fbcode_builder/specs/re2.py new file mode 100644 index 000000000..cf4e08a0b --- /dev/null +++ b/build/fbcode_builder/specs/re2.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + return { + "steps": [ + builder.github_project_workdir("google/re2", "build"), + builder.cmake_install("google/re2"), + ], + } diff --git a/build/fbcode_builder/specs/rocksdb.py b/build/fbcode_builder/specs/rocksdb.py new file mode 100644 index 000000000..9ebfe4739 --- /dev/null +++ b/build/fbcode_builder/specs/rocksdb.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + + +def fbcode_builder_spec(builder): + builder.add_option( + "rocksdb/_build:cmake_defines", + { + "USE_RTTI": "1", + "PORTABLE": "ON", + }, + ) + return { + "steps": [ + builder.fb_github_cmake_install("rocksdb/_build"), + ], + } diff --git a/build/fbcode_builder/specs/sodium.py b/build/fbcode_builder/specs/sodium.py new file mode 100644 index 000000000..8be9833cf --- /dev/null +++ b/build/fbcode_builder/specs/sodium.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from shell_quoting import ShellQuoted + + +def fbcode_builder_spec(builder): + builder.add_option("jedisct1/libsodium:git_hash", "stable") + return { + "steps": [ + builder.github_project_workdir("jedisct1/libsodium", "."), + builder.step( + "Build and install jedisct1/libsodium", + [ + builder.run(ShellQuoted("./autogen.sh")), + builder.configure(), + builder.make_and_install(), + ], + ), + ], + } diff --git a/build/fbcode_builder/specs/wangle.py b/build/fbcode_builder/specs/wangle.py new file mode 100644 index 000000000..62b5b3c86 --- /dev/null +++ b/build/fbcode_builder/specs/wangle.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import specs.fizz as fizz +import specs.fmt as fmt +import specs.folly as folly +import specs.gmock as gmock +import specs.sodium as sodium + + +def fbcode_builder_spec(builder): + # Projects that **depend** on wangle need not spend time on tests. + builder.add_option( + "wangle/wangle/build:cmake_defines", + { + # This is set to ON in the wangle `fbcode_builder_config.py` + "BUILD_TESTS": "OFF" + }, + ) + return { + "depends_on": [gmock, fmt, folly, fizz, sodium], + "steps": [builder.fb_github_cmake_install("wangle/wangle/build")], + } diff --git a/build/fbcode_builder/specs/zstd.py b/build/fbcode_builder/specs/zstd.py new file mode 100644 index 000000000..14d9a1249 --- /dev/null +++ b/build/fbcode_builder/specs/zstd.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +from shell_quoting import ShellQuoted + + +def fbcode_builder_spec(builder): + # This API should change rarely, so build the latest tag instead of master. + builder.add_option( + "facebook/zstd:git_hash", + ShellQuoted("$(git describe --abbrev=0 --tags origin/master)"), + ) + return { + "steps": [ + builder.github_project_workdir("facebook/zstd", "."), + builder.step( + "Build and install zstd", + [ + builder.make_and_install( + make_vars={ + "PREFIX": builder.option("prefix"), + } + ) + ], + ), + ], + } diff --git a/build/fbcode_builder/travis.yml b/build/fbcode_builder/travis.yml new file mode 100644 index 000000000..d2bb60778 --- /dev/null +++ b/build/fbcode_builder/travis.yml @@ -0,0 +1,51 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Facebook projects that use `fbcode_builder` for continuous integration +# share this Travis configuration to run builds via Docker. + +# Docker disables IPv6 in containers by default. Enable it for unit tests that need [::1]. +before_script: + - if [[ "$TRAVIS_OS_NAME" != "osx" ]]; + then + sudo build/fbcode_builder/docker_enable_ipv6.sh; + fi + +env: + global: + - travis_cache_dir=$HOME/travis_ccache + # Travis times out after 50 minutes. Very generously leave 10 minutes + # for setup (e.g. cache download, compression, and upload), so we never + # fail to cache the progress we made. + - docker_build_timeout=40m + +cache: + # Our build caches can be 200-300MB, so increase the timeout to 7 minutes + # to make sure we never fail to cache the progress we made. + timeout: 420 + directories: + - $HOME/travis_ccache # see docker_build_with_ccache.sh + +# Ugh, `services:` must be in the matrix, or we get `docker: command not found` +# https://github.com/travis-ci/travis-ci/issues/5142 +matrix: + include: + - env: ['os_image=ubuntu:18.04', gcc_version=7] + services: [docker] + +addons: + apt: + packages: python2.7 + +script: + # We don't want to write the script inline because of Travis kludginess -- + # it looks like it escapes " and \ in scripts when using `matrix:`. + - ./build/fbcode_builder/travis_docker_build.sh diff --git a/build/fbcode_builder/travis_docker_build.sh b/build/fbcode_builder/travis_docker_build.sh new file mode 100755 index 000000000..d4cba10ef --- /dev/null +++ b/build/fbcode_builder/travis_docker_build.sh @@ -0,0 +1,42 @@ +#!/bin/bash -uex +# Copyright (c) Facebook, Inc. and its affiliates. +# .travis.yml in the top-level dir explains why this is a separate script. +# Read the docs: ./make_docker_context.py --help + +os_image=${os_image?Must be set by Travis} +gcc_version=${gcc_version?Must be set by Travis} +make_parallelism=${make_parallelism:-4} +# ccache is off unless requested +travis_cache_dir=${travis_cache_dir:-} +# The docker build never times out, unless specified +docker_build_timeout=${docker_build_timeout:-} + +cur_dir="$(realpath "$(dirname "$0")")" + +if [[ "$travis_cache_dir" == "" ]]; then + echo "ccache disabled, enable by setting env. var. travis_cache_dir" + ccache_tgz="" +elif [[ -e "$travis_cache_dir/ccache.tgz" ]]; then + ccache_tgz="$travis_cache_dir/ccache.tgz" +else + echo "$travis_cache_dir/ccache.tgz does not exist, starting with empty cache" + ccache_tgz=$(mktemp) + tar -T /dev/null -czf "$ccache_tgz" +fi + +docker_context_dir=$( + cd "$cur_dir/.." # Let the script find our fbcode_builder_config.py + "$cur_dir/make_docker_context.py" \ + --os-image "$os_image" \ + --gcc-version "$gcc_version" \ + --make-parallelism "$make_parallelism" \ + --local-repo-dir "$cur_dir/../.." \ + --ccache-tgz "$ccache_tgz" +) +cd "${docker_context_dir?Failed to make Docker context directory}" + +# Make it safe to iterate on the .sh in the tree while the script runs. +cp "$cur_dir/docker_build_with_ccache.sh" . +exec ./docker_build_with_ccache.sh \ + --build-timeout "$docker_build_timeout" \ + "$travis_cache_dir" diff --git a/build/fbcode_builder/utils.py b/build/fbcode_builder/utils.py new file mode 100644 index 000000000..02459a200 --- /dev/null +++ b/build/fbcode_builder/utils.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +"Miscellaneous utility functions." + +import itertools +import logging +import os +import shutil +import subprocess +import sys +from contextlib import contextmanager + + +def recursively_flatten_list(l): + return itertools.chain.from_iterable( + (recursively_flatten_list(i) if type(i) is list else (i,)) for i in l + ) + + +def run_command(*cmd, **kwargs): + "The stdout of most fbcode_builder utilities is meant to be parsed." + logging.debug("Running: {0} with {1}".format(cmd, kwargs)) + kwargs["stdout"] = sys.stderr + subprocess.check_call(cmd, **kwargs) + + +@contextmanager +def make_temp_dir(d): + os.mkdir(d) + try: + yield d + finally: + shutil.rmtree(d, ignore_errors=True) + + +def _inner_read_config(path): + """ + Helper to read a named config file. + The grossness with the global is a workaround for this python bug: + https://bugs.python.org/issue21591 + The bug prevents us from defining either a local function or a lambda + in the scope of read_fbcode_builder_config below. + """ + global _project_dir + full_path = os.path.join(_project_dir, path) + return read_fbcode_builder_config(full_path) + + +def read_fbcode_builder_config(filename): + # Allow one spec to read another + # When doing so, treat paths as relative to the config's project directory. + # _project_dir is a "local" for _inner_read_config; see the comments + # in that function for an explanation of the use of global. + global _project_dir + _project_dir = os.path.dirname(filename) + + scope = {"read_fbcode_builder_config": _inner_read_config} + with open(filename) as config_file: + code = compile(config_file.read(), filename, mode="exec") + exec(code, scope) + return scope["config"] + + +def steps_for_spec(builder, spec, processed_modules=None): + """ + Sets `builder` configuration, and returns all the builder steps + necessary to build `spec` and its dependencies. + + Traverses the dependencies in depth-first order, honoring the sequencing + in each 'depends_on' list. + """ + if processed_modules is None: + processed_modules = set() + steps = [] + for module in spec.get("depends_on", []): + if module not in processed_modules: + processed_modules.add(module) + steps.extend( + steps_for_spec( + builder, module.fbcode_builder_spec(builder), processed_modules + ) + ) + steps.extend(spec.get("steps", [])) + return steps + + +def build_fbcode_builder_config(config): + return lambda builder: builder.build( + steps_for_spec(builder, config["fbcode_builder_spec"](builder)) + ) diff --git a/build/fbcode_builder_config.py b/build/fbcode_builder_config.py new file mode 100644 index 000000000..85018bf05 --- /dev/null +++ b/build/fbcode_builder_config.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +'fbcode_builder steps to build rsocket' + +import specs.rsocket as rsocket + + +def fbcode_builder_spec(builder): + return { + 'depends_on': [rsocket], + } + + +config = { + 'github_project': 'rsocket/rsocket-cpp', + 'fbcode_builder_spec': fbcode_builder_spec, +} diff --git a/cmake/FindFolly.cmake b/cmake/FindFolly.cmake deleted file mode 100644 index b7fa0a42a..000000000 --- a/cmake/FindFolly.cmake +++ /dev/null @@ -1,14 +0,0 @@ -cmake_minimum_required(VERSION 3.2) - -include(FindPackageHandleStandardArgs) - -if (FOLLY_INSTALL_DIR) - set(lib_paths ${FOLLY_INSTALL_DIR}/lib) - set(include_paths ${FOLLY_INSTALL_DIR}/include) -endif () - -find_library(FOLLY_LIBRARY folly PATHS ${lib_paths}) -find_path(FOLLY_INCLUDE_DIR "folly/String.h" PATHS ${include_paths}) - -find_package_handle_standard_args(Folly - DEFAULT_MSG FOLLY_LIBRARY FOLLY_INCLUDE_DIR) diff --git a/cmake/InstallFolly.cmake b/cmake/InstallFolly.cmake new file mode 100644 index 000000000..2bd17460c --- /dev/null +++ b/cmake/InstallFolly.cmake @@ -0,0 +1,22 @@ +# Copyright (c) 2018, Facebook, Inc. +# All rights reserved. +# +if (NOT FOLLY_INSTALL_DIR) + set(FOLLY_INSTALL_DIR ${CMAKE_BINARY_DIR}/folly-install) +endif () + +if (RSOCKET_INSTALL_DEPS) + execute_process( + COMMAND + ${CMAKE_SOURCE_DIR}/scripts/build_folly.sh + ${CMAKE_BINARY_DIR}/folly-src + ${FOLLY_INSTALL_DIR} + RESULT_VARIABLE folly_result + ) + if (NOT "${folly_result}" STREQUAL "0") + message(FATAL_ERROR "failed to build folly") + endif() +endif () + +find_package(Threads) +find_package(folly CONFIG REQUIRED PATHS ${FOLLY_INSTALL_DIR}) diff --git a/cmake/rsocket-config.cmake.in b/cmake/rsocket-config.cmake.in new file mode 100644 index 000000000..d5579a856 --- /dev/null +++ b/cmake/rsocket-config.cmake.in @@ -0,0 +1,12 @@ +# Copyright (c) 2018, Facebook, Inc. +# All rights reserved. + +@PACKAGE_INIT@ + +if(NOT TARGET rsocket::ReactiveSocket) + include("${PACKAGE_PREFIX_DIR}/lib/cmake/rsocket/rsocket-exports.cmake") +endif() + +if (NOT rsocket_FIND_QUIETLY) + message(STATUS "Found rsocket: ${PACKAGE_PREFIX_DIR}") +endif() diff --git a/devtools/format_all.sh b/devtools/format_all.sh index aed32b572..235b985e2 100755 --- a/devtools/format_all.sh +++ b/devtools/format_all.sh @@ -1,4 +1,7 @@ #!/usr/bin/env bash +# +# Copyright 2004-present Facebook. All Rights Reserved. +# set -xue cd "$(dirname "$0")/.." diff --git a/examples/channel-hello-world/ChannelHelloWorld_Client.cpp b/examples/channel-hello-world/ChannelHelloWorld_Client.cpp deleted file mode 100644 index 8ac8e30e6..000000000 --- a/examples/channel-hello-world/ChannelHelloWorld_Client.cpp +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include - -#include -#include -#include - -#include "examples/util/ExampleSubscriber.h" -#include "rsocket/RSocket.h" -#include "rsocket/transports/tcp/TcpConnectionFactory.h" - -#include "yarpl/Flowable.h" - -using namespace rsocket_example; -using namespace rsocket; -using namespace yarpl::flowable; - -DEFINE_string(host, "localhost", "host to connect to"); -DEFINE_int32(port, 9898, "host:port to connect to"); - -int main(int argc, char* argv[]) { - FLAGS_logtostderr = true; - FLAGS_minloglevel = 0; - folly::init(&argc, &argv); - - folly::SocketAddress address; - address.setFromHostPort(FLAGS_host, FLAGS_port); - - auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address))) - .get(); - - client->getRequester() - ->requestChannel(Flowables::justN({"initialPayload", "Bob", "Jane"}) - ->map([](std::string v) { - std::cout << "Sending: " << v << std::endl; - return Payload(v); - })) - ->subscribe([](Payload p) { - std::cout << "Received: " << p.moveDataToString() << std::endl; - }); - - // Wait for a newline on the console to terminate the server. - std::getchar(); - return 0; -} diff --git a/examples/conditional-request-handling/JsonRequestHandler.cpp b/examples/conditional-request-handling/JsonRequestHandler.cpp deleted file mode 100644 index c19d40032..000000000 --- a/examples/conditional-request-handling/JsonRequestHandler.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "JsonRequestHandler.h" -#include -#include "yarpl/Flowable.h" - -using namespace rsocket; -using namespace yarpl::flowable; - -/// Handles a new inbound Stream requested by the other end. -yarpl::Reference> -JsonRequestResponder::handleRequestStream(Payload request, StreamId) { - LOG(INFO) << "JsonRequestResponder.handleRequestStream " << request; - - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 100)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello (should be JSON) " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); -} diff --git a/examples/conditional-request-handling/JsonRequestHandler.h b/examples/conditional-request-handling/JsonRequestHandler.h deleted file mode 100644 index f24f06ccf..000000000 --- a/examples/conditional-request-handling/JsonRequestHandler.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/Payload.h" -#include "rsocket/RSocket.h" - -class JsonRequestResponder : public rsocket::RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) - override; -}; diff --git a/examples/conditional-request-handling/TextRequestHandler.cpp b/examples/conditional-request-handling/TextRequestHandler.cpp deleted file mode 100644 index a6f0717a1..000000000 --- a/examples/conditional-request-handling/TextRequestHandler.cpp +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "TextRequestHandler.h" -#include -#include "yarpl/Flowable.h" - -using namespace rsocket; -using namespace yarpl::flowable; - -/// Handles a new inbound Stream requested by the other end. -yarpl::Reference> -TextRequestResponder::handleRequestStream(Payload request, StreamId) { - LOG(INFO) << "TextRequestResponder.handleRequestStream " << request; - - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 100)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); -} diff --git a/examples/conditional-request-handling/TextRequestHandler.h b/examples/conditional-request-handling/TextRequestHandler.h deleted file mode 100644 index 604fdbeea..000000000 --- a/examples/conditional-request-handling/TextRequestHandler.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/Payload.h" -#include "rsocket/RSocket.h" - -class TextRequestResponder : public rsocket::RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) - override; -}; diff --git a/rsocket/ColdResumeHandler.cpp b/rsocket/ColdResumeHandler.cpp new file mode 100644 index 000000000..870faef48 --- /dev/null +++ b/rsocket/ColdResumeHandler.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/ColdResumeHandler.h" + +#include "yarpl/flowable/CancelingSubscriber.h" + +#include + +using namespace yarpl::flowable; + +namespace rsocket { + +std::string ColdResumeHandler::generateStreamToken( + const Payload&, + StreamId streamId, + StreamType) const { + return folly::to(streamId); +} + +std::shared_ptr> +ColdResumeHandler::handleResponderResumeStream( + std::string /* streamToken */, + size_t /* publisherAllowance */) { + return Flowable::error( + std::logic_error("ResumeHandler method not implemented")); +} + +std::shared_ptr> +ColdResumeHandler::handleRequesterResumeStream( + std::string /* streamToken */, + size_t /* consumerAllowance */) { + return std::make_shared>(); +} +} // namespace rsocket diff --git a/rsocket/ColdResumeHandler.h b/rsocket/ColdResumeHandler.h index 9c6d203cb..f4190e16f 100644 --- a/rsocket/ColdResumeHandler.h +++ b/rsocket/ColdResumeHandler.h @@ -1,8 +1,56 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include "yarpl/Flowable.h" + +#include "rsocket/Payload.h" +#include "rsocket/framing/FrameHeader.h" +#include "rsocket/internal/Common.h" + namespace rsocket { -class ColdResumeHandler {}; -} \ No newline at end of file +// This class has to be implemented by the client application for cold +// resumption. The default implementation will error/close the streams. +class ColdResumeHandler { + public: + virtual ~ColdResumeHandler() = default; + + // Generate an application-aware streamToken for the given stream parameters. + virtual std::string + generateStreamToken(const Payload&, StreamId streamId, StreamType) const; + + // This method will be called for each REQUEST_STREAM for which the + // application acted as a responder. The default action would be to return a + // Flowable which errors out immediately. + // The second parameter is the allowance which the application received + // before cold-start and hasn't been fulfilled yet. + virtual std::shared_ptr> + handleResponderResumeStream( + std::string streamToken, + size_t publisherAllowance); + + // This method will be called for each REQUEST_STREAM for which the + // application acted as a requester. The default action would be to return a + // Subscriber which cancels the stream immediately after getting subscribed. + // The second parameter is the allowance which the application requested + // before cold-start and hasn't been fulfilled yet. + virtual std::shared_ptr> + handleRequesterResumeStream( + std::string streamToken, + size_t consumerAllowance); +}; + +} // namespace rsocket diff --git a/rsocket/ConnectionAcceptor.h b/rsocket/ConnectionAcceptor.h index 92aa5f535..3e94a4416 100644 --- a/rsocket/ConnectionAcceptor.h +++ b/rsocket/ConnectionAcceptor.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,9 +24,8 @@ class EventBase; namespace rsocket { -using OnDuplexConnectionAccept = std::function, - folly::EventBase&)>; +using OnDuplexConnectionAccept = std::function< + void(std::unique_ptr, folly::EventBase&)>; /** * Common interface for a server that accepts connections and turns them into @@ -24,8 +35,6 @@ using OnDuplexConnectionAccept = std::function +#include #include "rsocket/DuplexConnection.h" +#include "rsocket/framing/ProtocolVersion.h" namespace folly { class EventBase; @@ -11,9 +25,7 @@ class EventBase; namespace rsocket { -using OnDuplexConnectionConnect = folly::Function, - folly::EventBase&)>; +enum class ResumeStatus { NEW_SESSION, RESUMING }; /** * Common interface for a client to create connections and turn them into @@ -33,6 +45,11 @@ class ConnectionFactory { ConnectionFactory& operator=(const ConnectionFactory&) = delete; // copy ConnectionFactory& operator=(ConnectionFactory&&) = delete; // move + struct ConnectedDuplexConnection { + std::unique_ptr connection; + folly::EventBase& eventBase; + }; + /** * Connect to server defined by constructor of the implementing class. * @@ -42,6 +59,8 @@ class ConnectionFactory { * * Resource creation depends on the particular implementation. */ - virtual void connect(OnDuplexConnectionConnect onConnect) = 0; + virtual folly::Future connect( + ProtocolVersion, + ResumeStatus resume) = 0; }; } // namespace rsocket diff --git a/rsocket/DuplexConnection.h b/rsocket/DuplexConnection.h index 84891b9d9..7aaff2156 100644 --- a/rsocket/DuplexConnection.h +++ b/rsocket/DuplexConnection.h @@ -1,55 +1,63 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include "yarpl/flowable/Subscriber.h" -namespace folly { -class IOBuf; -} +#include + +#include "yarpl/flowable/Subscriber.h" namespace rsocket { -/// Represents a connection of the underlying protocol, on top of which -/// the ReactiveSocket is layered. The underlying protocol MUST provide an -/// ordered, guaranteed, bidirectional transport of frames. Moreover, the frame +/// Represents a connection of the underlying protocol, on top of which the +/// RSocket protocol is layered. The underlying protocol MUST provide an +/// ordered, guaranteed, bidirectional transport of frames. Moreover, frame /// boundaries MUST be preserved. /// /// The frames exchanged through this interface are serialized, and lack the -/// optional frame length field. Presence of the field is determined by the -/// underlying protocol. If the protocol natively supports framing (e.g. Aeron), -/// the fill MUST be omitted, otherwise (e.g. TCP) is must be present. -/// ReactiveSocket implementation MUST NOT ever be provided with a frame that -/// contains the length field nor it ever requests to sends such a frame. +/// optional frame length field. Presence of the field is determined by the +/// underlying protocol. If the protocol natively supports framing +/// (e.g. Aeron), the fileld MUST be omitted, otherwise (e.g. TCP) it must be +/// present. The RSocket implementation MUST NOT be provided with a frame that +/// contains the length field nor can it ever send such a frame. /// /// It can be assumed that both input and output will be closed by sending /// appropriate terminal signals (according to ReactiveStreams specification) /// before the connection is destroyed. class DuplexConnection { public: + using Subscriber = yarpl::flowable::Subscriber>; + virtual ~DuplexConnection() = default; /// Sets a Subscriber that will consume received frames (a reader). /// - /// This method is invoked by ReactiveSocket implementation once in an entire - /// lifetime of the connection. The connection MUST NOT assume an ownership of - /// provided Subscriber. - virtual void setInput( - yarpl::Reference>> - framesSink) = 0; - - /// Obtains a Subscriber that should be fed with frames to send (a writer). + /// If setInput() has already been called, then calling setInput() again will + /// complete the previous subscriber. + virtual void setInput(std::shared_ptr) = 0; + + /// Write a serialized frame to the connection. /// - /// This method is invoked by ReactiveSocket - /// implementation once in an entire lifetime of the connection. The - /// connection MUST manage the lifetime of provided Subscriber. - virtual yarpl::Reference>> - getOutput() = 0; - - /// property telling whether the duplex connection respects frame boundaries - virtual bool isFramed() { + /// Does nothing if the underlying connection is closed. + virtual void send(std::unique_ptr) = 0; + + /// Whether the duplex connection respects frame boundaries. + virtual bool isFramed() const { return false; } }; -} + +} // namespace rsocket diff --git a/rsocket/Payload.cpp b/rsocket/Payload.cpp index 4f008ca64..b4037d888 100644 --- a/rsocket/Payload.cpp +++ b/rsocket/Payload.cpp @@ -1,64 +1,77 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/Payload.h" + #include #include -#include "rsocket/framing/Frame.h" + +#include "rsocket/internal/Common.h" namespace rsocket { -Payload::Payload( - std::unique_ptr _data, - std::unique_ptr _metadata) - : data(std::move(_data)), metadata(std::move(_metadata)) {} - -Payload::Payload(const std::string& _data, const std::string& _metadata) - : data(folly::IOBuf::copyBuffer(_data)) { - if (!_metadata.empty()) { - metadata = folly::IOBuf::copyBuffer(_metadata); - } +namespace { + +std::string moveIOBufToString(std::unique_ptr buf) { + return buf ? buf->moveToFbString().toStdString() : ""; +} + +std::string cloneIOBufToString(std::unique_ptr const& buf) { + return buf ? buf->cloneAsValue().moveToFbString().toStdString() : ""; } -void Payload::checkFlags(FrameFlags flags) const { - DCHECK(!!(flags & FrameFlags::METADATA) == bool(metadata)); +} // namespace + +Payload::Payload( + std::unique_ptr d, + std::unique_ptr m) + : data{std::move(d)}, metadata{std::move(m)} {} + +Payload::Payload(folly::StringPiece d, folly::StringPiece m) + : data{folly::IOBuf::copyBuffer(d.data(), d.size())} { + if (!m.empty()) { + metadata = folly::IOBuf::copyBuffer(m.data(), m.size()); + } } std::ostream& operator<<(std::ostream& os, const Payload& payload) { return os << "Metadata(" - << (payload.metadata - ? folly::to( - payload.metadata->computeChainDataLength()) - : "0") - << (payload.metadata - ? "): '" + - folly::humanify( - payload.metadata->cloneAsValue().moveToFbString().substr(0, 80)) + - "'" - : "): ") + << (payload.metadata ? payload.metadata->computeChainDataLength() + : 0) + << "): " + << (payload.metadata ? "'" + humanify(payload.metadata) + "'" + : "") << ", Data(" - << (payload.data ? folly::to( - payload.data->computeChainDataLength()) - : "0") - << (payload.data - ? "): '" + - folly::humanify( - payload.data->cloneAsValue().moveToFbString().substr(0, 80)) + - "'" - : "): "); + << (payload.data ? payload.data->computeChainDataLength() : 0) + << "): " + << (payload.data ? "'" + humanify(payload.data) + "'" : ""); } std::string Payload::moveDataToString() { - if (!data) { - return ""; - } - return data->moveToFbString().toStdString(); + return moveIOBufToString(std::move(data)); } std::string Payload::cloneDataToString() const { - if (!data) { - return ""; - } - return data->cloneAsValue().moveToFbString().toStdString(); + return cloneIOBufToString(data); +} + +std::string Payload::moveMetadataToString() { + return moveIOBufToString(std::move(metadata)); +} + +std::string Payload::cloneMetadataToString() const { + return cloneIOBufToString(metadata); } void Payload::clear() { @@ -71,15 +84,28 @@ Payload Payload::clone() const { if (data) { out.data = data->clone(); } - if (metadata) { out.metadata = metadata->clone(); } return out; } -FrameFlags Payload::getFlags() const { - return (metadata != nullptr ? FrameFlags::METADATA : FrameFlags::EMPTY); +ErrorWithPayload::ErrorWithPayload(Payload&& payload) + : payload(std::move(payload)) {} + +ErrorWithPayload::ErrorWithPayload(const ErrorWithPayload& oth) { + payload = oth.payload.clone(); +} + +ErrorWithPayload& ErrorWithPayload::operator=(const ErrorWithPayload& oth) { + payload = oth.payload.clone(); + return *this; +} + +std::ostream& operator<<( + std::ostream& os, + const ErrorWithPayload& errorWithPayload) { + return os << "rsocket::ErrorWithPayload: " << errorWithPayload.payload; } } // namespace rsocket diff --git a/rsocket/Payload.h b/rsocket/Payload.h index 4776cc59a..c21587014 100644 --- a/rsocket/Payload.h +++ b/rsocket/Payload.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -6,8 +18,6 @@ #include #include -#include "rsocket/framing/FrameFlags.h" - namespace rsocket { /// The type of a read-only view on a binary buffer. @@ -20,18 +30,19 @@ struct Payload { std::unique_ptr metadata = std::unique_ptr()); explicit Payload( - const std::string& data, - const std::string& metadata = std::string()); + folly::StringPiece data, + folly::StringPiece metadata = folly::StringPiece{}); explicit operator bool() const { return data != nullptr || metadata != nullptr; } - FrameFlags getFlags() const; - void checkFlags(FrameFlags flags) const; - std::string moveDataToString(); std::string cloneDataToString() const; + + std::string moveMetadataToString(); + std::string cloneMetadataToString() const; + void clear(); Payload clone() const; @@ -40,5 +51,23 @@ struct Payload { std::unique_ptr metadata; }; -std::ostream& operator<<(std::ostream& os, const Payload& payload); -} +struct ErrorWithPayload : public std::exception { + explicit ErrorWithPayload(Payload&& payload); + + // folly::ExceptionWrapper requires exceptions to have copy constructors + ErrorWithPayload(const ErrorWithPayload& oth); + ErrorWithPayload& operator=(const ErrorWithPayload&); + ErrorWithPayload(ErrorWithPayload&&) = default; + ErrorWithPayload& operator=(ErrorWithPayload&&) = default; + + const char* what() const noexcept override { + return "ErrorWithPayload"; + } + + Payload payload; +}; + +std::ostream& operator<<(std::ostream& os, const Payload&); +std::ostream& operator<<(std::ostream& os, const ErrorWithPayload&); + +} // namespace rsocket diff --git a/rsocket/README.md b/rsocket/README.md new file mode 100644 index 000000000..3b811aca9 --- /dev/null +++ b/rsocket/README.md @@ -0,0 +1,27 @@ +# rsocket-cpp + +C++ implementation of [RSocket](https://rsocket.io) + + +[![Coverage Status](https://coveralls.io/repos/github/rsocket/rsocket-cpp/badge.svg?branch=master)](https://coveralls.io/github/rsocket/rsocket-cpp?branch=master) + +# Dependencies + +Install `folly`: + +``` +brew install folly +``` + +# Building and running tests + +After installing dependencies as above, you can build and run tests with: + +``` +# inside root ./rsocket-cpp +mkdir -p build +cd build +cmake -DCMAKE_BUILD_TYPE=DEBUG ../ +make -j +./tests +``` diff --git a/rsocket/RSocket.cpp b/rsocket/RSocket.cpp index 01d3b7442..e83c5ca71 100644 --- a/rsocket/RSocket.cpp +++ b/rsocket/RSocket.cpp @@ -1,59 +1,140 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocket.h" namespace rsocket { -folly::Future> RSocket::createConnectedClient( - std::unique_ptr connectionFactory, +folly::Future> RSocket::createConnectedClient( + std::shared_ptr connectionFactory, SetupParameters setupParameters, std::shared_ptr responder, - std::unique_ptr keepaliveTimer, + std::chrono::milliseconds keepaliveInterval, std::shared_ptr stats, std::shared_ptr connectionEvents, std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler, - OnRSocketResume) { - auto c = std::shared_ptr(new RSocketClient( + folly::EventBase* stateMachineEvb) { + CHECK(resumeManager) + << "provide ResumeManager::makeEmpty() instead of nullptr"; + auto protocolVersion = setupParameters.protocolVersion; + auto createRSC = + [connectionFactory, + setupParameters = std::move(setupParameters), + responder = std::move(responder), + keepaliveInterval, + stats = std::move(stats), + connectionEvents = std::move(connectionEvents), + resumeManager = std::move(resumeManager), + coldResumeHandler = std::move(coldResumeHandler), + stateMachineEvb]( + ConnectionFactory::ConnectedDuplexConnection connection) mutable { + VLOG(3) << "createConnectedClient received DuplexConnection"; + return RSocket::createClientFromConnection( + std::move(connection.connection), + connection.eventBase, + std::move(setupParameters), + std::move(connectionFactory), + std::move(responder), + keepaliveInterval, + std::move(stats), + std::move(connectionEvents), + std::move(resumeManager), + std::move(coldResumeHandler), + stateMachineEvb); + }; + + return connectionFactory->connect(protocolVersion, ResumeStatus::NEW_SESSION) + .thenValue( + [createRSC = std::move(createRSC)]( + ConnectionFactory::ConnectedDuplexConnection connection) mutable { + // fromConnection method must be called from the transport eventBase + // and since there is no guarantee that the Future returned from the + // connectionFactory::connect method is executed on the event base, + // we have to ensure it by using folly::via + auto transportEvb = &connection.eventBase; + return folly::via( + transportEvb, + [connection = std::move(connection), + createRSC = std::move(createRSC)]() mutable { + return createRSC(std::move(connection)); + }); + }); +} + +folly::Future> RSocket::createResumedClient( + std::shared_ptr connectionFactory, + ResumeIdentificationToken token, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler, + std::shared_ptr responder, + std::chrono::milliseconds keepaliveInterval, + std::shared_ptr stats, + std::shared_ptr connectionEvents, + ProtocolVersion protocolVersion, + folly::EventBase* stateMachineEvb) { + auto* c = new RSocketClient( std::move(connectionFactory), - std::move(setupParameters), + std::move(protocolVersion), + std::move(token), std::move(responder), - std::move(keepaliveTimer), + keepaliveInterval, std::move(stats), std::move(connectionEvents), std::move(resumeManager), - std::move(coldResumeHandler))); + std::move(coldResumeHandler), + stateMachineEvb); - return c->connect().then([c]() mutable { return c; }); + return c->resume().thenValue( + [client = std::unique_ptr(c)](auto&&) mutable { + return std::move(client); + }); } -folly::Future> RSocket::createResumedClient( - std::unique_ptr connectionFactory, - SetupParameters setupParameters, - std::shared_ptr resumeManager, - std::shared_ptr coldResumeHandler, - OnRSocketResume, +std::unique_ptr RSocket::createClientFromConnection( + std::unique_ptr connection, + folly::EventBase& transportEvb, + SetupParameters params, + std::shared_ptr connectionFactory, std::shared_ptr responder, - std::unique_ptr keepaliveTimer, + std::chrono::milliseconds keepaliveInterval, std::shared_ptr stats, - std::shared_ptr connectionEvents) { - auto c = std::shared_ptr(new RSocketClient( + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler, + folly::EventBase* stateMachineEvb) { + auto client = std::unique_ptr(new RSocketClient( std::move(connectionFactory), - std::move(setupParameters), + params.protocolVersion, + params.token, std::move(responder), - std::move(keepaliveTimer), + keepaliveInterval, std::move(stats), std::move(connectionEvents), std::move(resumeManager), - std::move(coldResumeHandler))); - - return c->resume().then([c]() mutable { return c; }); + std::move(coldResumeHandler), + stateMachineEvb)); + client->fromConnection( + std::move(connection), transportEvb, std::move(params)); + return client; } std::unique_ptr RSocket::createServer( - std::unique_ptr connectionAcceptor) { - return std::make_unique(std::move(connectionAcceptor)); -} + std::unique_ptr connectionAcceptor, + std::shared_ptr stats) { + return std::make_unique( + std::move(connectionAcceptor), std::move(stats)); } + +} // namespace rsocket diff --git a/rsocket/RSocket.h b/rsocket/RSocket.h index b5040aaef..13f642830 100644 --- a/rsocket/RSocket.h +++ b/rsocket/RSocket.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -13,50 +25,62 @@ namespace rsocket { class RSocket { public: // Creates a RSocketClient which is connected to the remoteside. - static folly::Future> createConnectedClient( - std::unique_ptr, + // keepaliveInterval of 0 will result in no keepAlives + static folly::Future> createConnectedClient( + std::shared_ptr, SetupParameters setupParameters = SetupParameters(), std::shared_ptr responder = std::make_shared(), - std::unique_ptr keepaliveTimer = - std::unique_ptr(), + std::chrono::milliseconds keepaliveInterval = kDefaultKeepaliveInterval, std::shared_ptr stats = RSocketStats::noop(), std::shared_ptr connectionEvents = std::shared_ptr(), - std::shared_ptr resumeManager = - std::shared_ptr(), + std::shared_ptr resumeManager = ResumeManager::makeEmpty(), std::shared_ptr coldResumeHandler = std::shared_ptr(), - OnRSocketResume onRSocketResume = - [](std::vector, std::vector) { return false; }); + folly::EventBase* stateMachineEvb = nullptr); // Creates a RSocketClient which cold-resumes from the provided state - static folly::Future> createResumedClient( - std::unique_ptr, - SetupParameters setupParameters, + // keepaliveInterval of 0 will result in no keepAlives + static folly::Future> createResumedClient( + std::shared_ptr, + ResumeIdentificationToken token, std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler, - OnRSocketResume onRSocketResume, std::shared_ptr responder = std::make_shared(), - std::unique_ptr keepaliveTimer = - std::unique_ptr(), + std::chrono::milliseconds keepaliveInterval = kDefaultKeepaliveInterval, std::shared_ptr stats = RSocketStats::noop(), std::shared_ptr connectionEvents = - std::shared_ptr()); + std::shared_ptr(), + ProtocolVersion protocolVersion = ProtocolVersion::Latest, + folly::EventBase* stateMachineEvb = nullptr); + + // Creates a RSocketClient from an existing DuplexConnection. A keepalive + // interval of 0 will result in no keepalives. + static std::unique_ptr createClientFromConnection( + std::unique_ptr connection, + folly::EventBase& transportEvb, + SetupParameters setupParameters = SetupParameters(), + std::shared_ptr connectionFactory = nullptr, + std::shared_ptr responder = + std::make_shared(), + std::chrono::milliseconds keepaliveInterval = kDefaultKeepaliveInterval, + std::shared_ptr stats = RSocketStats::noop(), + std::shared_ptr connectionEvents = nullptr, + std::shared_ptr resumeManager = ResumeManager::makeEmpty(), + std::shared_ptr coldResumeHandler = nullptr, + folly::EventBase* stateMachineEvb = nullptr); // A convenience function to create RSocketServer static std::unique_ptr createServer( - std::unique_ptr); + std::unique_ptr, + std::shared_ptr stats = RSocketStats::noop()); RSocket() = delete; - RSocket(const RSocket&) = delete; - RSocket(RSocket&&) = delete; - RSocket& operator=(const RSocket&) = delete; - RSocket& operator=(RSocket&&) = delete; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketClient.cpp b/rsocket/RSocketClient.cpp index 207ebcc26..7f10b3ead 100644 --- a/rsocket/RSocketClient.cpp +++ b/rsocket/RSocketClient.cpp @@ -1,158 +1,242 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketClient.h" #include "rsocket/RSocketRequester.h" #include "rsocket/RSocketResponder.h" #include "rsocket/RSocketStats.h" -#include "rsocket/framing/FrameTransport.h" +#include "rsocket/framing/FrameTransportImpl.h" #include "rsocket/framing/FramedDuplexConnection.h" +#include "rsocket/framing/ScheduledFrameTransport.h" #include "rsocket/internal/ClientResumeStatusCallback.h" -#include "rsocket/internal/FollyKeepaliveTimer.h" -#include "rsocket/internal/RSocketConnectionManager.h" - -using namespace folly; +#include "rsocket/internal/KeepaliveTimer.h" namespace rsocket { -RSocketClient::~RSocketClient() { - VLOG(4) << "RSocketClient destroyed .."; -} - -std::shared_ptr RSocketClient::getRequester() const { - return requester_; -} - RSocketClient::RSocketClient( - std::unique_ptr connectionFactory, - SetupParameters setupParameters, + std::shared_ptr connectionFactory, + ProtocolVersion protocolVersion, + ResumeIdentificationToken token, std::shared_ptr responder, - std::unique_ptr keepaliveTimer, + std::chrono::milliseconds keepaliveInterval, std::shared_ptr stats, std::shared_ptr connectionEvents, std::shared_ptr resumeManager, std::shared_ptr coldResumeHandler, - OnRSocketResume) + folly::EventBase* stateMachineEvb) : connectionFactory_(std::move(connectionFactory)), - connectionManager_(std::make_unique()), - setupParameters_(std::move(setupParameters)), responder_(std::move(responder)), - keepaliveTimer_(std::move(keepaliveTimer)), + keepaliveInterval_(keepaliveInterval), stats_(stats), connectionEvents_(connectionEvents), resumeManager_(resumeManager), coldResumeHandler_(coldResumeHandler), - protocolVersion_(setupParameters_.protocolVersion), - token_(setupParameters_.token) {} - -folly::Future RSocketClient::connect() { - VLOG(2) << "Starting connection"; + protocolVersion_(protocolVersion), + token_(std::move(token)), + evb_(stateMachineEvb) { + CHECK(resumeManager_) + << "provide ResumeManager::makeEmpty() instead of nullptr"; +} - folly::Promise promise; - auto future = promise.getFuture(); +RSocketClient::~RSocketClient() { + VLOG(3) << "~RSocketClient .."; - connectionFactory_->connect([ this, promise = std::move(promise) ]( - std::unique_ptr connection, - folly::EventBase & eventBase) mutable { - VLOG(3) << "onConnect received DuplexConnection"; - evb_ = &eventBase; - createState(eventBase); - std::unique_ptr framedConnection; - if (connection->isFramed()) { - framedConnection = std::move(connection); - } else { - framedConnection = std::make_unique( - std::move(connection), protocolVersion_); - } - stateMachine_->connectClientSendSetup( - std::move(framedConnection), std::move(setupParameters_)); - promise.setValue(); + evb_->runImmediatelyOrRunInEventBaseThreadAndWait([sm = stateMachine_] { + auto exn = folly::make_exception_wrapper( + "RSocketClient is closing"); + sm->close(std::move(exn), StreamCompletionSignal::CONNECTION_END); }); +} - return future; +const std::shared_ptr& RSocketClient::getRequester() const { + return requester_; +} + +// Returns if this client is currently disconnected +bool RSocketClient::isDisconnected() const { + return stateMachine_->isDisconnected(); } folly::Future RSocketClient::resume() { + CHECK(connectionFactory_) + << "The client was likely created without ConnectionFactory. Can't " + << "resume"; + + return connectionFactory_->connect(protocolVersion_, ResumeStatus::RESUMING) + .thenValue( + [this]( + ConnectionFactory::ConnectedDuplexConnection connection) mutable { + return resumeFromConnection(std::move(connection)); + }); +} + +folly::Future RSocketClient::resumeFromConnection( + ConnectionFactory::ConnectedDuplexConnection connection) { VLOG(2) << "Resuming connection"; - // TODO: CHECK whether the underlying transport is closed before attempting - // resumption. - // - folly::Promise promise; - auto future = promise.getFuture(); + if (!evb_) { + // Cold-resumption. EventBase hasn't been explicitly set for SM by the + // application. Use the transport's eventBase. + evb_ = &connection.eventBase; + } - connectionFactory_->connect([ this, promise = std::move(promise) ]( - std::unique_ptr connection, - folly::EventBase & eventBase) mutable { - - CHECK( - !evb_ /* cold-resumption */ || - evb_ == &eventBase /* warm-resumption */); - - class ResumeCallback : public ClientResumeStatusCallback { - public: - explicit ResumeCallback(folly::Promise promise) - : promise_(std::move(promise)) {} - - void onResumeOk() noexcept override { - promise_.setValue(); - } - - void onResumeError(folly::exception_wrapper ex) noexcept override { - promise_.setException(ex); - } - private: - folly::Promise promise_; - }; - - auto resumeCallback = std::make_unique(std::move(promise)); - std::unique_ptr framedConnection; - if (connection->isFramed()) { - framedConnection = std::move(connection); - } else { - framedConnection = std::make_unique( - std::move(connection), protocolVersion_); + class ResumeCallback : public ClientResumeStatusCallback { + public: + explicit ResumeCallback(folly::Promise promise) + : promise_(std::move(promise)) {} + + void onResumeOk() noexcept override { + promise_.setValue(); + } + + void onResumeError(folly::exception_wrapper ex) noexcept override { + promise_.setException(ex); } - auto frameTransport = - yarpl::make_ref(std::move(framedConnection)); + private: + folly::Promise promise_; + }; + + folly::Promise promise; + auto future = promise.getFuture(); + + auto resumeCallback = std::make_unique(std::move(promise)); + std::unique_ptr framedConnection; + if (connection.connection->isFramed()) { + framedConnection = std::move(connection.connection); + } else { + framedConnection = std::make_unique( + std::move(connection.connection), protocolVersion_); + } + auto transport = + std::make_shared(std::move(framedConnection)); + + std::shared_ptr ft; + if (evb_ != &connection.eventBase) { + // If the StateMachine EventBase is different from the transport + // EventBase, then use ScheduledFrameTransport and + // ScheduledFrameProcessor to ensure the RSocketStateMachine and + // Transport live on the desired EventBases + ft = std::make_shared( + std::move(transport), + &connection.eventBase, /* Transport EventBase */ + evb_); /* StateMachine EventBase */ + } else { + ft = std::move(transport); + } + + evb_->runInEventBaseThread([this, + frameTransport = std::move(ft), + callback = std::move(resumeCallback)]() mutable { if (!stateMachine_) { - createState(eventBase); + createState(); } - stateMachine_->tryClientResume( - token_, std::move(frameTransport), std::move(resumeCallback)); + stateMachine_->resumeClient( + token_, + std::move(frameTransport), + std::move(callback), + protocolVersion_); }); return future; } -void RSocketClient::disconnect(folly::exception_wrapper ex) { - CHECK(stateMachine_); - evb_->runInEventBaseThread([ this, ex = std::move(ex) ] { - VLOG(2) << "Disconnecting RSocketStateMachine on EventBase"; - stateMachine_->disconnect(std::move(ex)); - }); +folly::Future RSocketClient::disconnect( + folly::exception_wrapper ew) { + if (!stateMachine_) { + return folly::makeFuture( + std::runtime_error{"RSocketClient must always have a state machine"}); + } + + auto work = [sm = stateMachine_, e = std::move(ew)]() mutable { + sm->disconnect(std::move(e)); + }; + + if (evb_->isInEventBaseThread()) { + VLOG(2) << "Running RSocketClient disconnect synchronously"; + work(); + return folly::unit; + } + + VLOG(2) << "Scheduling RSocketClient disconnect"; + return folly::via(evb_, work); } -void RSocketClient::createState(folly::EventBase& eventBase) { - CHECK(eventBase.isInEventBaseThread()); +void RSocketClient::fromConnection( + std::unique_ptr connection, + folly::EventBase& transportEvb, + SetupParameters params) { + if (!evb_) { + // If no EventBase is given for the stateMachine, then use the transport's + // EventBase to drive the stateMachine. + evb_ = &transportEvb; + } + createState(); + + std::unique_ptr framed; + if (connection->isFramed()) { + framed = std::move(connection); + } else { + framed = std::make_unique( + std::move(connection), params.protocolVersion); + } + auto transport = std::make_shared(std::move(framed)); + + if (evb_ == &transportEvb) { + stateMachine_->connectClient(std::move(transport), std::move(params)); + return; + } + + // If the StateMachine EventBase is different from the transport EventBase, + // then use ScheduledFrameTransport and ScheduledFrameProcessor to ensure the + // RSocketStateMachine and Transport live on the desired EventBases. + auto scheduledFT = std::make_shared( + std::move(transport), &transportEvb, evb_); + evb_->runInEventBaseThread([stateMachine = stateMachine_, + scheduledFT = std::move(scheduledFT), + params = std::move(params)]() mutable { + stateMachine->connectClient(std::move(scheduledFT), std::move(params)); + }); +} +void RSocketClient::createState() { // Creation of state is permitted only once for each RSocketClient. // When evb is removed from RSocketStateMachine, the state can be // created in constructor - CHECK(!stateMachine_); + CHECK(!stateMachine_) << "A stateMachine has already been created"; - stateMachine_ = std::make_shared( - eventBase, - responder_, - std::move(keepaliveTimer_), - ReactiveSocketMode::CLIENT, - stats_, - connectionEvents_); + if (!responder_) { + responder_ = std::make_shared(); + } - requester_ = std::make_shared(stateMachine_, eventBase); + std::unique_ptr keepaliveTimer{nullptr}; + if (keepaliveInterval_ > std::chrono::milliseconds(0)) { + keepaliveTimer = + std::make_unique(keepaliveInterval_, *evb_); + } - connectionManager_->manageConnection(stateMachine_, eventBase); + stateMachine_ = std::make_shared( + std::move(responder_), + std::move(keepaliveTimer), + RSocketMode::CLIENT, + std::move(stats_), + std::move(connectionEvents_), + std::move(resumeManager_), + std::move(coldResumeHandler_)); + + requester_ = std::make_shared(stateMachine_, *evb_); } } // namespace rsocket diff --git a/rsocket/RSocketClient.h b/rsocket/RSocketClient.h index abcccaba5..070a3f6be 100644 --- a/rsocket/RSocketClient.h +++ b/rsocket/RSocketClient.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -6,16 +18,17 @@ #include "rsocket/ColdResumeHandler.h" #include "rsocket/ConnectionFactory.h" +#include "rsocket/DuplexConnection.h" #include "rsocket/RSocketConnectionEvents.h" #include "rsocket/RSocketParameters.h" #include "rsocket/RSocketRequester.h" +#include "rsocket/RSocketResponder.h" #include "rsocket/RSocketStats.h" #include "rsocket/ResumeManager.h" namespace rsocket { class RSocket; -class RSocketConnectionManager; /** * API for connecting to an RSocket server. Created with RSocket class. @@ -25,57 +38,63 @@ class RSocketClient { public: ~RSocketClient(); - RSocketClient(const RSocketClient&) = delete; // copy - RSocketClient(RSocketClient&&) = default; // move - RSocketClient& operator=(const RSocketClient&) = delete; // copy - RSocketClient& operator=(RSocketClient&&) = default; // move + RSocketClient(const RSocketClient&) = delete; + RSocketClient(RSocketClient&&) = delete; + RSocketClient& operator=(const RSocketClient&) = delete; + RSocketClient& operator=(RSocketClient&&) = delete; friend class RSocket; // Returns the RSocketRequester associated with the RSocketClient. - std::shared_ptr getRequester() const; + const std::shared_ptr& getRequester() const; - // Resumes the connection. If a stateMachine already exists, - // it provides a warm-resumption. If a stateMachine does not exist, - // it does a cold-resumption. The returned future resolves on successful - // resumption. Else either a ConnectionException or a ResumptionException - // is raised. + // Returns if this client is currently disconnected + bool isDisconnected() const; + + // Resumes the client's connection. If the client was previously connected + // this will attempt a warm-resumption. Otherwise this will attempt a + // cold-resumption. + // + // Uses the internal ConnectionFactory instance to re-connect. folly::Future resume(); - // Disconnect the underlying transport - void disconnect(folly::exception_wrapper); + // Like resume(), but this doesn't use a ConnectionFactory and instead takes + // the connection and transport EventBase by argument. + // + // Prefer using resume() if possible. + folly::Future resumeFromConnection( + ConnectionFactory::ConnectedDuplexConnection); + + // Disconnect the underlying transport. + folly::Future disconnect(folly::exception_wrapper = {}); private: // Private constructor. RSocket class should be used to create instances // of RSocketClient. RSocketClient( - std::unique_ptr, - SetupParameters setupParameters = SetupParameters(), - std::shared_ptr responder = - std::make_shared(), - std::unique_ptr keepaliveTimer = - std::unique_ptr(), - std::shared_ptr stats = RSocketStats::noop(), - std::shared_ptr connectionEvents = - std::shared_ptr(), - std::shared_ptr resumeManager = - std::shared_ptr(), - std::shared_ptr coldResumeHandler = - std::shared_ptr(), - OnRSocketResume onRSocketResume = - [](std::vector, std::vector) { return false; }); - - // Connects to the remote side and creates state. - folly::Future connect(); + std::shared_ptr, + ProtocolVersion protocolVersion, + ResumeIdentificationToken token, + std::shared_ptr responder, + std::chrono::milliseconds keepaliveInterval, + std::shared_ptr stats, + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler, + folly::EventBase* stateMachineEvb); + + // Create stateMachine with the given DuplexConnection + void fromConnection( + std::unique_ptr connection, + folly::EventBase& transportEvb, + SetupParameters setupParameters); // Creates RSocketStateMachine and RSocketRequester - void createState(folly::EventBase& eventBase); + void createState(); - std::unique_ptr connectionFactory_; - std::unique_ptr connectionManager_; - SetupParameters setupParameters_; + const std::shared_ptr connectionFactory_; std::shared_ptr responder_; - std::unique_ptr keepaliveTimer_; + const std::chrono::milliseconds keepaliveInterval_; std::shared_ptr stats_; std::shared_ptr connectionEvents_; std::shared_ptr resumeManager_; @@ -84,11 +103,18 @@ class RSocketClient { std::shared_ptr stateMachine_; std::shared_ptr requester_; - // Remember the evb on which the client was created. Ensure warme-resume() - // operations are done on the same evb. - folly::EventBase* evb_; - - ProtocolVersion protocolVersion_; - ResumeIdentificationToken token_; + const ProtocolVersion protocolVersion_; + const ResumeIdentificationToken token_; + + // Remember the StateMachine's evb (supplied through constructor). If no + // EventBase is provided, the underlying transport's EventBase will be used + // to drive the StateMachine. + // If an EventBase is provided for StateMachine and underlying Transport's + // EventBase is different from it, then we use Scheduled* classes to let the + // StateMachine and Transport live on different EventBases. + // It might happen that the StateMachine and Transport live on same + // EventBase, but the transport ends up being in different EventBase after + // resumption, and vice versa. + folly::EventBase* evb_{nullptr}; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketConnectionEvents.h b/rsocket/RSocketConnectionEvents.h index 20fb8601a..177a819d2 100644 --- a/rsocket/RSocketConnectionEvents.h +++ b/rsocket/RSocketConnectionEvents.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -8,12 +20,35 @@ class exception_wrapper; namespace rsocket { +// The application should implement this interface to get called-back +// on network events. class RSocketConnectionEvents { public: virtual ~RSocketConnectionEvents() = default; + // This method gets called when the underlying transport is connected to the + // remote side. This does not necessarily mean that the RSocket connection + // will be successful. As an example, the transport might get reconnected + // for an existing RSocketStateMachine. But resumption at the RSocket layer + // might not succeed. virtual void onConnected() {} + + // This gets called when the underlying transport has disconnected. This also + // means the RSocket connection is disconnected. virtual void onDisconnected(const folly::exception_wrapper&) {} + + // This gets called when the RSocketStateMachine is closed. You cant use this + // RSocketStateMachine anymore. virtual void onClosed(const folly::exception_wrapper&) {} + + // This gets called when no more frames can be sent over the RSocket streams. + // This typically happens immediately after onDisconnected(). The streams can + // be resumed after onStreamsResumed() event. + virtual void onStreamsPaused() {} + + // This gets called when the underlying transport has been successfully + // connected AND the connection can be resumed at the RSocket layer. This + // typically gets called after onConnected() + virtual void onStreamsResumed() {} }; -} +} // namespace rsocket diff --git a/rsocket/RSocketErrors.h b/rsocket/RSocketErrors.h index 8add6294f..e570e7532 100644 --- a/rsocket/RSocketErrors.h +++ b/rsocket/RSocketErrors.h @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include #include namespace rsocket { @@ -19,7 +32,7 @@ class RSocketError : public std::runtime_error { * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#error-codes * @return */ - virtual int getErrorCode() = 0; + virtual int getErrorCode() const = 0; }; /** @@ -29,7 +42,7 @@ class InvalidSetupError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000001; } @@ -45,7 +58,7 @@ class UnsupportedSetupError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000002; } @@ -61,7 +74,7 @@ class RejectedSetupError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000003; } @@ -77,7 +90,7 @@ class RejectedResumeError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000004; } @@ -93,7 +106,7 @@ class ConnectionError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000101; } @@ -103,13 +116,13 @@ class ConnectionError : public RSocketError { }; /** -* Error Code: CONNECTION_CLOSE 0x00000102 -*/ + * Error Code: CONNECTION_CLOSE 0x00000102 + */ class ConnectionCloseError : public RSocketError { public: using RSocketError::RSocketError; - int getErrorCode() override { + int getErrorCode() const override { return 0x00000102; } @@ -117,4 +130,4 @@ class ConnectionCloseError : public RSocketError { return "CONNECTION_CLOSE"; } }; -} +} // namespace rsocket diff --git a/rsocket/RSocketException.h b/rsocket/RSocketException.h index ae99ab70d..9dc9d61e7 100644 --- a/rsocket/RSocketException.h +++ b/rsocket/RSocketException.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -21,4 +33,4 @@ class ResumptionException : public RSocketException { class ConnectionException : public RSocketException { using RSocketException::RSocketException; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketParameters.cpp b/rsocket/RSocketParameters.cpp index b59b5ba62..08f221e44 100644 --- a/rsocket/RSocketParameters.cpp +++ b/rsocket/RSocketParameters.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketParameters.h" @@ -14,4 +26,4 @@ std::ostream& operator<<( << " token: " << setupPayload.token << " resumable: " << setupPayload.resumable; } -} +} // namespace rsocket diff --git a/rsocket/RSocketParameters.h b/rsocket/RSocketParameters.h index ba77ea6e7..0605bcfcf 100644 --- a/rsocket/RSocketParameters.h +++ b/rsocket/RSocketParameters.h @@ -1,12 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include +#include +#include #include +#include + #include "rsocket/Payload.h" -#include "rsocket/framing/FrameSerializer.h" -#include "rsocket/internal/Common.h" +#include "rsocket/framing/Frame.h" namespace rsocket { @@ -15,8 +29,8 @@ using OnRSocketResume = class RSocketParameters { public: - RSocketParameters(bool _resumable, ProtocolVersion _protocolVersion) - : resumable(_resumable), protocolVersion(std::move(_protocolVersion)) {} + RSocketParameters(bool resume, ProtocolVersion version) + : resumable{resume}, protocolVersion{std::move(version)} {} bool resumable; ProtocolVersion protocolVersion; @@ -25,19 +39,18 @@ class RSocketParameters { class SetupParameters : public RSocketParameters { public: explicit SetupParameters( - std::string _metadataMimeType = "text/plain", - std::string _dataMimeType = "text/plain", - Payload _payload = Payload(), - bool _resumable = false, - const ResumeIdentificationToken& _token = + std::string metadataMime = "text/plain", + std::string dataMime = "text/plain", + Payload buf = Payload(), + bool resume = false, + ResumeIdentificationToken resumeToken = ResumeIdentificationToken::generateNew(), - ProtocolVersion _protocolVersion = - FrameSerializer::getCurrentProtocolVersion()) - : RSocketParameters(_resumable, _protocolVersion), - metadataMimeType(std::move(_metadataMimeType)), - dataMimeType(std::move(_dataMimeType)), - payload(std::move(_payload)), - token(_token) {} + ProtocolVersion version = ProtocolVersion::Latest) + : RSocketParameters(resume, version), + metadataMimeType(std::move(metadataMime)), + dataMimeType(std::move(dataMime)), + payload(std::move(buf)), + token(resumeToken) {} std::string metadataMimeType; std::string dataMimeType; @@ -50,18 +63,18 @@ std::ostream& operator<<(std::ostream&, const SetupParameters&); class ResumeParameters : public RSocketParameters { public: ResumeParameters( - ResumeIdentificationToken _token, - ResumePosition _serverPosition, - ResumePosition _clientPosition, - ProtocolVersion _protocolVersion) - : RSocketParameters(true, _protocolVersion), - token(std::move(_token)), - serverPosition(_serverPosition), - clientPosition(_clientPosition) {} + ResumeIdentificationToken resumeToken, + ResumePosition serverPos, + ResumePosition clientPos, + ProtocolVersion version) + : RSocketParameters(true, version), + token(std::move(resumeToken)), + serverPosition(serverPos), + clientPosition(clientPos) {} ResumeIdentificationToken token; ResumePosition serverPosition; ResumePosition clientPosition; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/RSocketRequester.cpp b/rsocket/RSocketRequester.cpp index 91c06cce6..cf1799506 100644 --- a/rsocket/RSocketRequester.cpp +++ b/rsocket/RSocketRequester.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketRequester.h" @@ -7,135 +19,163 @@ #include "rsocket/internal/ScheduledSingleObserver.h" #include "rsocket/internal/ScheduledSubscriber.h" #include "yarpl/Flowable.h" +#include "yarpl/single/SingleSubscriptions.h" using namespace folly; -using namespace yarpl; namespace rsocket { +namespace { + +template +void runOnCorrectThread(folly::EventBase& evb, Fn fn) { + if (evb.isInEventBaseThread()) { + fn(); + } else { + evb.runInEventBaseThread(std::move(fn)); + } +} + +} // namespace + RSocketRequester::RSocketRequester( std::shared_ptr srs, EventBase& eventBase) - : stateMachine_(std::move(srs)), eventBase_(eventBase) {} + : stateMachine_{std::move(srs)}, eventBase_{&eventBase} {} RSocketRequester::~RSocketRequester() { VLOG(1) << "Destroying RSocketRequester"; } void RSocketRequester::closeSocket() { - eventBase_.add([stateMachine = std::move(stateMachine_)]{ + eventBase_->runInEventBaseThread([stateMachine = std::move(stateMachine_)] { VLOG(2) << "Closing RSocketStateMachine on EventBase"; - stateMachine->close( - folly::exception_wrapper(), StreamCompletionSignal::SOCKET_CLOSED); + stateMachine->close({}, StreamCompletionSignal::SOCKET_CLOSED); }); } -yarpl::Reference> +std::shared_ptr> RSocketRequester::requestChannel( - yarpl::Reference> - requestStream) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::flowable::Flowables::fromPublisher([ - eb = &eventBase_, - requestStream = std::move(requestStream), - srs = stateMachine_ - ](yarpl::Reference> - subscriber) mutable { - eb->runInEventBaseThread([ - requestStream = std::move(requestStream), - subscriber = std::move(subscriber), - srs = std::move(srs), - eb - ]() mutable { - auto responseSink = srs->streamsFactory().createChannelRequester( - yarpl::make_ref>( - std::move(subscriber), *eb)); - // responseSink is wrapped with thread scheduling - // so all emissions happen on the right thread - requestStream->subscribe( - yarpl::make_ref>(std::move(responseSink), - *eb)); - }); - }); + std::shared_ptr> + requestStream) { + return requestChannel({}, false, std::move(requestStream)); +} + +std::shared_ptr> +RSocketRequester::requestChannel( + Payload request, + std::shared_ptr> + requestStream) { + return requestChannel(std::move(request), true, std::move(requestStream)); +} + +std::shared_ptr> +RSocketRequester::requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> + requestStreamFlowable) { + CHECK(stateMachine_); + + return yarpl::flowable::internal::flowableFromSubscriber( + [eb = eventBase_, + req = std::move(request), + hasInitialRequest, + requestStream = std::move(requestStreamFlowable), + srs = stateMachine_]( + std::shared_ptr> subscriber) { + auto lambda = [eb, + r = req.clone(), + hasInitialRequest, + requestStream, + srs, + subs = std::move(subscriber)]() mutable { + auto scheduled = + std::make_shared>( + std::move(subs), *eb); + auto responseSink = srs->requestChannel( + std::move(r), hasInitialRequest, std::move(scheduled)); + // responseSink is wrapped with thread scheduling + // so all emissions happen on the right thread. + + // If we don't get a responseSink back, that means that + // the requesting peer wasn't connected (or similar error) + // and the Flowable it gets back will immediately call onError. + if (responseSink) { + auto scheduledResponse = + std::make_shared>( + std::move(responseSink), *eb); + requestStream->subscribe(std::move(scheduledResponse)); + } + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } -yarpl::Reference> +std::shared_ptr> RSocketRequester::requestStream(Payload request) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::flowable::Flowables::fromPublisher([ - eb = &eventBase_, - request = std::move(request), - srs = stateMachine_ - ](yarpl::Reference> - subscriber) mutable { - eb->runInEventBaseThread([ - request = std::move(request), - subscriber = std::move(subscriber), - srs = std::move(srs), - eb - ]() mutable { - srs->streamsFactory().createStreamRequester( - std::move(request), - yarpl::make_ref>( - std::move(subscriber), *eb)); - }); - }); + CHECK(stateMachine_); + + return yarpl::flowable::internal::flowableFromSubscriber( + [eb = eventBase_, req = std::move(request), srs = stateMachine_]( + std::shared_ptr> subscriber) { + auto lambda = + [eb, r = req.clone(), srs, subs = std::move(subscriber)]() mutable { + auto scheduled = + std::make_shared>( + std::move(subs), *eb); + srs->requestStream(std::move(r), std::move(scheduled)); + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } -yarpl::Reference> +std::shared_ptr> RSocketRequester::requestResponse(Payload request) { - CHECK(stateMachine_); // verify the socket was not closed + CHECK(stateMachine_); return yarpl::single::Single::create( - [eb = &eventBase_, request = std::move(request), srs = stateMachine_]( - yarpl::Reference> - observer) mutable { - eb->runInEventBaseThread([ - request = std::move(request), - observer = std::move(observer), - eb, - srs = std::move(srs) - ]() mutable { - srs->streamsFactory().createRequestResponseRequester( - std::move(request), - yarpl::make_ref>( - std::move(observer), *eb)); - }); - }); + [eb = eventBase_, req = std::move(request), srs = stateMachine_]( + std::shared_ptr> observer) { + auto lambda = [eb, + r = req.clone(), + srs, + obs = std::move(observer)]() mutable { + auto scheduled = + std::make_shared>( + std::move(obs), *eb); + srs->requestResponse(std::move(r), std::move(scheduled)); + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } -yarpl::Reference> RSocketRequester::fireAndForget( +std::shared_ptr> RSocketRequester::fireAndForget( rsocket::Payload request) { - CHECK(stateMachine_); // verify the socket was not closed - - return yarpl::single::Single::create([ - eb = &eventBase_, - request = std::move(request), - srs = stateMachine_ - ](yarpl::Reference> - subscriber) mutable { - eb->runInEventBaseThread([ - request = std::move(request), - subscriber = std::move(subscriber), - srs = std::move(srs) - ]() mutable { - // TODO pass in SingleSubscriber for underlying layers to - // call onSuccess/onError once put on network - srs->requestFireAndForget(std::move(request)); - // right now just immediately call onSuccess - subscriber->onSuccess(); - }); - }); + CHECK(stateMachine_); + + return yarpl::single::Single::create( + [eb = eventBase_, req = std::move(request), srs = stateMachine_]( + std::shared_ptr> subscriber) { + auto lambda = + [r = req.clone(), srs, subs = std::move(subscriber)]() mutable { + // TODO: Pass in SingleSubscriber for underlying layers to call + // onSuccess/onError once put on network. + srs->fireAndForget(std::move(r)); + subs->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + subs->onSuccess(); + }; + runOnCorrectThread(*eb, std::move(lambda)); + }); } void RSocketRequester::metadataPush(std::unique_ptr metadata) { - CHECK(stateMachine_); // verify the socket was not closed + CHECK(stateMachine_); - eventBase_.runInEventBaseThread( - [srs = stateMachine_, metadata = std::move(metadata)]() mutable { - srs->metadataPush(std::move(metadata)); + runOnCorrectThread( + *eventBase_, [srs = stateMachine_, meta = std::move(metadata)]() mutable { + srs->metadataPush(std::move(meta)); }); } -} + +} // namespace rsocket diff --git a/rsocket/RSocketRequester.h b/rsocket/RSocketRequester.h index 694ade4dd..a87d15955 100644 --- a/rsocket/RSocketRequester.h +++ b/rsocket/RSocketRequester.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -37,7 +49,7 @@ class RSocketRequester { std::shared_ptr srs, folly::EventBase& eventBase); - ~RSocketRequester(); // implementing for logging right now + virtual ~RSocketRequester(); // implementing for logging right now RSocketRequester(const RSocketRequester&) = delete; RSocketRequester(RSocketRequester&&) = delete; @@ -51,8 +63,8 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#request-stream */ - yarpl::Reference> requestStream( - rsocket::Payload request); + virtual std::shared_ptr> + requestStream(rsocket::Payload request); /** * Start a channel (streams in both directions). @@ -60,8 +72,20 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#request-channel */ - yarpl::Reference> requestChannel( - yarpl::Reference> requests); + virtual std::shared_ptr> + requestChannel( + std::shared_ptr> requests); + + /** + * As requestStream function accepts an initial request, this version of + * requestChannel also accepts an initial request. + * @see requestChannel + * @see requestStream + */ + virtual std::shared_ptr> + requestChannel( + Payload request, + std::shared_ptr> requests); /** * Send a single request and get a single response. @@ -69,8 +93,8 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#stream-sequences-request-response */ - yarpl::Reference> requestResponse( - rsocket::Payload request); + virtual std::shared_ptr> + requestResponse(rsocket::Payload request); /** * Send a single Payload with no response. @@ -84,18 +108,24 @@ class RSocketRequester { * Interaction model details can be found at * https://github.com/ReactiveSocket/reactivesocket/blob/master/Protocol.md#request-fire-n-forget */ - yarpl::Reference> fireAndForget( + virtual std::shared_ptr> fireAndForget( rsocket::Payload request); /** * Send metadata without response. */ - void metadataPush(std::unique_ptr metadata); + virtual void metadataPush(std::unique_ptr metadata); + + virtual void closeSocket(); - void closeSocket(); + protected: + virtual std::shared_ptr> + requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> requests); - private: std::shared_ptr stateMachine_; - folly::EventBase& eventBase_; + folly::EventBase* eventBase_; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketResponder.cpp b/rsocket/RSocketResponder.cpp index 9c6cecb15..892d2e12e 100644 --- a/rsocket/RSocketResponder.cpp +++ b/rsocket/RSocketResponder.cpp @@ -1,35 +1,86 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketResponder.h" #include +#include namespace rsocket { -yarpl::Reference> -RSocketResponder::handleRequestResponse(rsocket::Payload, rsocket::StreamId) { - return yarpl::single::Singles::error( +using namespace yarpl::flowable; +using namespace yarpl::single; + +void RSocketResponderCore::handleRequestStream( + Payload, + StreamId, + std::shared_ptr> response) noexcept { + response->onSubscribe(Subscription::create()); + response->onError(std::logic_error("handleRequestStream not implemented")); +} + +void RSocketResponderCore::handleRequestResponse( + Payload, + StreamId, + std::shared_ptr> responseObserver) noexcept { + responseObserver->onSubscribe(SingleSubscriptions::empty()); + responseObserver->onError( std::logic_error("handleRequestResponse not implemented")); } -yarpl::Reference> -RSocketResponder::handleRequestStream(rsocket::Payload, rsocket::StreamId) { - return yarpl::flowable::Flowables::error( +void RSocketResponderCore::handleFireAndForget(Payload, StreamId) { + // No default implementation, no error response to provide. +} + +void RSocketResponderCore::handleMetadataPush(std::unique_ptr) { + // No default implementation, no error response to provide. +} + +std::shared_ptr> RSocketResponderCore::handleRequestChannel( + Payload, + StreamId, + std::shared_ptr> response) noexcept { + response->onSubscribe(Subscription::create()); + response->onError(std::logic_error("handleRequestStream not implemented")); + + // cancel immediately + return std::make_shared>(); +} + +std::shared_ptr> RSocketResponder::handleRequestResponse( + Payload, + StreamId) { + return Singles::error( + std::logic_error("handleRequestResponse not implemented")); +} + +std::shared_ptr> RSocketResponder::handleRequestStream( + Payload, + StreamId) { + return Flowable::error( std::logic_error("handleRequestStream not implemented")); } -yarpl::Reference> -RSocketResponder::handleRequestChannel( - rsocket::Payload, - yarpl::Reference>, - rsocket::StreamId) { - return yarpl::flowable::Flowables::error( +std::shared_ptr> RSocketResponder::handleRequestChannel( + Payload, + std::shared_ptr>, + StreamId) { + return Flowable::error( std::logic_error("handleRequestChannel not implemented")); } -void RSocketResponder::handleFireAndForget( - rsocket::Payload, - rsocket::StreamId) { +void RSocketResponder::handleFireAndForget(Payload, StreamId) { // No default implementation, no error response to provide. } @@ -38,17 +89,15 @@ void RSocketResponder::handleMetadataPush(std::unique_ptr) { } /// Handles a new Channel requested by the other end. -yarpl::Reference> -RSocketResponder::handleRequestChannelCore( +std::shared_ptr> +RSocketResponderAdapter::handleRequestChannel( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept { - class EagerSubscriberBridge - : public yarpl::flowable::Subscriber { + std::shared_ptr> response) noexcept { + class EagerSubscriberBridge : public Subscriber { public: - void onSubscribe(yarpl::Reference - subscription) noexcept override { + void onSubscribe( + std::shared_ptr subscription) noexcept override { CHECK(!subscription_); subscription_ = std::move(subscription); if (inner_) { @@ -56,49 +105,61 @@ RSocketResponder::handleRequestChannelCore( } } - void onNext(rsocket::Payload element) noexcept override { + void onNext(Payload element) noexcept override { DCHECK(inner_); inner_->onNext(std::move(element)); } void onComplete() noexcept override { - DCHECK(inner_); - inner_->onComplete(); - - inner_.reset(); - subscription_.reset(); + if (auto inner = std::move(inner_)) { + inner->onComplete(); + subscription_.reset(); + } else { + completed_ = true; + } } - void onError(std::exception_ptr ex) noexcept override { - DCHECK(inner_); - inner_->onError(std::move(ex)); - - inner_.reset(); - subscription_.reset(); + void onError(folly::exception_wrapper ex) noexcept override { + VLOG(3) << "handleRequestChannelCore::onError: " << ex.what(); + if (auto inner = std::move(inner_)) { + inner->onError(std::move(ex)); + subscription_.reset(); + } else { + error_ = std::move(ex); + } } - void subscribe( - yarpl::Reference> inner) { + void subscribe(std::shared_ptr> inner) { CHECK(!inner_); // only one call to subscribe is supported CHECK(inner); + inner_ = std::move(inner); if (subscription_) { inner_->onSubscribe(subscription_); + // it's possible to get an error or completion before subscribe happens, + // delay sending it but send it when this class gets subscribed + if (completed_) { + onComplete(); + } else if (error_) { + onError(std::move(error_)); + } } } private: - yarpl::Reference> inner_; - yarpl::Reference subscription_; + std::shared_ptr> inner_; + std::shared_ptr subscription_; + folly::exception_wrapper error_; + bool completed_{false}; }; - auto eagerSubscriber = yarpl::make_ref(); - auto flowable = handleRequestChannel( + auto eagerSubscriber = std::make_shared(); + auto flowable = inner_->handleRequestChannel( std::move(request), - yarpl::flowable::Flowables::fromPublisher([eagerSubscriber]( - yarpl::Reference> subscriber) { - eagerSubscriber->subscribe(subscriber); - }), + internal::flowableFromSubscriber( + [eagerSubscriber](std::shared_ptr> subscriber) { + eagerSubscriber->subscribe(subscriber); + }), std::move(streamId)); // bridge from the existing eager RequestHandler and old Subscriber type // to the lazy Flowable and new Subscriber type @@ -107,22 +168,33 @@ RSocketResponder::handleRequestChannelCore( } /// Handles a new Stream requested by the other end. -void RSocketResponder::handleRequestStreamCore( +void RSocketResponderAdapter::handleRequestStream( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept { - auto flowable = handleRequestStream(std::move(request), std::move(streamId)); + std::shared_ptr> response) noexcept { + auto flowable = + inner_->handleRequestStream(std::move(request), std::move(streamId)); flowable->subscribe(std::move(response)); } /// Handles a new inbound RequestResponse requested by the other end. -void RSocketResponder::handleRequestResponseCore( +void RSocketResponderAdapter::handleRequestResponse( Payload request, StreamId streamId, - const yarpl::Reference>& - responseObserver) noexcept { - auto single = handleRequestResponse(std::move(request), streamId); + std::shared_ptr> responseObserver) noexcept { + auto single = inner_->handleRequestResponse(std::move(request), streamId); single->subscribe(std::move(responseObserver)); } + +void RSocketResponderAdapter::handleFireAndForget( + Payload request, + StreamId streamId) { + inner_->handleFireAndForget(std::move(request), streamId); +} + +void RSocketResponderAdapter::handleMetadataPush( + std::unique_ptr buf) { + inner_->handleMetadataPush(std::move(buf)); } + +} // namespace rsocket diff --git a/rsocket/RSocketResponder.h b/rsocket/RSocketResponder.h index fdc428913..eedcc2ef8 100644 --- a/rsocket/RSocketResponder.h +++ b/rsocket/RSocketResponder.h @@ -1,14 +1,52 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include "rsocket/Payload.h" -#include "rsocket/internal/Common.h" +#include "rsocket/framing/FrameHeader.h" #include "yarpl/Flowable.h" #include "yarpl/Single.h" namespace rsocket { +class RSocketResponderCore { + public: + virtual ~RSocketResponderCore() = default; + + virtual void handleFireAndForget(Payload request, StreamId streamId); + + virtual void handleMetadataPush(std::unique_ptr metadata); + + virtual std::shared_ptr> + handleRequestChannel( + Payload request, + StreamId streamId, + std::shared_ptr> response) noexcept; + + virtual void handleRequestStream( + Payload request, + StreamId streamId, + std::shared_ptr> response) noexcept; + + virtual void handleRequestResponse( + Payload request, + StreamId streamId, + std::shared_ptr> + response) noexcept; +}; + /** * Responder APIs to handle requests on an RSocket connection. * @@ -37,28 +75,28 @@ class RSocketResponder { * * Returns a Single with the response. */ - virtual yarpl::Reference> - handleRequestResponse(rsocket::Payload request, rsocket::StreamId streamId); + virtual std::shared_ptr> handleRequestResponse( + Payload request, + StreamId streamId); /** * Called when a new `requestStream` occurs from an RSocketRequester. * * Returns a Flowable with the response stream. */ - virtual yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId); + virtual std::shared_ptr> + handleRequestStream(Payload request, StreamId streamId); /** * Called when a new `requestChannel` occurs from an RSocketRequester. * * Returns a Flowable with the response stream. */ - virtual yarpl::Reference> + virtual std::shared_ptr> handleRequestChannel( - rsocket::Payload request, - yarpl::Reference> - requestStream, - rsocket::StreamId streamId); + Payload request, + std::shared_ptr> requestStream, + StreamId streamId); /** * Called when a new `fireAndForget` occurs from an RSocketRequester. @@ -75,30 +113,42 @@ class RSocketResponder { * No response. */ virtual void handleMetadataPush(std::unique_ptr metadata); +}; + +class RSocketResponderAdapter : public RSocketResponderCore { + public: + explicit RSocketResponderAdapter(std::shared_ptr inner) + : inner_(std::move(inner)) {} + virtual ~RSocketResponderAdapter() = default; /// Internal method for handling channel requests, not intended to be used by /// application code. - virtual yarpl::Reference> - handleRequestChannelCore( + std::shared_ptr> handleRequestChannel( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept; + std::shared_ptr> response) noexcept + override; /// Internal method for handling stream requests, not intended to be used /// by application code. - virtual void handleRequestStreamCore( + void handleRequestStream( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept; + std::shared_ptr> response) noexcept + override; /// Internal method for handling request-response requests, not intended to be /// used by application code. - virtual void handleRequestResponseCore( + void handleRequestResponse( Payload request, StreamId streamId, - const yarpl::Reference>& - response) noexcept; + std::shared_ptr> response) noexcept + override; + + void handleFireAndForget(Payload request, StreamId streamId) override; + void handleMetadataPush(std::unique_ptr buf) override; + + private: + std::shared_ptr inner_; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketServer.cpp b/rsocket/RSocketServer.cpp index 78fce6bba..1e202810d 100644 --- a/rsocket/RSocketServer.cpp +++ b/rsocket/RSocketServer.cpp @@ -1,27 +1,43 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketServer.h" #include + #include #include "rsocket/RSocketErrors.h" #include "rsocket/RSocketStats.h" -#include "rsocket/framing/FrameTransport.h" #include "rsocket/framing/FramedDuplexConnection.h" -#include "rsocket/internal/RSocketConnectionManager.h" +#include "rsocket/framing/ScheduledFrameTransport.h" +#include "rsocket/internal/ConnectionSet.h" +#include "rsocket/internal/WarmResumeManager.h" namespace rsocket { RSocketServer::RSocketServer( - std::unique_ptr connectionAcceptor) + std::unique_ptr connectionAcceptor, + std::shared_ptr stats) : duplexConnectionAcceptor_(std::move(connectionAcceptor)), setupResumeAcceptors_([] { - return new rsocket::SetupResumeAcceptor( - ProtocolVersion::Unknown, - folly::EventBaseManager::get()->getExistingEventBase()); + return new rsocket::SetupResumeAcceptor{ + folly::EventBaseManager::get()->getExistingEventBase()}; }), - connectionManager_(std::make_unique()) {} + connectionSet_(std::make_unique()), + stats_(std::move(stats)) {} RSocketServer::~RSocketServer() { + VLOG(3) << "~RSocketServer .."; shutdownAndWait(); } @@ -41,34 +57,32 @@ void RSocketServer::shutdownAndWait() { std::vector> closingFutures; for (auto& acceptor : setupResumeAcceptors_.accessAllThreads()) { - // this call will queue up the cleanup on the eventBase + // This call will queue up the cleanup on the eventBase. closingFutures.push_back(acceptor.close()); } folly::collectAll(closingFutures).get(); - connectionManager_.reset(); // will close all existing RSockets and wait - - // All requests are fully finished, worker threads can be safely killed off. + // Close off all outstanding connections. + connectionSet_->shutdownAndWait(); } void RSocketServer::start( std::shared_ptr serviceHandler) { CHECK(duplexConnectionAcceptor_); // RSocketServer has to be initialized with - // the acceptor + // the acceptor if (started) { throw std::runtime_error("RSocketServer::start() already called."); } started = true; - LOG(INFO) << "Starting RSocketServer"; - - duplexConnectionAcceptor_->start([this, serviceHandler]( - std::unique_ptr connection, - folly::EventBase& eventBase) { - acceptConnection(std::move(connection), eventBase, serviceHandler); - }); + duplexConnectionAcceptor_->start( + [this, serviceHandler]( + std::unique_ptr connection, + folly::EventBase& eventBase) { + acceptConnection(std::move(connection), eventBase, serviceHandler); + }); } void RSocketServer::start(OnNewSetupFn onNewSetupFn) { @@ -79,10 +93,15 @@ void RSocketServer::startAndPark(OnNewSetupFn onNewSetupFn) { startAndPark(RSocketServiceHandler::create(std::move(onNewSetupFn))); } +void RSocketServer::setSingleThreadedResponder() { + useScheduledResponder_ = false; +} + void RSocketServer::acceptConnection( std::unique_ptr connection, folly::EventBase&, std::shared_ptr serviceHandler) { + stats_->serverConnectionAccepted(); if (isShutdown_) { // connection is getting out of scope and terminated return; @@ -102,12 +121,20 @@ void RSocketServer::acceptConnection( acceptor->accept( std::move(framedConnection), - std::bind( - &RSocketServer::onRSocketSetup, - this, - serviceHandler, - std::placeholders::_1, - std::placeholders::_2), + [serviceHandler, + weakConSet = std::weak_ptr(connectionSet_), + scheduledResponder = useScheduledResponder_]( + std::unique_ptr conn, + SetupParameters params) mutable { + if (auto connectionSet = weakConSet.lock()) { + RSocketServer::onRSocketSetup( + serviceHandler, + std::move(connectionSet), + scheduledResponder, + std::move(conn), + std::move(params)); + } + }, std::bind( &RSocketServer::onRSocketResume, this, @@ -118,48 +145,105 @@ void RSocketServer::acceptConnection( void RSocketServer::onRSocketSetup( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::shared_ptr connectionSet, + bool scheduledResponder, + std::unique_ptr connection, SetupParameters setupParams) { - VLOG(1) << "Received new setup payload"; - auto* eventBase = folly::EventBaseManager::get()->getExistingEventBase(); + const auto eventBase = folly::EventBaseManager::get()->getExistingEventBase(); + VLOG(2) << "Received new setup payload on " << eventBase->getName(); CHECK(eventBase); auto result = serviceHandler->onNewSetup(setupParams); if (result.hasError()) { - VLOG(3) << "Terminating SETUP attempt from client. No Responder"; - throw result.error(); + VLOG(3) << "Terminating SETUP attempt from client. " + << result.error().what(); + connection->send( + FrameSerializer::createFrameSerializer(setupParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup(result.error().what()))); + return; } - auto connectionParams = result.value(); - CHECK(connectionParams.responder); - auto responder = std::make_shared( - std::move(connectionParams.responder), *eventBase); - auto rs = std::make_shared( - *eventBase, - std::move(responder), + auto connectionParams = std::move(result.value()); + if (!connectionParams.responder) { + LOG(ERROR) << "Received invalid Responder. Dropping connection"; + connection->send( + FrameSerializer::createFrameSerializer(setupParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup( + "Received invalid Responder from server"))); + return; + } + const auto rs = std::make_shared( + scheduledResponder + ? std::make_shared( + std::move(connectionParams.responder), *eventBase) + : std::move(connectionParams.responder), nullptr, - ReactiveSocketMode::SERVER, - std::move(connectionParams.stats), - std::move(connectionParams.connectionEvents)); - connectionManager_->manageConnection(rs, *eventBase); + RSocketMode::SERVER, + connectionParams.stats, + std::move(connectionParams.connectionEvents), + setupParams.resumable + ? std::make_shared(connectionParams.stats) + : ResumeManager::makeEmpty(), + nullptr /* coldResumeHandler */); + + if (!connectionSet->insert(rs, eventBase)) { + VLOG(1) << "Server is closed, so ignore the connection"; + connection->send( + FrameSerializer::createFrameSerializer(setupParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup( + "Server ignores the connection attempt"))); + return; + } + rs->registerCloseCallback(connectionSet.get()); + auto requester = std::make_shared(rs, *eventBase); auto serverState = std::shared_ptr( - new RSocketServerState(rs, requester)); + new RSocketServerState(*eventBase, rs, std::move(requester))); serviceHandler->onNewRSocketState(std::move(serverState), setupParams.token); - rs->connectServer(std::move(frameTransport), std::move(setupParams)); + rs->connectServer( + std::make_shared(std::move(connection)), + std::move(setupParams)); } void RSocketServer::onRSocketResume( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::unique_ptr connection, ResumeParameters resumeParams) { auto result = serviceHandler->onResume(resumeParams.token); if (result.hasError()) { + stats_->resumeFailedNoState(); VLOG(3) << "Terminating RESUME attempt from client. No ServerState found"; - throw result.error(); + connection->send( + FrameSerializer::createFrameSerializer(resumeParams.protocolVersion) + ->serializeOut(Frame_ERROR::rejectedSetup(result.error().what()))); + return; } - auto serverState = std::move(result.value()); + const auto serverState = std::move(result.value()); CHECK(serverState); - serverState->rSocketStateMachine_->resumeServer( - std::move(frameTransport), resumeParams); + const auto eventBase = folly::EventBaseManager::get()->getExistingEventBase(); + VLOG(2) << "Resuming client on " << eventBase->getName(); + if (!serverState->eventBase_.isInEventBaseThread()) { + // If the resumed connection is on a different EventBase, then use + // ScheduledFrameTransport and ScheduledFrameProcessor to ensure the + // RSocketStateMachine continues to live on the same EventBase and the + // IO happens in the new EventBase + auto scheduledFT = std::make_shared( + std::make_shared(std::move(connection)), + eventBase, /* Transport EventBase */ + &serverState->eventBase_); /* StateMachine EventBase */ + serverState->eventBase_.runInEventBaseThread( + [serverState, + scheduledFT = std::move(scheduledFT), + resumeParams = std::move(resumeParams)]() mutable { + serverState->rSocketStateMachine_->resumeServer( + std::move(scheduledFT), resumeParams); + }); + } else { + // If the resumed connection is on the same EventBase, then the + // RSocketStateMachine and Transport can continue living in the same + // EventBase without any thread hopping between them. + serverState->rSocketStateMachine_->resumeServer( + std::make_shared(std::move(connection)), + resumeParams); + } } void RSocketServer::startAndPark( @@ -177,4 +261,8 @@ folly::Optional RSocketServer::listeningPort() const { : folly::none; } +size_t RSocketServer::getNumConnections() { + return connectionSet_ ? connectionSet_->size() : 0; +} + } // namespace rsocket diff --git a/rsocket/RSocketServer.h b/rsocket/RSocketServer.h index a4a96613a..39dae66a3 100644 --- a/rsocket/RSocketServer.h +++ b/rsocket/RSocketServer.h @@ -1,22 +1,34 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include #include #include -#include "RSocketServiceHandler.h" +#include + #include "rsocket/ConnectionAcceptor.h" #include "rsocket/RSocketParameters.h" #include "rsocket/RSocketResponder.h" +#include "rsocket/RSocketServiceHandler.h" +#include "rsocket/internal/ConnectionSet.h" #include "rsocket/internal/SetupResumeAcceptor.h" namespace rsocket { -class RSocketConnectionManager; - /** * API for starting an RSocket server. Returned from RSocket::createServer. * @@ -26,7 +38,9 @@ class RSocketConnectionManager; */ class RSocketServer { public: - explicit RSocketServer(std::unique_ptr); + explicit RSocketServer( + std::unique_ptr, + std::shared_ptr stats = RSocketStats::noop()); ~RSocketServer(); RSocketServer(const RSocketServer&) = delete; @@ -83,25 +97,47 @@ class RSocketServer { */ folly::Optional listeningPort() const; + /** + * Use the same EventBase that is provided to acceptConnection function for + * internal operations. Don't schedule to another event base. + */ + void setSingleThreadedResponder(); + + /** + * Number of active connections to this server. + */ + size_t getNumConnections(); + private: - void onRSocketSetup( + static void onRSocketSetup( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::shared_ptr connectionSet, + bool scheduledResponder, + std::unique_ptr connection, rsocket::SetupParameters setupPayload); void onRSocketResume( std::shared_ptr serviceHandler, - yarpl::Reference frameTransport, + std::unique_ptr connection, rsocket::ResumeParameters setupPayload); - std::unique_ptr duplexConnectionAcceptor_; + const std::unique_ptr duplexConnectionAcceptor_; bool started{false}; - class SetupResumeAcceptorTag{}; - folly::ThreadLocal setupResumeAcceptors_; + class SetupResumeAcceptorTag {}; + folly::ThreadLocal + setupResumeAcceptors_; folly::Baton<> waiting_; std::atomic isShutdown_{false}; - std::unique_ptr connectionManager_; + std::shared_ptr connectionSet_; + std::shared_ptr stats_; + + /** + * If this field is false, acceptConnection() function will assume that there + * will be a single thread for each connected client. The execution will not + * be scheduled to another event base. + */ + bool useScheduledResponder_{true}; }; } // namespace rsocket diff --git a/rsocket/RSocketServerState.h b/rsocket/RSocketServerState.h index f3d7bc02e..c5d010dbb 100644 --- a/rsocket/RSocketServerState.h +++ b/rsocket/RSocketServerState.h @@ -1,27 +1,56 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include "rsocket/RSocketRequester.h" +namespace folly { +class EventBase; +} + namespace rsocket { class RSocketServerState { - void close(); + public: + void close() { + eventBase_.runInEventBaseThread([sm = rSocketStateMachine_] { + sm->close({}, StreamCompletionSignal::SOCKET_CLOSED); + }); + } std::shared_ptr getRequester() { return rSocketRequester_; } + folly::EventBase* eventBase() { + return &eventBase_; + } + friend class RSocketServer; private: RSocketServerState( + folly::EventBase& eventBase, std::shared_ptr stateMachine, std::shared_ptr rSocketRequester) - : rSocketStateMachine_(stateMachine), + : eventBase_(eventBase), + rSocketStateMachine_(stateMachine), rSocketRequester_(rSocketRequester) {} - std::shared_ptr rSocketStateMachine_; - std::shared_ptr rSocketRequester_; + + folly::EventBase& eventBase_; + const std::shared_ptr rSocketStateMachine_; + const std::shared_ptr rSocketRequester_; }; -} +} // namespace rsocket diff --git a/rsocket/RSocketServiceHandler.cpp b/rsocket/RSocketServiceHandler.cpp index fce0681d1..8e3f8d341 100644 --- a/rsocket/RSocketServiceHandler.cpp +++ b/rsocket/RSocketServiceHandler.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketServiceHandler.h" @@ -16,7 +28,7 @@ RSocketServiceHandler::onResume(ResumeIdentificationToken) { bool RSocketServiceHandler::canResume( const std::vector& /* cleanStreamIds */, const std::vector& /* dirtyStreamIds */, - ResumeIdentificationToken) { + ResumeIdentificationToken) const { return true; } @@ -24,10 +36,15 @@ std::shared_ptr RSocketServiceHandler::create( OnNewSetupFn onNewSetupFn) { class ServiceHandler : public RSocketServiceHandler { public: - ServiceHandler(OnNewSetupFn fn) : onNewSetupFn_(std::move(fn)) {} + explicit ServiceHandler(OnNewSetupFn fn) : onNewSetupFn_(std::move(fn)) {} folly::Expected onNewSetup( const SetupParameters& setupParameters) override { - return RSocketConnectionParams(onNewSetupFn_(setupParameters)); + try { + return RSocketConnectionParams(onNewSetupFn_(setupParameters)); + } catch (const std::exception& e) { + return folly::Unexpected( + ConnectionException(e.what())); + } } private: @@ -35,4 +52,4 @@ std::shared_ptr RSocketServiceHandler::create( }; return std::make_shared(std::move(onNewSetupFn)); } -} +} // namespace rsocket diff --git a/rsocket/RSocketServiceHandler.h b/rsocket/RSocketServiceHandler.h index 4b12678a3..b67caa358 100644 --- a/rsocket/RSocketServiceHandler.h +++ b/rsocket/RSocketServiceHandler.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -22,7 +34,7 @@ using OnNewSetupFn = // This struct holds all the necessary information needed by the RSocketServer // to initiate a connection with a client. struct RSocketConnectionParams { - RSocketConnectionParams( + explicit RSocketConnectionParams( std::shared_ptr _responder, std::shared_ptr _stats = RSocketStats::noop(), std::shared_ptr _connectionEvents = nullptr) @@ -34,7 +46,6 @@ struct RSocketConnectionParams { std::shared_ptr connectionEvents; }; - // This class has to be implemented by the application. The methods can be // called from different threads and it is the application's responsibility to // ensure thread-safety. @@ -90,10 +101,10 @@ class RSocketServiceHandler { virtual bool canResume( const std::vector& /* cleanStreamIds */, const std::vector& /* dirtyStreamIds */, - ResumeIdentificationToken); + ResumeIdentificationToken) const; // Convenience constructor to create a simple RSocketServiceHandler. static std::shared_ptr create( OnNewSetupFn onNewSetupFn); }; -} +} // namespace rsocket diff --git a/rsocket/RSocketStats.cpp b/rsocket/RSocketStats.cpp index 008696841..ee7bc6f70 100644 --- a/rsocket/RSocketStats.cpp +++ b/rsocket/RSocketStats.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/RSocketStats.h" @@ -7,12 +19,19 @@ namespace rsocket { class NoopStats : public RSocketStats { public: NoopStats() = default; - ~NoopStats() = default; + NoopStats(const NoopStats&) = delete; // non construction-copyable + NoopStats& operator=(const NoopStats&) = delete; // non copyable + NoopStats& operator=(const NoopStats&&) = delete; // non movable + NoopStats(NoopStats&&) = delete; // non construction-movable + ~NoopStats() override = default; void socketCreated() override {} + void socketConnected() override {} void socketDisconnected() override {} void socketClosed(StreamCompletionSignal) override {} + void serverConnectionAccepted() override {} + void duplexConnectionCreated(const std::string&, rsocket::DuplexConnection*) override {} @@ -23,23 +42,23 @@ class NoopStats : public RSocketStats { void bytesRead(size_t) override {} void frameWritten(FrameType) override {} void frameRead(FrameType) override {} - + void serverResume(folly::Optional, int64_t, int64_t, ResumeOutcome) + override {} void resumeBufferChanged(int, int) override {} void streamBufferChanged(int64_t, int64_t) override {} + void resumeFailedNoState() override {} + + void keepaliveSent() override {} + void keepaliveReceived() override {} + static std::shared_ptr instance() { - static auto singleton = std::make_shared(); + static const auto singleton = std::make_shared(); return singleton; } - - private: - NoopStats(const NoopStats&) = delete; // non construction-copyable - NoopStats& operator=(const NoopStats&) = delete; // non copyable - NoopStats& operator=(const NoopStats&&) = delete; // non movable - NoopStats(NoopStats&&) = delete; // non construction-movable }; std::shared_ptr RSocketStats::noop() { return NoopStats::instance(); } -} +} // namespace rsocket diff --git a/rsocket/RSocketStats.h b/rsocket/RSocketStats.h index 6b4ad4c64..8e7480b91 100644 --- a/rsocket/RSocketStats.h +++ b/rsocket/RSocketStats.h @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include #include #include #include @@ -15,28 +28,44 @@ class DuplexConnection; class RSocketStats { public: + enum class ResumeOutcome { SUCCESS, FAILURE }; + virtual ~RSocketStats() = default; static std::shared_ptr noop(); - virtual void socketCreated() = 0; - virtual void socketDisconnected() = 0; - virtual void socketClosed(StreamCompletionSignal signal) = 0; + virtual void socketCreated() {} + virtual void socketConnected() {} + virtual void socketDisconnected() {} + virtual void socketClosed(StreamCompletionSignal /* signal */) {} + + virtual void serverConnectionAccepted() {} virtual void duplexConnectionCreated( - const std::string& type, - DuplexConnection* connection) = 0; + const std::string& /* type */, + DuplexConnection* /* connection */) {} virtual void duplexConnectionClosed( - const std::string& type, - DuplexConnection* connection) = 0; - - virtual void bytesWritten(size_t bytes) = 0; - virtual void bytesRead(size_t bytes) = 0; - virtual void frameWritten(FrameType frameType) = 0; - virtual void frameRead(FrameType frameType) = 0; - virtual void resumeBufferChanged(int framesCountDelta, int dataSizeDelta) = 0; + const std::string& /* type */, + DuplexConnection* /* connection */) {} + virtual void serverResume( + folly::Optional /* clientAvailable */, + int64_t /* serverAvailable */, + int64_t /* serverDelta */, + ResumeOutcome /* outcome */) {} + virtual void bytesWritten(size_t /* bytes */) {} + virtual void bytesRead(size_t /* bytes */) {} + virtual void frameWritten(FrameType /* frameType */) {} + virtual void frameRead(FrameType /* frameType */) {} + virtual void resumeBufferChanged( + int /* framesCountDelta */, + int /* dataSizeDelta */) {} virtual void streamBufferChanged( - int64_t framesCountDelta, - int64_t dataSizeDelta) = 0; + int64_t /* framesCountDelta */, + int64_t /* dataSizeDelta */) {} + virtual void resumeFailedNoState() {} + virtual void keepaliveSent() {} + virtual void keepaliveReceived() {} + virtual void unknownFrameReceived() { + } // TODO(lehecka): add to all implementations }; -} +} // namespace rsocket diff --git a/rsocket/ResumeManager.h b/rsocket/ResumeManager.h index 3e54d66ef..198539916 100644 --- a/rsocket/ResumeManager.h +++ b/rsocket/ResumeManager.h @@ -1,8 +1,161 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include +#include +#include "rsocket/framing/Frame.h" +#include "rsocket/framing/FrameTransportImpl.h" + +namespace folly { +class IOBuf; +} + namespace rsocket { -class ResumeManager {}; -} \ No newline at end of file +// Struct to hold information relevant per stream. +struct StreamResumeInfo { + StreamResumeInfo() = delete; + StreamResumeInfo(StreamType sType, RequestOriginator req, std::string sToken) + : streamType(sType), requester(req), streamToken(sToken) {} + + // REQUEST_STREAM, REQUEST_CHANNEL or REQUEST_RESPONSE. We don't + // have to store any stream level information for FNF. + StreamType streamType; + + // Did the stream originate locally or remotely. + RequestOriginator requester; + + // Application defined string representation for the stream. + std::string streamToken; + + // Stores the allowance which the local side has received but hasn't + // fulfilled yet. Relevant for REQUEST_STREAM Responder and REQUEST_CHANNEL + size_t producerAllowance{0}; + + // Stores the allowance which has been sent to the remote side and has not + // been fulfilled yet. Relevant for REQUEST_STREAM Requester and + // REQUEST_CHANNEL + size_t consumerAllowance{0}; +}; + +using StreamResumeInfos = std::unordered_map; + +// Applications desiring to have cold-resumption should implement a +// ResumeManager interface. By default, an in-memory implementation of this +// interface (WarmResumeManager) will be used by RSocket. +// +// The API refers to the stored frames by "position". "position" is the byte +// count at frame boundaries. For example, if the ResumeManager has stored 3 +// 100-byte sent frames starting from byte count 150. Then, +// - isPositionAvailable would return true for the values [150, 250, 350]. +// - firstSentPosition() would return 150 +// - lastSentPosition() would return 350 +class ResumeManager { + public: + static std::shared_ptr makeEmpty(); + + virtual ~ResumeManager() {} + + // The following methods will be called for each frame which is being + // sent/received on the wire. The application should implement a way to + // store the sent and received frames in persistent storage. + virtual void trackReceivedFrame( + size_t frameLength, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) = 0; + + virtual void trackSentFrame( + const folly::IOBuf& serializedFrame, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) = 0; + + // We have received acknowledgement from the remote-side that it has frames + // up to "position". We can discard all frames before that. This + // information is periodically received from remote-side through KeepAlive + // frames. + virtual void resetUpToPosition(ResumePosition position) = 0; + + // The application should check its persistent storage and respond whether it + // has frames starting from "position" in send buffer. + virtual bool isPositionAvailable(ResumePosition position) const = 0; + + // The application should send frames starting from the "position" using the + // provided "transport". As an alternative, we could design the API such + // that we retrieve individual frames from the application and send them over + // wire. But that would mean application has random access to frames + // indexed by position. This API gives the flexibility to the application to + // store the frames in any way it wants (randomly accessed or sequentially + // accessed). + virtual void sendFramesFromPosition( + ResumePosition position, + FrameTransport& transport) const = 0; + + // This should return the first (oldest) available position in the send + // buffer. + virtual ResumePosition firstSentPosition() const = 0; + + // This should return the last (latest) available position in the send + // buffer. + virtual ResumePosition lastSentPosition() const = 0; + + // This should return the latest tracked position of frames received from + // remote side. + virtual ResumePosition impliedPosition() const = 0; + + // This gets called when a stream is opened (both local/remote streams) + virtual void onStreamOpen( + StreamId, + RequestOriginator, + std::string streamToken, + StreamType streamType) = 0; + + // This gets called when a stream is closed (both local/remote streams) + virtual void onStreamClosed(StreamId streamId) = 0; + + // Returns the cached stream information. + virtual const StreamResumeInfos& getStreamResumeInfos() const = 0; + + // Returns the largest used StreamId so far. + virtual StreamId getLargestUsedStreamId() const = 0; + + // Utility method to check frames which should be tracked for resumption. + virtual bool shouldTrackFrame(const FrameType frameType) const { + switch (frameType) { + case FrameType::REQUEST_CHANNEL: + case FrameType::REQUEST_STREAM: + case FrameType::REQUEST_RESPONSE: + case FrameType::REQUEST_FNF: + case FrameType::REQUEST_N: + case FrameType::CANCEL: + case FrameType::ERROR: + case FrameType::PAYLOAD: + return true; + case FrameType::RESERVED: + case FrameType::SETUP: + case FrameType::LEASE: + case FrameType::KEEPALIVE: + case FrameType::METADATA_PUSH: + case FrameType::RESUME: + case FrameType::RESUME_OK: + case FrameType::EXT: + default: + return false; + } + } +}; +} // namespace rsocket diff --git a/rsocket/benchmarks/BaselinesAsyncSocket.cpp b/rsocket/benchmarks/BaselinesAsyncSocket.cpp new file mode 100644 index 000000000..467a36f68 --- /dev/null +++ b/rsocket/benchmarks/BaselinesAsyncSocket.cpp @@ -0,0 +1,285 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#define PORT (35437) + +// namespace { +// +// class TcpReader : public ::folly::AsyncTransportWrapper::ReadCallback { +// public: +// TcpReader( +// folly::AsyncTransportWrapper::UniquePtr&& socket, +// EventBase& eventBase, +// size_t loadSize, +// size_t recvBufferLength) +// : socket_(std::move(socket)), +// eventBase_(eventBase), +// loadSize_(loadSize), +// recvBufferLength_(recvBufferLength) { +// socket_->setReadCB(this); +// } +// +// private: +// void getReadBuffer(void** bufReturn, size_t* lenReturn) noexcept override { +// std::tie(*bufReturn, *lenReturn) = +// readBuffer_.preallocate(recvBufferLength_, recvBufferLength_); +// } +// +// void readDataAvailable(size_t len) noexcept override { +// readBuffer_.postallocate(len); +// auto readData = readBuffer_.split(len); +// +// receivedLength_ += readData->computeChainDataLength(); +// ++reads_; +// if (receivedLength_ >= loadSize_) { +// LOG(INFO) << "closing reader"; +// close(); +// } +// } +// +// void readBufferAvailable( +// std::unique_ptr readBuf) noexcept override { +// receivedLength_ += readBuf->computeChainDataLength(); +// ++reads_; +// if (receivedLength_ >= loadSize_) { +// LOG(INFO) << "closing reader"; +// close(); +// } +// } +// +// void readEOF() noexcept override { +// LOG(INFO) << "closing reader"; +// close(); +// } +// +// void readErr(const folly::AsyncSocketException& exn) noexcept override { +// LOG(ERROR) << exn.what(); +// close(); +// } +// +// bool isBufferMovable() noexcept override { +// return true; +// } +// +// void close() { +// if (socket_) { +// LOG(INFO) << "received " << receivedLength_ << " via " << reads_ +// << " reads"; +// auto socket = std::move(socket_); +// socket->close(); +// eventBase_.terminateLoopSoon(); +// delete this; +// } +// } +// +// folly::AsyncTransportWrapper::UniquePtr socket_; +// folly::IOBufQueue readBuffer_{folly::IOBufQueue::cacheChainLength()}; +// EventBase& eventBase_; +// const size_t loadSize_; +// const size_t recvBufferLength_; +// size_t receivedLength_{0}; +// int reads_{0}; +//}; +// +// class ServerAcceptCallback : public AsyncServerSocket::AcceptCallback { +// public: +// ServerAcceptCallback( +// EventBase& eventBase, +// size_t loadSize, +// size_t recvBufferLength) +// : eventBase_(eventBase), +// loadSize_(loadSize), +// recvBufferLength_(recvBufferLength) {} +// +// void connectionAccepted( +// int fd, +// const SocketAddress&) noexcept override { +// auto socket = +// folly::AsyncTransportWrapper::UniquePtr(new AsyncSocket(&eventBase_, +// fd)); +// +// new TcpReader( +// std::move(socket), eventBase_, loadSize_, recvBufferLength_); +// } +// +// void acceptError(folly::exception_wrapper ex) noexcept override { +// LOG(FATAL) << "acceptError" << ex << std::endl; +// eventBase_.terminateLoopSoon(); +// } +// +// private: +// EventBase& eventBase_; +// const size_t loadSize_; +// const size_t recvBufferLength_; +//}; +// +// class TcpWriter : public ::folly::AsyncTransportWrapper::WriteCallback { +// public: +// ~TcpWriter() { +// LOG(INFO) << "writes=" << writes_ << " success=" << success_ << " errors=" +// << errors_; +// } +// +// void startWriting(AsyncSocket& socket, size_t loadSize, +// size_t messageSize) { +// size_t bytesSent{0}; +// +// while (!closed_ && bytesSent < loadSize) { +// auto data = IOBuf::copyBuffer(std::string(messageSize, 'a')); +// socket.writeChain(this, std::move(data)); +// ++writes_; +// bytesSent += messageSize; +// } +// LOG(INFO) << "wrote " << bytesSent << " closed=" << closed_; +// } +// +// private: +// void writeSuccess() noexcept override { +// ++success_; +// } +// +// void writeErr( +// size_t, +// const folly::AsyncSocketException& exn) noexcept override { +// LOG_EVERY_N(ERROR,10000) << "writeError: " << exn.what(); +// closed_ = true; +// ++errors_; +// } +// +// bool closed_{false}; +// int writes_{0}; +// int success_{0}; +// int errors_{0}; +//}; +// +// class ClientConnectCallback : public AsyncSocket::ConnectCallback { +// public: +// ClientConnectCallback(EventBase& eventBase, size_t loadSize, +// size_t msgLength) +// : eventBase_(eventBase), loadSize_(loadSize), msgLength_(msgLength) {} +// +// void connect() { +// eventBase_.runInEventBaseThread([this] { +// socket_.reset(new AsyncSocket(&eventBase_)); +// SocketAddress clientAaddr("::", PORT); +// socket_->connect(this, clientAaddr); +// }); +// } +// +// private: +// void connectSuccess() noexcept override { +// { +// TcpWriter writer; +// LOG(INFO) << "startWriting"; +// writer.startWriting(*socket_, loadSize_, msgLength_); +// LOG(INFO) << "endWriting"; +// socket_->close(); +// LOG(INFO) << "socket closed, deleting this"; +// } +// delete this; +// } +// +// void connectErr(const AsyncSocketException& ex) noexcept override { +// LOG(FATAL) << "connectErr: " << ex.what() << " " << ex.getType(); +// delete this; +// } +// +// AsyncTransportWrapper::UniquePtr socket_; +// EventBase& eventBase_; +// const size_t loadSize_; +// const size_t msgLength_; +//}; +//} + +static void BM_Baseline_AsyncSocket_SendReceive( + size_t /*loadSize*/, + size_t /*msgLength*/, + size_t /*recvLength*/) { + LOG_EVERY_N(INFO, 10000) << "TODO(lehecka): benchmark needs updating, " + << "it has memory corruption bugs"; + // EventBase serverEventBase; + // auto serverSocket = AsyncServerSocket::newSocket(&serverEventBase); + // + // ServerAcceptCallback serverCallback(serverEventBase, loadSize, + // recvLength); + // + // SocketAddress addr("::", PORT); + // + // serverSocket->setReusePortEnabled(true); + // serverSocket->bind(addr); + // serverSocket->addAcceptCallback(&serverCallback, &serverEventBase); + // serverSocket->listen(1); + // serverSocket->startAccepting(); + // + // ScopedEventBaseThread clientThread; + // auto* clientCallback = new ClientConnectCallback( + // *clientThread.getEventBase(), loadSize, msgLength); + // clientCallback->connect(); + // + // serverEventBase.loopForever(); +} + +BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s40B_r1024B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 40; + constexpr size_t receiveSizeB = 1024; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} +BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s40B_r4096B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 40; + constexpr size_t receiveSizeB = 4096; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} +BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s80B_r4096B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 80; + constexpr size_t receiveSizeB = 4096; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} +BENCHMARK(BM_Baseline_AsyncSocket_Throughput_100MB_s4096B_r4096B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 4096; + constexpr size_t receiveSizeB = 4096; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} + +BENCHMARK(BM_Baseline_AsyncSocket_Latency_1M_msgs_32B, n) { + (void)n; + constexpr size_t messageSizeB = 32; + constexpr size_t loadSizeB = 1000000 * messageSizeB; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, messageSizeB, messageSizeB); +} +BENCHMARK(BM_Baseline_AsyncSocket_Latency_1M_msgs_128B, n) { + (void)n; + constexpr size_t messageSizeB = 128; + constexpr size_t loadSizeB = 1000000 * messageSizeB; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, messageSizeB, messageSizeB); +} +BENCHMARK(BM_Baseline_AsyncSocket_Latency_1M_msgs_4kB, n) { + (void)n; + constexpr size_t messageSizeB = 4096; + constexpr size_t loadSizeB = 1000000 * messageSizeB; + BM_Baseline_AsyncSocket_SendReceive(loadSizeB, messageSizeB, messageSizeB); +} diff --git a/rsocket/benchmarks/BaselinesTcp.cpp b/rsocket/benchmarks/BaselinesTcp.cpp new file mode 100644 index 000000000..d9e22892d --- /dev/null +++ b/rsocket/benchmarks/BaselinesTcp.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_MESSAGE_LENGTH (8 * 1024) +#define PORT (35437) + +static void BM_Baseline_TCP_SendReceive( + size_t loadSize, + size_t msgLength, + size_t recvLength) { + std::atomic accepting{false}; + std::atomic accepted{false}; + + std::thread t([&]() { + int serverSock = socket(AF_INET, SOCK_STREAM, 0); + int sock = -1; + struct sockaddr_in addr = {}; + socklen_t addrlen = sizeof(addr); + std::array message = {}; + + if (serverSock < 0) { + perror("acceptor socket"); + return; + } + + int enable = 1; + if (setsockopt( + serverSock, SOL_SOCKET, SO_REUSEADDR, &enable, sizeof(enable)) < + 0) { + perror("setsocketopt SO_REUSEADDR"); + return; + } + + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + addr.sin_port = htons(PORT); + if (bind(serverSock, reinterpret_cast(&addr), addrlen) < + 0) { + perror("bind"); + return; + } + + if (listen(serverSock, 1) < 0) { + perror("listen"); + return; + } + + accepting.store(true); + + if ((sock = accept( + serverSock, reinterpret_cast(&addr), &addrlen)) < + 0) { + perror("accept"); + return; + } + + accepted.store(true); + + size_t sentBytes = 0; + while (sentBytes < loadSize) { + if (send(sock, message.data(), msgLength, 0) != + static_cast(msgLength)) { + perror("send"); + return; + } + sentBytes += msgLength; + } + + close(sock); + close(serverSock); + }); + + while (!accepting) { + std::this_thread::yield(); + } + + const int sock = socket(AF_INET, SOCK_STREAM, 0); + struct sockaddr_in addr = {}; + const socklen_t addrlen = sizeof(addr); + std::array message = {}; + + if (sock < 0) { + perror("connector socket"); + return; + } + + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = inet_addr("127.0.0.1"); + addr.sin_port = htons(PORT); + if (connect(sock, reinterpret_cast(&addr), addrlen) < 0) { + perror("connect"); + return; + } + + while (!accepted) { + std::this_thread::yield(); + } + + size_t receivedBytes = 0; + while (receivedBytes < loadSize) { + const ssize_t recved = recv(sock, message.data(), recvLength, 0); + + if (recved < 0) { + perror("recv"); + return; + } + + receivedBytes += recved; + } + + close(sock); + t.join(); +} + +BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s40B_r1024B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 40; + constexpr size_t receiveSizeB = 1024; + BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} +BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s40B_r4096B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 40; + constexpr size_t receiveSizeB = 4096; + BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} +BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s80B_r4096B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 80; + constexpr size_t receiveSizeB = 4096; + BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} +BENCHMARK(BM_Baseline_TCP_Throughput_100MB_s4096B_r4096B, n) { + (void)n; + constexpr size_t loadSizeB = 100 * 1024 * 1024; + constexpr size_t sendSizeB = 4096; + constexpr size_t receiveSizeB = 4096; + BM_Baseline_TCP_SendReceive(loadSizeB, sendSizeB, receiveSizeB); +} + +BENCHMARK(BM_Baseline_TCP_Latency_1M_msgs_32B, n) { + (void)n; + constexpr size_t messageSizeB = 32; + constexpr size_t loadSizeB = 1000000 * messageSizeB; + BM_Baseline_TCP_SendReceive(loadSizeB, messageSizeB, messageSizeB); +} +BENCHMARK(BM_Baseline_TCP_Latency_1M_msgs_128B, n) { + (void)n; + constexpr size_t messageSizeB = 128; + constexpr size_t loadSizeB = 1000000 * messageSizeB; + BM_Baseline_TCP_SendReceive(loadSizeB, messageSizeB, messageSizeB); +} +BENCHMARK(BM_Baseline_TCP_Latency_1M_msgs_4kB, n) { + (void)n; + constexpr size_t messageSizeB = 4096; + constexpr size_t loadSizeB = 1000000 * messageSizeB; + BM_Baseline_TCP_SendReceive(loadSizeB, messageSizeB, messageSizeB); +} diff --git a/rsocket/benchmarks/Benchmarks.cpp b/rsocket/benchmarks/Benchmarks.cpp new file mode 100644 index 000000000..69a2abc91 --- /dev/null +++ b/rsocket/benchmarks/Benchmarks.cpp @@ -0,0 +1,27 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main(int argc, char** argv) { + folly::init(&argc, &argv); + + FLAGS_logtostderr = true; + + LOG(INFO) << "Running benchmarks... (takes minutes)"; + folly::runBenchmarks(); + + return 0; +} diff --git a/rsocket/benchmarks/CMakeLists.txt b/rsocket/benchmarks/CMakeLists.txt new file mode 100644 index 000000000..4d0c4f51c --- /dev/null +++ b/rsocket/benchmarks/CMakeLists.txt @@ -0,0 +1,29 @@ +add_library(fixture Fixture.cpp Fixture.h) +target_link_libraries(fixture ReactiveSocket Folly::folly) + +function(benchmark NAME FILE) + add_executable(${NAME} ${FILE} Benchmarks.cpp) + target_link_libraries( + ${NAME} + fixture + ReactiveSocket + Folly::follybenchmark + glog::glog + gflags) +endfunction() + +benchmark(baselines_tcp BaselinesTcp.cpp) +benchmark(baselines_async_socket BaselinesAsyncSocket.cpp) + +benchmark(fire-forget-throughput-tcp FireForgetThroughputTcp.cpp) +benchmark(req-response-throughput-tcp RequestResponseThroughputTcp.cpp) +benchmark(stream-throughput-tcp StreamThroughputTcp.cpp) + +benchmark(stream-throughput-mem StreamThroughputMemory.cpp) + +add_test(NAME RequestResponseThroughputTcpTest COMMAND req-response-throughput-tcp --items 100000) +add_test(NAME StreamThroughputTcpTest COMMAND stream-throughput-tcp --items 100000) +add_test(NAME FireForgetThroughputTcpTest COMMAND fire-forget-throughput-tcp --items 100000) + +#TODO(lehecka):enable test +#add_test(NAME StreamThroughputMemoryTest COMMAND stream-throughput-mem --items 100000) diff --git a/rsocket/benchmarks/FireForgetThroughputTcp.cpp b/rsocket/benchmarks/FireForgetThroughputTcp.cpp new file mode 100644 index 000000000..03f5a29f6 --- /dev/null +++ b/rsocket/benchmarks/FireForgetThroughputTcp.cpp @@ -0,0 +1,87 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Fixture.h" +#include "rsocket/benchmarks/Latch.h" + +#include +#include + +#include "rsocket/RSocket.h" + +using namespace rsocket; + +DEFINE_int32(server_threads, 8, "number of server threads to run"); +DEFINE_int32( + override_client_threads, + 0, + "control the number of client threads (defaults to the number of clients)"); +DEFINE_int32(clients, 10, "number of clients to run"); +DEFINE_int32(items, 1000000, "number of items to fire-and-forget, in total"); + +namespace { + +class Responder : public RSocketResponder { + public: + Responder(Latch& latch) : latch_{latch} {} + + void handleFireAndForget(Payload, StreamId) override { + latch_.post(); + } + + private: + Latch& latch_; +}; +} // namespace + +BENCHMARK(FireForgetThroughput, n) { + (void)n; + + Latch latch{static_cast(FLAGS_items)}; + + std::unique_ptr fixture; + Fixture::Options opts; + + BENCHMARK_SUSPEND { + auto responder = std::make_shared(latch); + + opts.serverThreads = FLAGS_server_threads; + opts.clients = FLAGS_clients; + if (FLAGS_override_client_threads > 0) { + opts.clientThreads = FLAGS_override_client_threads; + } + + fixture = std::make_unique(opts, std::move(responder)); + + LOG(INFO) << "Running:"; + LOG(INFO) << " Server with " << opts.serverThreads << " threads."; + LOG(INFO) << " " << opts.clients << " clients across " + << fixture->workers.size() << " threads."; + LOG(INFO) << " Running " << FLAGS_items << " requests in total."; + } + + for (int i = 0; i < FLAGS_items; ++i) { + for (auto& client : fixture->clients) { + client->getRequester() + ->fireAndForget(Payload("TcpFireAndForget")) + ->subscribe( + std::make_shared>()); + } + } + + constexpr std::chrono::minutes timeout{5}; + if (!latch.timed_wait(timeout)) { + LOG(ERROR) << "Timed out!"; + } +} diff --git a/rsocket/benchmarks/Fixture.cpp b/rsocket/benchmarks/Fixture.cpp new file mode 100644 index 000000000..2a42fd222 --- /dev/null +++ b/rsocket/benchmarks/Fixture.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Fixture.h" + +#include "rsocket/RSocket.h" +#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +namespace rsocket { + +namespace { + +std::shared_ptr makeClient( + folly::EventBase* eventBase, + folly::SocketAddress address) { + auto factory = + std::make_unique(*eventBase, std::move(address)); + return RSocket::createConnectedClient(std::move(factory)).get(); +} +} // namespace + +Fixture::Fixture( + Fixture::Options fixtureOpts, + std::shared_ptr responder) + : options{std::move(fixtureOpts)} { + TcpConnectionAcceptor::Options opts; + opts.address = folly::SocketAddress{"0.0.0.0", 0}; + opts.threads = options.serverThreads; + + auto acceptor = std::make_unique(std::move(opts)); + server = std::make_unique(std::move(acceptor)); + server->start([responder](const SetupParameters&) { return responder; }); + + auto const numWorkers = + options.clientThreads ? *options.clientThreads : options.clients; + for (size_t i = 0; i < numWorkers; ++i) { + workers.push_back(std::make_unique( + "rsocket-client-thread")); + } + + const folly::SocketAddress actual{"127.0.0.1", *server->listeningPort()}; + + for (size_t i = 0; i < options.clients; ++i) { + auto worker = std::move(workers.front()); + workers.pop_front(); + clients.push_back(makeClient(worker->getEventBase(), actual)); + workers.push_back(std::move(worker)); + } +} +} // namespace rsocket diff --git a/rsocket/benchmarks/Fixture.h b/rsocket/benchmarks/Fixture.h new file mode 100644 index 000000000..a1b290f7d --- /dev/null +++ b/rsocket/benchmarks/Fixture.h @@ -0,0 +1,54 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/RSocketClient.h" +#include "rsocket/RSocketServer.h" + +#include +#include + +#include +#include + +namespace rsocket { + +/// Benchmarks fixture object that contains a server, along with a list of +/// clients and their worker threads. +/// +/// Uses TCP as the transport. +struct Fixture { + struct Options { + /// Number of threads the server will run. + size_t serverThreads{8}; + + /// Number of clients to run. + size_t clients{8}; + + /// Number of worker threads driving the clients. A default value means to + /// use one thread per client. + folly::Optional clientThreads; + }; + + Fixture(Options, std::shared_ptr); + + // State is public, have at it. + + std::unique_ptr server; + std::deque> workers; + std::vector> clients; + const Options options; +}; +} // namespace rsocket diff --git a/rsocket/benchmarks/Latch.h b/rsocket/benchmarks/Latch.h new file mode 100644 index 000000000..fc5422169 --- /dev/null +++ b/rsocket/benchmarks/Latch.h @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +/// Simple implementation of a latch synchronization primitive, for testing. +class Latch { + public: + explicit Latch(size_t limit) : limit_{limit} {} + + void wait() { + baton_.wait(); + } + + bool timed_wait(std::chrono::milliseconds timeout) { + return baton_.timed_wait(timeout); + } + + void post() { + auto const old = count_.fetch_add(1); + if (old == limit_ - 1) { + baton_.post(); + } + } + + private: + folly::Baton<> baton_; + std::atomic count_{0}; + const size_t limit_{0}; +}; diff --git a/benchmarks/README.md b/rsocket/benchmarks/README.md similarity index 100% rename from benchmarks/README.md rename to rsocket/benchmarks/README.md diff --git a/rsocket/benchmarks/RequestResponseThroughputTcp.cpp b/rsocket/benchmarks/RequestResponseThroughputTcp.cpp new file mode 100644 index 000000000..aace80fd2 --- /dev/null +++ b/rsocket/benchmarks/RequestResponseThroughputTcp.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Fixture.h" +#include "rsocket/benchmarks/Latch.h" +#include "rsocket/benchmarks/Throughput.h" + +#include +#include +#include + +#include "rsocket/RSocket.h" +#include "yarpl/Single.h" + +using namespace rsocket; + +constexpr size_t kMessageLen = 32; + +DEFINE_int32(server_threads, 8, "number of server threads to run"); +DEFINE_int32( + override_client_threads, + 0, + "control the number of client threads (defaults to the number of clients)"); +DEFINE_int32(clients, 10, "number of clients to run"); +DEFINE_int32( + items, + 1000000, + "number of request-response requests to send, in total"); + +namespace { + +class Observer : public yarpl::single::SingleObserverBase { + public: + explicit Observer(Latch& latch) : latch_{latch} {} + + void onSubscribe(std::shared_ptr + subscription) override { + yarpl::single::SingleObserverBase::onSubscribe( + std::move(subscription)); + } + + void onSuccess(Payload) override { + latch_.post(); + yarpl::single::SingleObserverBase::onSuccess({}); + } + + void onError(folly::exception_wrapper) override { + latch_.post(); + yarpl::single::SingleObserverBase::onError({}); + } + + private: + Latch& latch_; +}; +} // namespace + +BENCHMARK(RequestResponseThroughput, n) { + (void)n; + + Latch latch{static_cast(FLAGS_items)}; + + std::unique_ptr fixture; + Fixture::Options opts; + + BENCHMARK_SUSPEND { + auto responder = + std::make_shared(std::string(kMessageLen, 'a')); + + opts.serverThreads = FLAGS_server_threads; + opts.clients = FLAGS_clients; + if (FLAGS_override_client_threads > 0) { + opts.clientThreads = FLAGS_override_client_threads; + } + + fixture = std::make_unique(opts, std::move(responder)); + + LOG(INFO) << "Running:"; + LOG(INFO) << " Server with " << opts.serverThreads << " threads."; + LOG(INFO) << " " << opts.clients << " clients across " + << fixture->workers.size() << " threads."; + LOG(INFO) << " Running " << FLAGS_items << " requests in total"; + } + + for (int i = 0; i < FLAGS_items; ++i) { + auto& client = fixture->clients[i % opts.clients]; + client->getRequester() + ->requestResponse(Payload("RequestResponseTcp")) + ->subscribe(std::make_shared(latch)); + } + + constexpr std::chrono::minutes timeout{5}; + if (!latch.timed_wait(timeout)) { + LOG(ERROR) << "Timed out!"; + } +} diff --git a/rsocket/benchmarks/StreamThroughputMemory.cpp b/rsocket/benchmarks/StreamThroughputMemory.cpp new file mode 100644 index 000000000..c5128152e --- /dev/null +++ b/rsocket/benchmarks/StreamThroughputMemory.cpp @@ -0,0 +1,187 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Throughput.h" + +#include +#include +#include +#include +#include +#include + +#include "rsocket/RSocket.h" +#include "yarpl/Flowable.h" + +using namespace rsocket; + +constexpr size_t kMessageLen = 32; + +DEFINE_int32(items, 1000000, "number of items in stream"); + +namespace { + +/// State shared across the client and server DirectDuplexConnections. +struct State { + /// Whether one of the two connections has been destroyed. + folly::Synchronized destroyed; +}; + +/// DuplexConnection that talks to another DuplexConnection via memory. +class DirectDuplexConnection : public DuplexConnection { + public: + DirectDuplexConnection(std::shared_ptr state, folly::EventBase& evb) + : state_{std::move(state)}, evb_{evb} {} + + ~DirectDuplexConnection() override { + *state_->destroyed.wlock() = true; + } + + // Tie two DirectDuplexConnections together so they can talk to each other. + void tie(DirectDuplexConnection* other) { + other_ = other; + other_->other_ = this; + } + + void setInput(std::shared_ptr input) override { + input_ = std::move(input); + } + + void send(std::unique_ptr buf) override { + auto destroyed = state_->destroyed.rlock(); + if (*destroyed || !other_) { + return; + } + + other_->evb_.runInEventBaseThread( + [state = state_, other = other_, b = std::move(buf)]() mutable { + auto destroyed = state->destroyed.rlock(); + if (*destroyed) { + return; + } + + other->input_->onNext(std::move(b)); + }); + } + + private: + std::shared_ptr state_; + folly::EventBase& evb_; + + DirectDuplexConnection* other_{nullptr}; + + std::shared_ptr input_; +}; + +class Acceptor : public ConnectionAcceptor { + public: + explicit Acceptor(std::shared_ptr state) : state_{std::move(state)} {} + + void setClientConnection(DirectDuplexConnection* connection) { + client_ = connection; + } + + void start(OnDuplexConnectionAccept onAccept) override { + worker_.getEventBase()->runInEventBaseThread( + [this, onAccept = std::move(onAccept)]() mutable { + auto server = std::make_unique( + std::move(state_), *worker_.getEventBase()); + server->tie(client_); + onAccept(std::move(server), *worker_.getEventBase()); + }); + } + + void stop() override {} + + folly::Optional listeningPort() const override { + return folly::none; + } + + private: + std::shared_ptr state_; + + DirectDuplexConnection* client_{nullptr}; + + folly::ScopedEventBaseThread worker_; +}; + +class Factory : public ConnectionFactory { + public: + Factory() { + auto state = std::make_shared(); + + connection_ = std::make_unique( + state, *worker_.getEventBase()); + + auto acceptor = std::make_unique(state); + acceptor_ = acceptor.get(); + + acceptor_->setClientConnection(connection_.get()); + + auto responder = + std::make_shared(std::string(kMessageLen, 'a')); + + server_ = std::make_unique(std::move(acceptor)); + server_->start([responder](const SetupParameters&) { return responder; }); + } + + folly::Future connect( + ProtocolVersion, + ResumeStatus /* unused */) override { + return folly::via(worker_.getEventBase(), [this] { + return ConnectedDuplexConnection{ + std::move(connection_), *worker_.getEventBase()}; + }); + } + + private: + std::unique_ptr connection_; + + std::unique_ptr server_; + Acceptor* acceptor_{nullptr}; + + folly::ScopedEventBaseThread worker_; +}; + +std::shared_ptr makeClient() { + auto factory = std::make_unique(); + return RSocket::createConnectedClient(std::move(factory)).get(); +} +} // namespace + +BENCHMARK(StreamThroughput, n) { + (void)n; + + std::shared_ptr client; + std::shared_ptr subscriber; + + folly::ScopedEventBaseThread worker; + + Latch latch{1}; + + BENCHMARK_SUSPEND { + LOG(INFO) << " Running with " << FLAGS_items << " items"; + + client = makeClient(); + } + + client->getRequester() + ->requestStream(Payload("InMemoryStream")) + ->subscribe(std::make_shared(latch, FLAGS_items)); + + constexpr std::chrono::minutes timeout{5}; + if (!latch.timed_wait(timeout)) { + LOG(ERROR) << "Timed out!"; + } +} diff --git a/rsocket/benchmarks/StreamThroughputTcp.cpp b/rsocket/benchmarks/StreamThroughputTcp.cpp new file mode 100644 index 000000000..4f9c5e343 --- /dev/null +++ b/rsocket/benchmarks/StreamThroughputTcp.cpp @@ -0,0 +1,77 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/benchmarks/Fixture.h" +#include "rsocket/benchmarks/Throughput.h" + +#include +#include +#include + +#include "rsocket/RSocket.h" + +using namespace rsocket; + +constexpr size_t kMessageLen = 32; + +DEFINE_int32(server_threads, 8, "number of server threads to run"); +DEFINE_int32( + override_client_threads, + 0, + "control the number of client threads (defaults to the number of clients)"); +DEFINE_int32(clients, 10, "number of clients to run"); +DEFINE_int32(items, 1000000, "number of items in stream, per client"); +DEFINE_int32(streams, 1, "number of streams, per client"); + +BENCHMARK(StreamThroughput, n) { + (void)n; + + Latch latch{static_cast(FLAGS_streams)}; + + std::unique_ptr fixture; + Fixture::Options opts; + + BENCHMARK_SUSPEND { + auto responder = + std::make_shared(std::string(kMessageLen, 'a')); + + opts.serverThreads = FLAGS_server_threads; + opts.clients = FLAGS_clients; + if (FLAGS_override_client_threads > 0) { + opts.clientThreads = FLAGS_override_client_threads; + } + + fixture = std::make_unique(opts, std::move(responder)); + + LOG(INFO) << "Running:"; + LOG(INFO) << " Server with " << opts.serverThreads << " threads."; + LOG(INFO) << " " << opts.clients << " clients across " + << fixture->workers.size() << " threads."; + LOG(INFO) << " Running " << FLAGS_streams << " streams of " << FLAGS_items + << " items each."; + } + + for (size_t i = 0; i < FLAGS_streams; ++i) { + for (auto& client : fixture->clients) { + client->getRequester() + ->requestStream(Payload("TcpStream")) + ->subscribe(std::make_shared(latch, FLAGS_items)); + } + } + + constexpr std::chrono::minutes timeout{5}; + if (!latch.timed_wait(timeout)) { + LOG(ERROR) << "Timed out!"; + } +} diff --git a/rsocket/benchmarks/Throughput.h b/rsocket/benchmarks/Throughput.h new file mode 100644 index 000000000..c5c215e99 --- /dev/null +++ b/rsocket/benchmarks/Throughput.h @@ -0,0 +1,87 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/RSocketResponder.h" +#include "rsocket/benchmarks/Latch.h" + +namespace rsocket { + +/// Responder that always sends back a fixed message. +class FixedResponder : public RSocketResponder { + public: + explicit FixedResponder(const std::string& message) + : message_{folly::IOBuf::copyBuffer(message)} {} + + /// Infinitely streams back the message. + std::shared_ptr> handleRequestStream( + Payload, + StreamId) override { + return yarpl::flowable::Flowable::fromGenerator( + [msg = message_->clone()] { return Payload(msg->clone()); }); + } + + std::shared_ptr> handleRequestResponse( + Payload, + StreamId) override { + return yarpl::single::Singles::fromGenerator( + [msg = message_->clone()] { return Payload(msg->clone()); }); + } + + private: + std::unique_ptr message_; +}; + +/// Subscriber that requests N items and cancels the subscription once all of +/// them arrive. Signals a latch when it terminates. +class BoundedSubscriber : public yarpl::flowable::BaseSubscriber { + public: + BoundedSubscriber(Latch& latch, size_t requested) + : latch_{latch}, requested_{requested} {} + + void onSubscribeImpl() override { + this->request(requested_); + } + + void onNextImpl(Payload) override { + if (received_.fetch_add(1) == requested_ - 1) { + DCHECK(!terminated_.exchange(true)); + latch_.post(); + + // After this cancel we could be destroyed. + this->cancel(); + } + } + + void onCompleteImpl() override { + if (!terminated_.exchange(true)) { + latch_.post(); + } + } + + void onErrorImpl(folly::exception_wrapper) override { + if (!terminated_.exchange(true)) { + latch_.post(); + } + } + + private: + Latch& latch_; + + std::atomic_bool terminated_{false}; + size_t requested_{0}; + std::atomic received_{0}; +}; +} // namespace rsocket diff --git a/examples/README.md b/rsocket/examples/README.md similarity index 100% rename from examples/README.md rename to rsocket/examples/README.md diff --git a/rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp new file mode 100644 index 000000000..55c12bee6 --- /dev/null +++ b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Client.cpp @@ -0,0 +1,62 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include + +#include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +#include "yarpl/Flowable.h" + +using namespace rsocket; +using namespace yarpl::flowable; + +DEFINE_string(host, "localhost", "host to connect to"); +DEFINE_int32(port, 9898, "host:port to connect to"); + +int main(int argc, char* argv[]) { + FLAGS_logtostderr = true; + FLAGS_minloglevel = 0; + folly::init(&argc, &argv); + + folly::ScopedEventBaseThread worker; + + folly::SocketAddress address; + address.setFromHostPort(FLAGS_host, FLAGS_port); + + auto client = RSocket::createConnectedClient( + std::make_unique( + *worker.getEventBase(), std::move(address))) + .get(); + + client->getRequester() + ->requestChannel( + Payload("initialPayload"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { + std::cout << "Sending: " << v << std::endl; + return Payload(v); + })) + ->subscribe([](Payload p) { + std::cout << "Received: " << p.moveDataToString() << std::endl; + }); + + // Wait for a newline on the console to terminate the server. + std::getchar(); + return 0; +} diff --git a/examples/channel-hello-world/ChannelHelloWorld_Server.cpp b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp similarity index 69% rename from examples/channel-hello-world/ChannelHelloWorld_Server.cpp rename to rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp index 40cf04092..9be2a69d9 100644 --- a/examples/channel-hello-world/ChannelHelloWorld_Server.cpp +++ b/rsocket/examples/channel-hello-world/ChannelHelloWorld_Server.cpp @@ -1,6 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include +#include #include #include @@ -18,9 +31,9 @@ DEFINE_int32(port, 9898, "port to connect to"); class HelloChannelRequestResponder : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestChannel( + std::shared_ptr> handleRequestChannel( rsocket::Payload initialPayload, - yarpl::Reference> request, + std::shared_ptr> request, rsocket::StreamId) override { std::cout << "Initial request " << initialPayload.cloneDataToString() << std::endl; diff --git a/examples/channel-hello-world/README.md b/rsocket/examples/channel-hello-world/README.md similarity index 100% rename from examples/channel-hello-world/README.md rename to rsocket/examples/channel-hello-world/README.md diff --git a/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp similarity index 68% rename from examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp rename to rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp index 8c921d1fd..033649c17 100644 --- a/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp +++ b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Client.cpp @@ -1,23 +1,34 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" using namespace ::folly; -using namespace ::rsocket_example; using namespace ::rsocket; -using namespace yarpl::flowable; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); +namespace { class ChannelConnectionEvents : public RSocketConnectionEvents { public: void onConnected() override { @@ -40,16 +51,19 @@ class ChannelConnectionEvents : public RSocketConnectionEvents { private: std::atomic closed_{false}; }; +} // namespace void sendRequest(std::string mimeType) { + folly::ScopedEventBaseThread worker; folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); auto connectionEvents = std::make_shared(); auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address)), + std::make_unique( + *worker.getEventBase(), std::move(address)), SetupParameters(mimeType, mimeType), std::make_shared(), - nullptr, + kDefaultKeepaliveInterval, RSocketStats::noop(), connectionEvents) .get(); diff --git a/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp similarity index 66% rename from examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp rename to rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp index 93ef0666d..97e640f98 100644 --- a/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp +++ b/rsocket/examples/conditional-request-handling/ConditionalRequestHandling_Server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -7,10 +19,8 @@ #include "JsonRequestHandler.h" #include "TextRequestHandler.h" #include "rsocket/RSocket.h" -#include "rsocket/RSocketErrors.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -using namespace ::folly; using namespace ::rsocket; DEFINE_int32(port, 9898, "port to connect to"); diff --git a/rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp b/rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp new file mode 100644 index 000000000..563ea7d10 --- /dev/null +++ b/rsocket/examples/conditional-request-handling/JsonRequestHandler.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "JsonRequestHandler.h" +#include +#include +#include "yarpl/Flowable.h" + +using namespace rsocket; +using namespace yarpl::flowable; + +/// Handles a new inbound Stream requested by the other end. +std::shared_ptr> +JsonRequestResponder::handleRequestStream(Payload request, StreamId) { + LOG(INFO) << "JsonRequestResponder.handleRequestStream " << request; + + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 100)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello (should be JSON) " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); +} diff --git a/rsocket/examples/conditional-request-handling/JsonRequestHandler.h b/rsocket/examples/conditional-request-handling/JsonRequestHandler.h new file mode 100644 index 000000000..2bc0f45ad --- /dev/null +++ b/rsocket/examples/conditional-request-handling/JsonRequestHandler.h @@ -0,0 +1,26 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" +#include "rsocket/RSocket.h" + +class JsonRequestResponder : public rsocket::RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) + override; +}; diff --git a/examples/conditional-request-handling/README.md b/rsocket/examples/conditional-request-handling/README.md similarity index 100% rename from examples/conditional-request-handling/README.md rename to rsocket/examples/conditional-request-handling/README.md diff --git a/rsocket/examples/conditional-request-handling/TextRequestHandler.cpp b/rsocket/examples/conditional-request-handling/TextRequestHandler.cpp new file mode 100644 index 000000000..708313186 --- /dev/null +++ b/rsocket/examples/conditional-request-handling/TextRequestHandler.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "TextRequestHandler.h" +#include +#include +#include "yarpl/Flowable.h" + +using namespace rsocket; +using namespace yarpl::flowable; + +/// Handles a new inbound Stream requested by the other end. +std::shared_ptr> +TextRequestResponder::handleRequestStream(Payload request, StreamId) { + LOG(INFO) << "TextRequestResponder.handleRequestStream " << request; + + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 100)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); +} diff --git a/rsocket/examples/conditional-request-handling/TextRequestHandler.h b/rsocket/examples/conditional-request-handling/TextRequestHandler.h new file mode 100644 index 000000000..7098b516e --- /dev/null +++ b/rsocket/examples/conditional-request-handling/TextRequestHandler.h @@ -0,0 +1,26 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" +#include "rsocket/RSocket.h" + +class TextRequestResponder : public rsocket::RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) + override; +}; diff --git a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp similarity index 52% rename from examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp rename to rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp index cab0150a5..c20dd0971 100644 --- a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp +++ b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,16 +18,13 @@ #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Single.h" -using namespace rsocket_example; using namespace rsocket; -using namespace yarpl; -using namespace yarpl::single; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -25,11 +34,14 @@ int main(int argc, char* argv[]) { FLAGS_minloglevel = 0; folly::init(&argc, &argv); + folly::ScopedEventBaseThread worker; + folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address))) + std::make_unique( + *worker.getEventBase(), std::move(address))) .get(); client->getRequester()->fireAndForget(Payload("Hello World!"))->subscribe([] { diff --git a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp similarity index 67% rename from examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp rename to rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp index 9ee680095..7c1dfbce0 100644 --- a/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp +++ b/rsocket/examples/fire-and-forget-hello-world/FireAndForgetHelloWorld_Server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -8,11 +20,8 @@ #include "rsocket/RSocket.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "yarpl/Single.h" using namespace rsocket; -using namespace yarpl; -using namespace yarpl::single; DEFINE_int32(port, 9898, "port to connect to"); diff --git a/examples/fire-and-forget-hello-world/README.md b/rsocket/examples/fire-and-forget-hello-world/README.md similarity index 100% rename from examples/fire-and-forget-hello-world/README.md rename to rsocket/examples/fire-and-forget-hello-world/README.md diff --git a/examples/request-response-hello-world/README.md b/rsocket/examples/request-response-hello-world/README.md similarity index 100% rename from examples/request-response-hello-world/README.md rename to rsocket/examples/request-response-hello-world/README.md diff --git a/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp similarity index 53% rename from examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp rename to rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp index 449089b18..f9f935a05 100644 --- a/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp +++ b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,16 +18,13 @@ #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Single.h" -using namespace rsocket_example; using namespace rsocket; -using namespace yarpl; -using namespace yarpl::single; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -28,8 +37,10 @@ int main(int argc, char* argv[]) { folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); + folly::ScopedEventBaseThread worker; auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address))) + std::make_unique( + *worker.getEventBase(), std::move(address))) .get(); client->getRequester() diff --git a/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp similarity index 55% rename from examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp rename to rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp index 789f076c9..7bba6813a 100644 --- a/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp +++ b/rsocket/examples/request-response-hello-world/RequestResponseHelloWorld_Server.cpp @@ -1,6 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include +#include #include #include @@ -11,32 +24,33 @@ #include "yarpl/Single.h" using namespace rsocket; -using namespace yarpl; using namespace yarpl::single; DEFINE_int32(port, 9898, "port to connect to"); +namespace { class HelloRequestResponseResponder : public rsocket::RSocketResponder { public: - Reference> handleRequestResponse(Payload request, StreamId) - override { + std::shared_ptr> handleRequestResponse( + Payload request, + StreamId) override { std::cout << "HelloRequestResponseRequestResponder.handleRequestResponse " << request << std::endl; // string from payload data auto requestString = request.moveDataToString(); - return Single::create([name = std::move(requestString)]( - auto subscriber) { - - std::stringstream ss; - ss << "Hello " << name << "!"; - std::string s = ss.str(); - subscriber->onSubscribe(SingleSubscriptions::empty()); - subscriber->onSuccess(Payload(s, "metadata")); - }); + return Single::create( + [name = std::move(requestString)](auto subscriber) { + std::stringstream ss; + ss << "Hello " << name << "!"; + std::string s = ss.str(); + subscriber->onSubscribe(SingleSubscriptions::empty()); + subscriber->onSuccess(Payload(s, "metadata")); + }); } }; +} // namespace int main(int argc, char* argv[]) { FLAGS_logtostderr = true; diff --git a/rsocket/examples/resumption/ColdResumption_Client.cpp b/rsocket/examples/resumption/ColdResumption_Client.cpp new file mode 100644 index 000000000..8d443bf0f --- /dev/null +++ b/rsocket/examples/resumption/ColdResumption_Client.cpp @@ -0,0 +1,233 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include + +#include "rsocket/RSocket.h" + +#include "rsocket/test/test_utils/ColdResumeManager.h" +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +using namespace rsocket; +using namespace yarpl::flowable; + +DEFINE_string(host, "localhost", "host to connect to"); +DEFINE_int32(port, 9898, "host:port to connect to"); + +typedef std::map>> + HelloSubscribers; + +namespace { + +class HelloSubscriber : public Subscriber { + public: + void request(int n) { + while (!subscription_) { + std::this_thread::yield(); + } + subscription_->request(n); + } + + int rcvdCount() const { + return count_; + }; + + protected: + void onSubscribe(std::shared_ptr subscription) override { + subscription_ = subscription; + } + + void onNext(Payload) noexcept override { + count_++; + } + + void onComplete() override {} + void onError(folly::exception_wrapper) override {} + + private: + std::shared_ptr subscription_; + std::atomic count_{0}; +}; + +class HelloResumeHandler : public ColdResumeHandler { + public: + explicit HelloResumeHandler(HelloSubscribers subscribers) + : subscribers_(std::move(subscribers)) {} + + std::string generateStreamToken(const Payload& payload, StreamId, StreamType) + const override { + auto streamToken = + payload.data->cloneAsValue().moveToFbString().toStdString(); + VLOG(3) << "Generated token: " << streamToken; + return streamToken; + } + + std::shared_ptr> handleRequesterResumeStream( + std::string streamToken, + size_t consumerAllowance) override { + CHECK(subscribers_.find(streamToken) != subscribers_.end()); + LOG(INFO) << "Resuming " << streamToken << " stream with allowance " + << consumerAllowance; + return subscribers_[streamToken]; + } + + private: + HelloSubscribers subscribers_; +}; + +SetupParameters getSetupParams(ResumeIdentificationToken token) { + SetupParameters setupParameters; + setupParameters.resumable = true; + setupParameters.token = token; + return setupParameters; +} + +std::unique_ptr getConnFactory( + folly::EventBase* eventBase) { + folly::SocketAddress address; + address.setFromHostPort(FLAGS_host, FLAGS_port); + return std::make_unique(*eventBase, address); +} +} // namespace + +// There are three sessions and three streams. +// There is cold-resumption between the three sessions. +// The first stream lasts through all three sessions. +// The second stream lasts through the second and third session. +// the third stream lives only in the third session. + +int main(int argc, char* argv[]) { + FLAGS_logtostderr = true; + FLAGS_minloglevel = 0; + folly::init(&argc, &argv); + + folly::ScopedEventBaseThread worker; + + auto token = ResumeIdentificationToken::generateNew(); + + std::string firstPayload = "First"; + std::string secondPayload = "Second"; + std::string thirdPayload = "Third"; + + { + auto resumeManager = std::make_shared( + RSocketStats::noop(), "" /* inputFile */); + { + auto firstSub = std::make_shared(); + auto coldResumeHandler = std::make_shared( + HelloSubscribers({{firstPayload, firstSub}})); + auto firstClient = RSocket::createConnectedClient( + getConnFactory(worker.getEventBase()), + getSetupParams(token), + nullptr, // responder + kDefaultKeepaliveInterval, + nullptr, // stats + nullptr, // connectionEvents + resumeManager, + coldResumeHandler) + .get(); + firstClient->getRequester() + ->requestStream(Payload(firstPayload)) + ->subscribe(firstSub); + firstSub->request(7); + while (firstSub->rcvdCount() < 3) { + std::this_thread::yield(); + } + firstClient->disconnect(std::runtime_error("disconnect from client")); + } + worker.getEventBase()->runInEventBaseThreadAndWait( + [resumeManager = std::move(resumeManager)]() { + // We want to persist state after RSocketStateMachine of the client + // has been completely destroyed and before we start the next scope. + // Since the RSocketStateMachine's destruction proceeds + // asynchronously in worker thread, we have to schedule the + // persistence in the worker thread. + resumeManager->persistState("/tmp/firstResumption.json"); + }); + } + + LOG(INFO) << "============== First Cold Resumption ================"; + + { + auto resumeManager = std::make_shared( + RSocketStats::noop(), "/tmp/firstResumption.json" /* inputFile */); + { + auto firstSub = std::make_shared(); + auto coldResumeHandler = std::make_shared( + HelloSubscribers({{firstPayload, firstSub}})); + auto secondClient = RSocket::createResumedClient( + getConnFactory(worker.getEventBase()), + token, + resumeManager, + coldResumeHandler) + .get(); + + firstSub->request(3); + + // Create another stream to verify StreamIds are set properly after + // resumption + auto secondSub = std::make_shared(); + secondClient->getRequester() + ->requestStream(Payload(secondPayload)) + ->subscribe(secondSub); + secondSub->request(5); + firstSub->request(4); + while (secondSub->rcvdCount() < 1) { + std::this_thread::yield(); + } + } + worker.getEventBase()->runInEventBaseThreadAndWait( + [resumeManager = std::move(resumeManager)]() { + // Refer to comments in the above scope. + resumeManager->persistState("/tmp/secondResumption.json"); + }); + } + + LOG(INFO) << "============== Second Cold Resumption ================"; + + { + auto resumeManager = std::make_shared( + RSocketStats::noop(), "/tmp/secondResumption.json" /* inputFile */); + auto firstSub = std::make_shared(); + auto secondSub = std::make_shared(); + auto coldResumeHandler = + std::make_shared(HelloSubscribers( + {{firstPayload, firstSub}, {secondPayload, secondSub}})); + auto thirdClient = RSocket::createResumedClient( + getConnFactory(worker.getEventBase()), + token, + resumeManager, + coldResumeHandler) + .get(); + + firstSub->request(6); + secondSub->request(5); + + // Create another stream to verify StreamIds are set properly after + // resumption + auto thirdSub = std::make_shared(); + thirdClient->getRequester() + ->requestStream(Payload(thirdPayload)) + ->subscribe(thirdSub); + thirdSub->request(5); + + getchar(); + } + + return 0; +} diff --git a/examples/warm-resumption/WarmResumption_Server.cpp b/rsocket/examples/resumption/Resumption_Server.cpp similarity index 68% rename from examples/warm-resumption/WarmResumption_Server.cpp rename to rsocket/examples/resumption/Resumption_Server.cpp index 3511635b2..36ccb6bb7 100644 --- a/examples/warm-resumption/WarmResumption_Server.cpp +++ b/rsocket/examples/resumption/Resumption_Server.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -17,21 +29,17 @@ DEFINE_int32(port, 9898, "Port to accept connections on"); class HelloStreamRequestResponder : public RSocketResponder { public: - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( rsocket::Payload request, rsocket::StreamId) override { auto requestString = request.moveDataToString(); - return Flowables::range(1, 1000)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); + return Flowable<>::range(1, 1000)->map( + [name = std::move(requestString)](int64_t v) { + return Payload(folly::to(v), "metadata"); + }); } }; - class HelloServiceHandler : public RSocketServiceHandler { public: folly::Expected onNewSetup( @@ -67,7 +75,7 @@ int main(int argc, char* argv[]) { TcpConnectionAcceptor::Options opts; opts.address = folly::SocketAddress("::", FLAGS_port); - opts.threads = 1; + opts.threads = 3; auto rs = RSocket::createServer( std::make_unique(std::move(opts))); diff --git a/examples/warm-resumption/WarmResumption_Client.cpp b/rsocket/examples/resumption/WarmResumption_Client.cpp similarity index 61% rename from examples/warm-resumption/WarmResumption_Client.cpp rename to rsocket/examples/resumption/WarmResumption_Client.cpp index fff23ea4b..f83a24e85 100644 --- a/examples/warm-resumption/WarmResumption_Client.cpp +++ b/rsocket/examples/resumption/WarmResumption_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,14 +18,13 @@ #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/internal/ClientResumeStatusCallback.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; DEFINE_string(host, "localhost", "host to connect to"); @@ -21,8 +32,7 @@ DEFINE_int32(port, 9898, "host:port to connect to"); namespace { -class HelloSubscriber : public virtual yarpl::Refcounted, - public yarpl::flowable::Subscriber { +class HelloSubscriber : public yarpl::flowable::Subscriber { public: void request(int n) { LOG(INFO) << "... requesting " << n; @@ -43,7 +53,7 @@ class HelloSubscriber : public virtual yarpl::Refcounted, }; protected: - void onSubscribe(yarpl::Reference + void onSubscribe(std::shared_ptr subscription) noexcept override { subscription_ = subscription; } @@ -57,24 +67,26 @@ class HelloSubscriber : public virtual yarpl::Refcounted, LOG(INFO) << "Received: onComplete"; } - void onError(std::exception_ptr) noexcept override { + void onError(folly::exception_wrapper) noexcept override { LOG(INFO) << "Received: onError "; } private: - yarpl::Reference subscription_{nullptr}; + std::shared_ptr subscription_{nullptr}; std::atomic count_{0}; }; -} +} // namespace -std::shared_ptr getClientAndRequestStream( - yarpl::Reference subscriber) { +std::unique_ptr getClientAndRequestStream( + folly::EventBase* eventBase, + std::shared_ptr subscriber) { folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); SetupParameters setupParameters; setupParameters.resumable = true; auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address)), + std::make_unique( + *eventBase, std::move(address)), std::move(setupParameters)) .get(); client->getRequester()->requestStream(Payload("Jane"))->subscribe(subscriber); @@ -86,8 +98,10 @@ int main(int argc, char* argv[]) { FLAGS_minloglevel = 0; folly::init(&argc, &argv); - auto subscriber1 = yarpl::make_ref(); - auto client = getClientAndRequestStream(subscriber1); + folly::ScopedEventBaseThread worker1; + + auto subscriber1 = std::make_shared(); + auto client = getClientAndRequestStream(worker1.getEventBase(), subscriber1); subscriber1->request(7); @@ -96,11 +110,11 @@ int main(int argc, char* argv[]) { } client->disconnect(std::runtime_error("disconnect triggered from client")); - folly::ScopedEventBaseThread worker_; + folly::ScopedEventBaseThread worker2; client->resume() - .via(worker_.getEventBase()) - .then([subscriber1] { + .via(worker2.getEventBase()) + .thenValue([subscriber1](folly::Unit) { // continue with the old client. subscriber1->request(3); while (subscriber1->rcvdCount() < 10) { @@ -108,7 +122,7 @@ int main(int argc, char* argv[]) { } subscriber1->cancel(); }) - .onError([](folly::exception_wrapper ex) { + .thenError([&](folly::exception_wrapper ex) { LOG(INFO) << "Resumption Failed: " << ex.what(); try { ex.throw_exception(); @@ -120,8 +134,9 @@ int main(int argc, char* argv[]) { LOG(INFO) << "UnknownException " << typeid(e).name(); } // Create a new client - auto subscriber2 = yarpl::make_ref(); - auto client = getClientAndRequestStream(subscriber2); + auto subscriber2 = std::make_shared(); + auto client = + getClientAndRequestStream(worker1.getEventBase(), subscriber2); subscriber2->request(7); while (subscriber2->rcvdCount() < 7) { std::this_thread::yield(); diff --git a/examples/stream-hello-world/README.md b/rsocket/examples/stream-hello-world/README.md similarity index 100% rename from examples/stream-hello-world/README.md rename to rsocket/examples/stream-hello-world/README.md diff --git a/examples/stream-hello-world/StreamHelloWorld_Client.cpp b/rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp similarity index 53% rename from examples/stream-hello-world/StreamHelloWorld_Client.cpp rename to rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp index 4c224ab05..07fc75d9f 100644 --- a/examples/stream-hello-world/StreamHelloWorld_Client.cpp +++ b/rsocket/examples/stream-hello-world/StreamHelloWorld_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,13 +18,12 @@ #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; DEFINE_string(host, "localhost", "host to connect to"); @@ -23,11 +34,14 @@ int main(int argc, char* argv[]) { FLAGS_minloglevel = 0; folly::init(&argc, &argv); + folly::ScopedEventBaseThread worker; + folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address))) + std::make_unique( + *worker.getEventBase(), std::move(address))) .get(); client->getRequester() diff --git a/examples/stream-hello-world/StreamHelloWorld_Server.cpp b/rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp similarity index 60% rename from examples/stream-hello-world/StreamHelloWorld_Server.cpp rename to rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp index 1d57fb467..c3af1a157 100644 --- a/examples/stream-hello-world/StreamHelloWorld_Server.cpp +++ b/rsocket/examples/stream-hello-world/StreamHelloWorld_Server.cpp @@ -1,6 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include +#include #include #include @@ -18,7 +31,7 @@ DEFINE_int32(port, 9898, "port to connect to"); class HelloStreamRequestResponder : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( rsocket::Payload request, rsocket::StreamId) override { std::cout << "HelloStreamRequestResponder.handleRequestStream " << request @@ -27,13 +40,13 @@ class HelloStreamRequestResponder : public rsocket::RSocketResponder { // string from payload data auto requestString = request.moveDataToString(); - return Flowables::range(1, 10)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); + return Flowable<>::range(1, 10)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); } }; diff --git a/examples/stream-observable-to-flowable/README.md b/rsocket/examples/stream-observable-to-flowable/README.md similarity index 100% rename from examples/stream-observable-to-flowable/README.md rename to rsocket/examples/stream-observable-to-flowable/README.md diff --git a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp similarity index 57% rename from examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp rename to rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp index 40726bc95..18cc8be0c 100644 --- a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp +++ b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -6,15 +18,13 @@ #include #include -#include "examples/util/ExampleSubscriber.h" #include "rsocket/RSocket.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" #include "yarpl/Flowable.h" -using namespace rsocket_example; using namespace rsocket; -using yarpl::flowable::Subscribers; DEFINE_string(host, "localhost", "host to connect to"); DEFINE_int32(port, 9898, "host:port to connect to"); @@ -24,11 +34,14 @@ int main(int argc, char* argv[]) { FLAGS_minloglevel = 0; folly::init(&argc, &argv); + folly::ScopedEventBaseThread worker; + folly::SocketAddress address; address.setFromHostPort(FLAGS_host, FLAGS_port); auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address))) + std::make_unique( + *worker.getEventBase(), std::move(address))) .get(); client->getRequester() diff --git a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp similarity index 59% rename from examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp rename to rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp index d179b1655..30f2b3d5e 100644 --- a/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp +++ b/rsocket/examples/stream-observable-to-flowable/StreamObservableToFlowable_Server.cpp @@ -1,7 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include +#include #include #include @@ -10,7 +23,6 @@ #include "rsocket/RSocket.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" #include "yarpl/Observable.h" -#include "yarpl/schedulers/ThreadScheduler.h" using namespace rsocket; using namespace yarpl; @@ -22,7 +34,7 @@ DEFINE_int32(port, 9898, "port to connect to"); class PushStreamRequestResponder : public rsocket::RSocketResponder { public: /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( Payload request, rsocket::StreamId) override { std::cout << "PushStreamRequestResponder.handleRequestStream " << request @@ -45,26 +57,24 @@ class PushStreamRequestResponder : public rsocket::RSocketResponder { // This examples uses BackpressureStrategy::DROP which simply // drops any events emitted from the Observable if the Flowable // does not have any credits from the Subscriber. - return Observable::create([name = std::move(requestString)]( - Reference> s) { - // Must make this async since it's an infinite stream - // and will block the IO thread. - // Using a raw thread right now since the 'subscribeOn' - // operator is not ready yet. This can eventually - // be replaced with use of 'subscribeOn'. - std::thread([s, name]() { - auto subscription = Subscriptions::atomicBoolSubscription(); - s->onSubscribe(subscription); - int64_t v = 0; - while (!subscription->isCancelled()) { - std::stringstream ss; - ss << "Event[" << name << "]-" << ++v << "!"; - std::string payloadData = ss.str(); - s->onNext(Payload(payloadData, "metadata")); - } - }).detach(); - - }) + return Observable::create( + [name = std::move(requestString)]( + std::shared_ptr> s) { + // Must make this async since it's an infinite stream + // and will block the IO thread. + // Using a raw thread right now since the 'subscribeOn' + // operator is not ready yet. This can eventually + // be replaced with use of 'subscribeOn'. + std::thread([s, name]() { + int64_t v = 0; + while (!s->isUnsubscribed()) { + std::stringstream ss; + ss << "Event[" << name << "]-" << ++v << "!"; + std::string payloadData = ss.str(); + s->onNext(Payload(payloadData, "metadata")); + } + }).detach(); + }) ->toFlowable(BackpressureStrategy::DROP); } }; diff --git a/examples/util/ExampleSubscriber.cpp b/rsocket/examples/util/ExampleSubscriber.cpp similarity index 70% rename from examples/util/ExampleSubscriber.cpp rename to rsocket/examples/util/ExampleSubscriber.cpp index d74744d98..6ad535d36 100644 --- a/examples/util/ExampleSubscriber.cpp +++ b/rsocket/examples/util/ExampleSubscriber.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "examples/util/ExampleSubscriber.h" +#include "rsocket/examples/util/ExampleSubscriber.h" #include using namespace ::rsocket; @@ -23,7 +35,7 @@ ExampleSubscriber::ExampleSubscriber(int initialRequest, int numToTake) } void ExampleSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { LOG(INFO) << "ExampleSubscriber " << this << " onSubscribe, requesting " << initialRequest_; subscription_ = std::move(subscription); @@ -55,12 +67,8 @@ void ExampleSubscriber::onComplete() noexcept { terminalEventCV_.notify_all(); } -void ExampleSubscriber::onError(std::exception_ptr ex) noexcept { - try { - std::rethrow_exception(ex); - } catch (const std::exception& e) { - LOG(ERROR) << "ExampleSubscriber " << this << " onError: " << e.what(); - } +void ExampleSubscriber::onError(folly::exception_wrapper ex) noexcept { + LOG(ERROR) << "ExampleSubscriber " << this << " onError: " << ex; terminated_ = true; terminalEventCV_.notify_all(); } @@ -73,4 +81,4 @@ void ExampleSubscriber::awaitTerminalEvent() { terminalEventCV_.wait(lk, [this] { return terminated_; }); LOG(INFO) << "ExampleSubscriber " << this << " unblocked"; } -} +} // namespace rsocket_example diff --git a/examples/util/ExampleSubscriber.h b/rsocket/examples/util/ExampleSubscriber.h similarity index 51% rename from examples/util/ExampleSubscriber.h rename to rsocket/examples/util/ExampleSubscriber.h index c814ac38f..24a1caa23 100644 --- a/examples/util/ExampleSubscriber.h +++ b/rsocket/examples/util/ExampleSubscriber.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -20,11 +32,11 @@ class ExampleSubscriber : public yarpl::flowable::Subscriber { ~ExampleSubscriber(); ExampleSubscriber(int initialRequest, int numToTake); - void onSubscribe(yarpl::Reference + void onSubscribe(std::shared_ptr subscription) noexcept override; void onNext(rsocket::Payload) noexcept override; void onComplete() noexcept override; - void onError(std::exception_ptr ex) noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; void awaitTerminalEvent(); @@ -34,9 +46,9 @@ class ExampleSubscriber : public yarpl::flowable::Subscriber { int numToTake_; int requested_; int received_; - yarpl::Reference subscription_; + std::shared_ptr subscription_; bool terminated_{false}; std::mutex m_; std::condition_variable terminalEventCV_; }; -} +} // namespace rsocket_example diff --git a/examples/util/README.md b/rsocket/examples/util/README.md similarity index 100% rename from examples/util/README.md rename to rsocket/examples/util/README.md diff --git a/rsocket/framing/ErrorCode.cpp b/rsocket/framing/ErrorCode.cpp index 4b9b4767d..6ee11c348 100644 --- a/rsocket/framing/ErrorCode.cpp +++ b/rsocket/framing/ErrorCode.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/ErrorCode.h" @@ -31,4 +43,4 @@ std::ostream& operator<<(std::ostream& os, ErrorCode errorCode) { } return os << "ErrorCode[" << static_cast(errorCode) << "]"; } -} +} // namespace rsocket diff --git a/rsocket/framing/ErrorCode.h b/rsocket/framing/ErrorCode.h index f6f0ce31d..93f741aaa 100644 --- a/rsocket/framing/ErrorCode.h +++ b/rsocket/framing/ErrorCode.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -40,4 +52,4 @@ enum class ErrorCode : uint32_t { }; std::ostream& operator<<(std::ostream&, ErrorCode); -} +} // namespace rsocket diff --git a/rsocket/framing/Frame.cpp b/rsocket/framing/Frame.cpp index d46073a67..9b3d8cc53 100644 --- a/rsocket/framing/Frame.cpp +++ b/rsocket/framing/Frame.cpp @@ -1,9 +1,20 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/Frame.h" #include -#include #include #include #include @@ -12,97 +23,25 @@ namespace rsocket { -const uint32_t Frame_LEASE::kMaxTtl; -const uint32_t Frame_LEASE::kMaxNumRequests; -const uint32_t Frame_SETUP::kMaxKeepaliveTime; -const uint32_t Frame_SETUP::kMaxLifetime; - -std::unique_ptr FrameBufferAllocator::allocate(size_t size) { - // Purposely leak the allocator, since it's hard to deterministically - // guarantee that threads will stop using it before it would get statically - // destructed. - static auto* singleton = new FrameBufferAllocator; - return singleton->allocateBuffer(size); -} - -std::unique_ptr FrameBufferAllocator::allocateBuffer( - size_t size) { - return folly::IOBuf::createCombined(size); -} - -std::ostream& -writeFlags(std::ostream& os, FrameFlags frameFlags, FrameType frameType) { - constexpr const char* kEmpty = "0x00"; - constexpr const char* kMetadata = "METADATA"; - constexpr const char* kResumeEnable = "RESUME_ENABLE"; - constexpr const char* kLease = "LEASE"; - constexpr const char* kKeepAliveRespond = "KEEPALIVE_RESPOND"; - constexpr const char* kFollows = "FOLLOWS"; - constexpr const char* kComplete = "COMPLETE"; - constexpr const char* kNext = "NEXT"; - - static std::map>> - flagToNameMap{{FrameType::REQUEST_N, {}}, - {FrameType::REQUEST_RESPONSE, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}}}, - {FrameType::REQUEST_FNF, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}}}, - {FrameType::METADATA_PUSH, {}}, - {FrameType::CANCEL, {}}, - {FrameType::PAYLOAD, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}, - {FrameFlags::COMPLETE, kComplete}, - {FrameFlags::NEXT, kNext}}}, - {FrameType::ERROR, {{FrameFlags::METADATA, kMetadata}}}, - {FrameType::KEEPALIVE, - {{FrameFlags::KEEPALIVE_RESPOND, kKeepAliveRespond}}}, - {FrameType::SETUP, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::RESUME_ENABLE, kResumeEnable}, - {FrameFlags::LEASE, kLease}}}, - {FrameType::LEASE, {{FrameFlags::METADATA, kMetadata}}}, - {FrameType::RESUME, {}}, - {FrameType::REQUEST_CHANNEL, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}, - {FrameFlags::COMPLETE, kComplete}}}, - {FrameType::REQUEST_STREAM, - {{FrameFlags::METADATA, kMetadata}, - {FrameFlags::FOLLOWS, kFollows}}}}; - - FrameFlags foundFlags = FrameFlags::EMPTY; - - // Search the corresponding string value for each flag, insert the missing - // ones as empty - const std::vector>& allowedFlags = - flagToNameMap[frameType]; - - std::string delimiter = ""; - for (const auto& pair : allowedFlags) { - if (!!(frameFlags & pair.first)) { - os << delimiter << pair.second; - delimiter = "|"; - foundFlags |= pair.first; - } - } +namespace detail { - if (foundFlags != frameFlags) { - os << frameFlags; - } else if (delimiter.empty()) { - os << kEmpty; - } - return os; +FrameFlags getFlags(const Payload& p) { + return p.metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_; } -std::ostream& operator<<(std::ostream& os, const FrameHeader& header) { - os << header.type_ << "["; - return writeFlags(os, header.flags_, header.type_) << ", " << header.streamId_ << "]"; +void checkFlags(const Payload& p, FrameFlags flags) { + if (bool(p.metadata) != bool(flags & FrameFlags::METADATA)) { + throw std::invalid_argument{ + "Value of METADATA flag doesn't match payload metadata"}; + } } -/// @} +} // namespace detail + +constexpr uint32_t Frame_LEASE::kMaxTtl; +constexpr uint32_t Frame_LEASE::kMaxNumRequests; +constexpr uint32_t Frame_SETUP::kMaxKeepaliveTime; +constexpr uint32_t Frame_SETUP::kMaxLifetime; std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_Base& frame) { return os << frame.header_ << "(" << frame.requestN_ << ", " @@ -141,42 +80,65 @@ std::ostream& operator<<(std::ostream& os, const Frame_PAYLOAD& frame) { return os << frame.header_ << ", " << frame.payload_; } -Frame_ERROR Frame_ERROR::unexpectedFrame() { - return Frame_ERROR( - 0, ErrorCode::CONNECTION_ERROR, Payload("unexpected frame")); +Frame_ERROR Frame_ERROR::invalidSetup(folly::StringPiece message) { + return connectionErr(ErrorCode::INVALID_SETUP, message); +} + +Frame_ERROR Frame_ERROR::unsupportedSetup(folly::StringPiece message) { + return connectionErr(ErrorCode::UNSUPPORTED_SETUP, message); } -Frame_ERROR Frame_ERROR::invalidFrame() { - return Frame_ERROR(0, ErrorCode::CONNECTION_ERROR, Payload("invalid frame")); +Frame_ERROR Frame_ERROR::rejectedSetup(folly::StringPiece message) { + return connectionErr(ErrorCode::REJECTED_SETUP, message); } -Frame_ERROR Frame_ERROR::badSetupFrame(const std::string& message) { - return Frame_ERROR(0, ErrorCode::INVALID_SETUP, Payload(message)); +Frame_ERROR Frame_ERROR::rejectedResume(folly::StringPiece message) { + return connectionErr(ErrorCode::REJECTED_RESUME, message); } -Frame_ERROR Frame_ERROR::rejectedSetup(const std::string& message) { - return Frame_ERROR(0, ErrorCode::REJECTED_SETUP, Payload(message)); +Frame_ERROR Frame_ERROR::connectionError(folly::StringPiece message) { + return connectionErr(ErrorCode::CONNECTION_ERROR, message); } -Frame_ERROR Frame_ERROR::connectionError(const std::string& message) { - return Frame_ERROR(0, ErrorCode::CONNECTION_ERROR, Payload(message)); +Frame_ERROR Frame_ERROR::applicationError( + StreamId stream, + folly::StringPiece message) { + return streamErr(ErrorCode::APPLICATION_ERROR, message, stream); } -Frame_ERROR Frame_ERROR::rejectedResume(const std::string& message) { - return Frame_ERROR(0, ErrorCode::REJECTED_RESUME, Payload(message)); +Frame_ERROR Frame_ERROR::applicationError(StreamId stream, Payload&& payload) { + if (stream == 0) { + throw std::invalid_argument{"Can't make stream error for stream zero"}; + } + return Frame_ERROR(stream, ErrorCode::APPLICATION_ERROR, std::move(payload)); } -Frame_ERROR Frame_ERROR::error(StreamId streamId, Payload&& payload) { - DCHECK(streamId) << "streamId MUST be non-0"; - return Frame_ERROR(streamId, ErrorCode::INVALID, std::move(payload)); +Frame_ERROR Frame_ERROR::rejected(StreamId stream, folly::StringPiece message) { + return streamErr(ErrorCode::REJECTED, message, stream); } -Frame_ERROR Frame_ERROR::applicationError( - StreamId streamId, - Payload&& payload) { - DCHECK(streamId) << "streamId MUST be non-0"; - return Frame_ERROR( - streamId, ErrorCode::APPLICATION_ERROR, std::move(payload)); +Frame_ERROR Frame_ERROR::canceled(StreamId stream, folly::StringPiece message) { + return streamErr(ErrorCode::CANCELED, message, stream); +} + +Frame_ERROR Frame_ERROR::invalid(StreamId stream, folly::StringPiece message) { + return streamErr(ErrorCode::INVALID, message, stream); +} + +Frame_ERROR Frame_ERROR::connectionErr( + ErrorCode err, + folly::StringPiece message) { + return Frame_ERROR{0, err, Payload{message}}; +} + +Frame_ERROR Frame_ERROR::streamErr( + ErrorCode err, + folly::StringPiece message, + StreamId stream) { + if (stream == 0) { + throw std::invalid_argument{"Can't make stream error for stream zero"}; + } + return Frame_ERROR{stream, err, Payload{message}}; } std::ostream& operator<<(std::ostream& os, const Frame_ERROR& frame) { @@ -192,7 +154,8 @@ std::ostream& operator<<(std::ostream& os, const Frame_KEEPALIVE& frame) { std::ostream& operator<<(std::ostream& os, const Frame_SETUP& frame) { return os << frame.header_ << ", Version: " << frame.versionMajor_ << "." - << frame.versionMinor_ << ", " << frame.payload_; + << frame.versionMinor_ << ", " + << "Token: " << frame.token_ << ", " << frame.payload_; } void Frame_SETUP::moveToSetupPayload(SetupParameters& setupPayload) { @@ -200,7 +163,7 @@ void Frame_SETUP::moveToSetupPayload(SetupParameters& setupPayload) { setupPayload.dataMimeType = std::move(dataMimeType_); setupPayload.payload = std::move(payload_); setupPayload.token = std::move(token_); - setupPayload.resumable = !!(header_.flags_ & FrameFlags::RESUME_ENABLE); + setupPayload.resumable = !!(header_.flags & FrameFlags::RESUME_ENABLE); setupPayload.protocolVersion = ProtocolVersion(versionMajor_, versionMinor_); } @@ -212,8 +175,8 @@ std::ostream& operator<<(std::ostream& os, const Frame_LEASE& frame) { std::ostream& operator<<(std::ostream& os, const Frame_RESUME& frame) { return os << frame.header_ << ", (" - << "token" - << ", @server " << frame.lastReceivedServerPosition_ << ", @client " + << "token " << frame.token_ << ", @server " + << frame.lastReceivedServerPosition_ << ", @client " << frame.clientPosition_ << ")"; } @@ -222,11 +185,13 @@ std::ostream& operator<<(std::ostream& os, const Frame_RESUME_OK& frame) { } std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_CHANNEL& frame) { - return os << frame.header_ << ", " << frame.payload_; + return os << frame.header_ << ", initialRequestN=" << frame.requestN_ << ", " + << frame.payload_; } std::ostream& operator<<(std::ostream& os, const Frame_REQUEST_STREAM& frame) { - return os << frame.header_ << ", " << frame.payload_; + return os << frame.header_ << ", initialRequestN=" << frame.requestN_ << ", " + << frame.payload_; } } // namespace rsocket diff --git a/rsocket/framing/Frame.h b/rsocket/framing/Frame.h index e55e25bd4..8de331f1a 100644 --- a/rsocket/framing/Frame.h +++ b/rsocket/framing/Frame.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,8 +24,10 @@ #include "rsocket/Payload.h" #include "rsocket/framing/ErrorCode.h" #include "rsocket/framing/FrameFlags.h" +#include "rsocket/framing/FrameHeader.h" #include "rsocket/framing/FrameType.h" -#include "rsocket/internal/Common.h" +#include "rsocket/framing/ProtocolVersion.h" +#include "rsocket/framing/ResumeIdentificationToken.h" namespace folly { template @@ -21,47 +35,22 @@ class Optional; namespace io { class Cursor; class QueueAppender; -} -} +} // namespace io +} // namespace folly namespace rsocket { -class FrameHeader { - public: - FrameHeader() { -#ifndef NDEBUG - type_ = FrameType::RESERVED; -#endif // NDEBUG - } - FrameHeader(FrameType type, FrameFlags flags, StreamId streamId) - : type_(type), flags_(flags), streamId_(streamId) {} +namespace detail { - bool flagsComplete() const { - return !!(flags_ & FrameFlags::COMPLETE); - } +FrameFlags getFlags(const Payload&); - bool flagsNext() const { - return !!(flags_ & FrameFlags::NEXT); - } +void checkFlags(const Payload&, FrameFlags); - FrameType type_{}; - FrameFlags flags_{}; - StreamId streamId_{}; -}; +} // namespace detail -std::ostream& operator<<(std::ostream&, const FrameHeader&); +using ResumePosition = int64_t; +constexpr ResumePosition kUnspecifiedResumePosition = -1; -class FrameBufferAllocator { - public: - static std::unique_ptr allocate(size_t size); - - virtual ~FrameBufferAllocator() = default; - - private: - virtual std::unique_ptr allocateBuffer(size_t size); -}; - -/// @{ /// Frames do not form hierarchy, as we never perform type erasure on a frame. /// We use inheritance only to save code duplication. /// @@ -81,7 +70,7 @@ class Frame_REQUEST_N { Frame_REQUEST_N() = default; Frame_REQUEST_N(StreamId streamId, uint32_t requestN) - : header_(FrameType::REQUEST_N, FrameFlags::EMPTY, streamId), + : header_(FrameType::REQUEST_N, FrameFlags::EMPTY_, streamId), requestN_(requestN) { DCHECK(requestN_ > 0); DCHECK(requestN_ <= kMaxRequestN); @@ -101,12 +90,10 @@ class Frame_REQUEST_Base { FrameFlags flags, uint32_t requestN, Payload payload) - : header_(frameType, flags | payload.getFlags(), streamId), + : header_(frameType, flags | detail::getFlags(payload), streamId), requestN_(requestN), payload_(std::move(payload)) { - // to verify the client didn't set - // METADATA and provided none - payload_.checkFlags(header_.flags_); + detail::checkFlags(payload_, header_.flags); // TODO: DCHECK(requestN_ > 0); DCHECK(requestN_ <= Frame_REQUEST_N::kMaxRequestN); } @@ -190,11 +177,10 @@ class Frame_REQUEST_RESPONSE { Frame_REQUEST_RESPONSE(StreamId streamId, FrameFlags flags, Payload payload) : header_( FrameType::REQUEST_RESPONSE, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), streamId), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags_); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); } FrameHeader header_; @@ -211,11 +197,10 @@ class Frame_REQUEST_FNF { Frame_REQUEST_FNF(StreamId streamId, FrameFlags flags, Payload payload) : header_( FrameType::REQUEST_FNF, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), streamId), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags_); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); } FrameHeader header_; @@ -241,7 +226,7 @@ class Frame_CANCEL { public: Frame_CANCEL() = default; explicit Frame_CANCEL(StreamId streamId) - : header_(FrameType::CANCEL, FrameFlags::EMPTY, streamId) {} + : header_(FrameType::CANCEL, FrameFlags::EMPTY_, streamId) {} FrameHeader header_; }; @@ -256,11 +241,10 @@ class Frame_PAYLOAD { Frame_PAYLOAD(StreamId streamId, FrameFlags flags, Payload payload) : header_( FrameType::PAYLOAD, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), streamId), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags_); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); } static Frame_PAYLOAD complete(StreamId streamId); @@ -276,19 +260,29 @@ class Frame_ERROR { Frame_ERROR() = default; Frame_ERROR(StreamId streamId, ErrorCode errorCode, Payload payload) - : header_(FrameType::ERROR, payload.getFlags(), streamId), + : header_(FrameType::ERROR, detail::getFlags(payload), streamId), errorCode_(errorCode), payload_(std::move(payload)) {} - static Frame_ERROR unexpectedFrame(); - static Frame_ERROR invalidFrame(); - static Frame_ERROR badSetupFrame(const std::string& message); - static Frame_ERROR rejectedSetup(const std::string& message); - static Frame_ERROR connectionError(const std::string& message); - static Frame_ERROR rejectedResume(const std::string& message); - static Frame_ERROR error(StreamId streamId, Payload&& payload); - static Frame_ERROR applicationError(StreamId streamId, Payload&& payload); + // Connection errors. + static Frame_ERROR invalidSetup(folly::StringPiece); + static Frame_ERROR unsupportedSetup(folly::StringPiece); + static Frame_ERROR rejectedSetup(folly::StringPiece); + static Frame_ERROR rejectedResume(folly::StringPiece); + static Frame_ERROR connectionError(folly::StringPiece); + + // Stream errors. + static Frame_ERROR applicationError(StreamId, folly::StringPiece); + static Frame_ERROR applicationError(StreamId, Payload&&); + static Frame_ERROR rejected(StreamId, folly::StringPiece); + static Frame_ERROR canceled(StreamId, folly::StringPiece); + static Frame_ERROR invalid(StreamId, folly::StringPiece); + + private: + static Frame_ERROR connectionErr(ErrorCode, folly::StringPiece); + static Frame_ERROR streamErr(ErrorCode, folly::StringPiece, StreamId); + public: FrameHeader header_; ErrorCode errorCode_{}; Payload payload_; @@ -340,7 +334,7 @@ class Frame_SETUP { Payload payload) : header_( FrameType::SETUP, - (flags & AllowedFlags) | payload.getFlags(), + (flags & AllowedFlags) | detail::getFlags(payload), 0), versionMajor_(versionMajor), versionMinor_(versionMinor), @@ -350,8 +344,7 @@ class Frame_SETUP { metadataMimeType_(metadataMimeType), dataMimeType_(dataMimeType), payload_(std::move(payload)) { - payload_.checkFlags(header_.flags_); // to verify the client didn't set - // METADATA and provided none + detail::checkFlags(payload_, header_.flags); DCHECK(keepaliveTime_ > 0); DCHECK(maxLifetime_ > 0); DCHECK(keepaliveTime_ <= kMaxKeepaliveTime); @@ -387,7 +380,7 @@ class Frame_LEASE { std::unique_ptr metadata = std::unique_ptr()) : header_( FrameType::LEASE, - metadata ? FrameFlags::METADATA : FrameFlags::EMPTY, + metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_, 0), ttl_(ttl), numberOfRequests_(numberOfRequests), @@ -414,7 +407,7 @@ class Frame_RESUME { ResumePosition lastReceivedServerPosition, ResumePosition clientPosition, ProtocolVersion protocolVersion) - : header_(FrameType::RESUME, FrameFlags::EMPTY, 0), + : header_(FrameType::RESUME, FrameFlags::EMPTY_, 0), versionMajor_(protocolVersion.major), versionMinor_(protocolVersion.minor), token_(token), @@ -435,12 +428,12 @@ class Frame_RESUME_OK { public: Frame_RESUME_OK() = default; explicit Frame_RESUME_OK(ResumePosition position) - : header_(FrameType::RESUME_OK, FrameFlags::EMPTY, 0), + : header_(FrameType::RESUME_OK, FrameFlags::EMPTY_, 0), position_(position) {} FrameHeader header_; ResumePosition position_{}; }; std::ostream& operator<<(std::ostream&, const Frame_RESUME_OK&); -/// @} -} + +} // namespace rsocket diff --git a/rsocket/framing/FrameFlags.cpp b/rsocket/framing/FrameFlags.cpp index df37ad40e..d95399aa1 100644 --- a/rsocket/framing/FrameFlags.cpp +++ b/rsocket/framing/FrameFlags.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameFlags.h" @@ -10,4 +22,4 @@ namespace rsocket { std::ostream& operator<<(std::ostream& os, FrameFlags flags) { return os << std::bitset<16>{raw(flags)}; } -} +} // namespace rsocket diff --git a/rsocket/framing/FrameFlags.h b/rsocket/framing/FrameFlags.h index bedea38c1..7ab7eacf7 100644 --- a/rsocket/framing/FrameFlags.h +++ b/rsocket/framing/FrameFlags.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -6,10 +18,11 @@ #include namespace rsocket { - enum class FrameFlags : uint16_t { - EMPTY = 0x000, - IGNORE = 0x200, + // Note that win32 defines EMPTY and IGNORE so we use a trailing + // underscore to avoid a collision + EMPTY_ = 0x000, + IGNORE_ = 0x200, METADATA = 0x100, // SETUP. @@ -57,6 +70,6 @@ constexpr FrameFlags operator~(FrameFlags a) { return static_cast(~raw(a)); } -std::ostream& operator<<(std::ostream&, FrameFlags); +std::ostream& operator<<(std::ostream& ostr, FrameFlags a); -} +} // namespace rsocket diff --git a/rsocket/framing/FrameHeader.cpp b/rsocket/framing/FrameHeader.cpp new file mode 100644 index 000000000..3ee16dfca --- /dev/null +++ b/rsocket/framing/FrameHeader.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/FrameHeader.h" + +#include +#include +#include + +namespace rsocket { + +namespace { + +using FlagString = std::pair; + +constexpr std::array kMetadata = { + {std::make_pair(FrameFlags::METADATA, "METADATA")}}; +constexpr std::array kKeepaliveRespond = { + {std::make_pair(FrameFlags::KEEPALIVE_RESPOND, "KEEPALIVE_RESPOND")}}; +constexpr std::array kMetadataFollows = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::FOLLOWS, "FOLLOWS")}}; +constexpr std::array kMetadataResumeEnableLease = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::RESUME_ENABLE, "RESUME_ENABLE"), + std::make_pair(FrameFlags::LEASE, "LEASE")}}; +constexpr std::array kMetadataFollowsComplete = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::FOLLOWS, "FOLLOWS"), + std::make_pair(FrameFlags::COMPLETE, "COMPLETE")}}; +constexpr std::array kMetadataFollowsCompleteNext = { + {std::make_pair(FrameFlags::METADATA, "METADATA"), + std::make_pair(FrameFlags::FOLLOWS, "FOLLOWS"), + std::make_pair(FrameFlags::COMPLETE, "COMPLETE"), + std::make_pair(FrameFlags::NEXT, "NEXT")}}; + +template +constexpr auto toRange(const std::array& arr) { + return folly::Range{arr.data(), arr.size()}; +} + +// constexpr -- Old versions of C++ compiler doesn't support +// compound-statements in constexpr function (no switch statement) +folly::Range allowedFlags(FrameType type) { + switch (type) { + case FrameType::SETUP: + return toRange(kMetadataResumeEnableLease); + case FrameType::LEASE: + case FrameType::ERROR: + return toRange(kMetadata); + case FrameType::KEEPALIVE: + return toRange(kKeepaliveRespond); + case FrameType::REQUEST_RESPONSE: + case FrameType::REQUEST_FNF: + case FrameType::REQUEST_STREAM: + return toRange(kMetadataFollows); + case FrameType::REQUEST_CHANNEL: + return toRange(kMetadataFollowsComplete); + case FrameType::PAYLOAD: + return toRange(kMetadataFollowsCompleteNext); + default: + return {}; + } +} + +std::ostream& +writeFlags(std::ostream& os, FrameFlags frameFlags, FrameType frameType) { + FrameFlags foundFlags = FrameFlags::EMPTY_; + + std::string delimiter; + for (const auto& pair : allowedFlags(frameType)) { + if (!!(frameFlags & pair.first)) { + os << delimiter << pair.second; + delimiter = "|"; + foundFlags |= pair.first; + } + } + + if (foundFlags != frameFlags) { + os << frameFlags; + } else if (delimiter.empty()) { + os << "0x00"; + } + return os; +} + +} // namespace + +std::ostream& operator<<(std::ostream& os, const FrameHeader& header) { + os << header.type << "["; + return writeFlags(os, header.flags, header.type) + << ", " << header.streamId << "]"; +} + +} // namespace rsocket diff --git a/rsocket/framing/FrameHeader.h b/rsocket/framing/FrameHeader.h new file mode 100644 index 000000000..cb67c895b --- /dev/null +++ b/rsocket/framing/FrameHeader.h @@ -0,0 +1,52 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/framing/FrameFlags.h" +#include "rsocket/framing/FrameType.h" +#include "rsocket/internal/Common.h" + +namespace rsocket { + +/// Header that begins every RSocket frame. +class FrameHeader { + public: + FrameHeader() {} + + FrameHeader(FrameType ty, FrameFlags fflags, StreamId stream) + : type{ty}, flags{fflags}, streamId{stream} {} + + bool flagsComplete() const { + return !!(flags & FrameFlags::COMPLETE); + } + + bool flagsNext() const { + return !!(flags & FrameFlags::NEXT); + } + + bool flagsFollows() const { + return !!(flags & FrameFlags::FOLLOWS); + } + + FrameType type{FrameType::RESERVED}; + FrameFlags flags{FrameFlags::EMPTY_}; + StreamId streamId{0}; +}; + +std::ostream& operator<<(std::ostream&, const FrameHeader&); + +} // namespace rsocket diff --git a/rsocket/framing/FrameProcessor.h b/rsocket/framing/FrameProcessor.h index e66de9cb0..70c5eae3e 100644 --- a/rsocket/framing/FrameProcessor.h +++ b/rsocket/framing/FrameProcessor.h @@ -1,13 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "rsocket/internal/Common.h" - -namespace folly { -class IOBuf; -class exception_wrapper; -} +#include +#include namespace rsocket { @@ -19,4 +27,4 @@ class FrameProcessor { virtual void onTerminal(folly::exception_wrapper) = 0; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer.cpp b/rsocket/framing/FrameSerializer.cpp index 6221199ee..92904944b 100644 --- a/rsocket/framing/FrameSerializer.cpp +++ b/rsocket/framing/FrameSerializer.cpp @@ -1,52 +1,25 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameSerializer.h" - -#include -#include - -#include "rsocket/framing/FrameSerializer_v0.h" -#include "rsocket/framing/FrameSerializer_v0_1.h" #include "rsocket/framing/FrameSerializer_v1_0.h" -DEFINE_string( - rs_use_protocol_version, - "", - "override for the ReactiveSocket protocol version to be used" - " [MAJOR.MINOR]."); - namespace rsocket { -constexpr const ProtocolVersion ProtocolVersion::Latest = - FrameSerializerV1_0::Version; - -ProtocolVersion FrameSerializer::getCurrentProtocolVersion() { - if (FLAGS_rs_use_protocol_version.empty()) { - return ProtocolVersion::Latest; - } - - if (FLAGS_rs_use_protocol_version == "*") { - return ProtocolVersion::Unknown; - } - - if (FLAGS_rs_use_protocol_version.size() != 3) { - LOG(ERROR) << "unknown protocol version " << FLAGS_rs_use_protocol_version - << " defaulting to v" << ProtocolVersion::Latest; - return ProtocolVersion::Latest; - } - - return ProtocolVersion( - folly::to(FLAGS_rs_use_protocol_version[0] - '0'), - folly::to(FLAGS_rs_use_protocol_version[2] - '0')); -} - std::unique_ptr FrameSerializer::createFrameSerializer( const ProtocolVersion& protocolVersion) { - if (protocolVersion == FrameSerializerV0::Version) { - return std::make_unique(); - } else if (protocolVersion == FrameSerializerV0_1::Version) { - return std::make_unique(); - } else if (protocolVersion == FrameSerializerV1_0::Version) { + if (protocolVersion == FrameSerializerV1_0::Version) { return std::make_unique(); } @@ -56,21 +29,37 @@ std::unique_ptr FrameSerializer::createFrameSerializer( return nullptr; } -std::unique_ptr FrameSerializer::createCurrentVersion() { - return createFrameSerializer(getCurrentProtocolVersion()); -} - std::unique_ptr FrameSerializer::createAutodetectedSerializer( const folly::IOBuf& firstFrame) { auto detectedVersion = FrameSerializerV1_0::detectProtocolVersion(firstFrame); - if (detectedVersion == ProtocolVersion::Unknown) { - detectedVersion = FrameSerializerV0_1::detectProtocolVersion(firstFrame); - } return createFrameSerializer(detectedVersion); } -std::ostream& operator<<(std::ostream& os, const ProtocolVersion& version) { - return os << version.major << "." << version.minor; +bool& FrameSerializer::preallocateFrameSizeField() { + return preallocateFrameSizeField_; +} + +folly::IOBufQueue FrameSerializer::createBufferQueue(size_t bufferSize) const { + const auto prependSize = + preallocateFrameSizeField_ ? frameLengthFieldSize() : 0; + auto buf = folly::IOBuf::createCombined(bufferSize + prependSize); + buf->advance(prependSize); + folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); + queue.append(std::move(buf)); + return queue; +} + +folly::Optional FrameSerializer::peekStreamId( + const ProtocolVersion& protocolVersion, + const folly::IOBuf& frame, + bool skipFrameLengthBytes) { + if (protocolVersion == FrameSerializerV1_0::Version) { + return FrameSerializerV1_0().peekStreamId(frame, skipFrameLengthBytes); + } + + auto* msg = "unknown protocol version"; + DCHECK(false) << msg; + return folly::none; } -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer.h b/rsocket/framing/FrameSerializer.h index 8cb86272d..7ee0bafae 100644 --- a/rsocket/framing/FrameSerializer.h +++ b/rsocket/framing/FrameSerializer.h @@ -1,11 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include + #include -#include + #include "rsocket/framing/Frame.h" namespace rsocket { @@ -15,76 +27,89 @@ class FrameSerializer { public: virtual ~FrameSerializer() = default; - virtual ProtocolVersion protocolVersion() = 0; + virtual ProtocolVersion protocolVersion() const = 0; - static ProtocolVersion getCurrentProtocolVersion(); static std::unique_ptr createFrameSerializer( const ProtocolVersion& protocolVersion); - static std::unique_ptr createCurrentVersion(); static std::unique_ptr createAutodetectedSerializer( const folly::IOBuf& firstFrame); - virtual FrameType peekFrameType(const folly::IOBuf& in) = 0; - virtual folly::Optional peekStreamId(const folly::IOBuf& in) = 0; + static folly::Optional peekStreamId( + const ProtocolVersion& protocolVersion, + const folly::IOBuf& frame, + bool skipFrameLengthBytes); + virtual FrameType peekFrameType(const folly::IOBuf& in) const = 0; + virtual folly::Optional peekStreamId( + const folly::IOBuf& in, + bool skipFrameLengthBytes) const = 0; + + virtual std::unique_ptr serializeOut( + Frame_REQUEST_STREAM&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_REQUEST_STREAM&&) = 0; + Frame_REQUEST_CHANNEL&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_REQUEST_CHANNEL&&) = 0; + Frame_REQUEST_RESPONSE&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_REQUEST_RESPONSE&&) = 0; - virtual std::unique_ptr serializeOut(Frame_REQUEST_FNF&&) = 0; - virtual std::unique_ptr serializeOut(Frame_REQUEST_N&&) = 0; - virtual std::unique_ptr serializeOut(Frame_METADATA_PUSH&&) = 0; - virtual std::unique_ptr serializeOut(Frame_CANCEL&&) = 0; - virtual std::unique_ptr serializeOut(Frame_PAYLOAD&&) = 0; - virtual std::unique_ptr serializeOut(Frame_ERROR&&) = 0; + Frame_REQUEST_FNF&&) const = 0; virtual std::unique_ptr serializeOut( - Frame_KEEPALIVE&&, - bool) = 0; - virtual std::unique_ptr serializeOut(Frame_SETUP&&) = 0; - virtual std::unique_ptr serializeOut(Frame_LEASE&&) = 0; - virtual std::unique_ptr serializeOut(Frame_RESUME&&) = 0; - virtual std::unique_ptr serializeOut(Frame_RESUME_OK&&) = 0; + Frame_REQUEST_N&&) const = 0; + virtual std::unique_ptr serializeOut( + Frame_METADATA_PUSH&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_CANCEL&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_PAYLOAD&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_ERROR&&) const = 0; + virtual std::unique_ptr serializeOut( + Frame_KEEPALIVE&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_SETUP&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_LEASE&&) const = 0; + virtual std::unique_ptr serializeOut(Frame_RESUME&&) const = 0; + virtual std::unique_ptr serializeOut( + Frame_RESUME_OK&&) const = 0; virtual bool deserializeFrom( Frame_REQUEST_STREAM&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; virtual bool deserializeFrom( Frame_REQUEST_CHANNEL&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; virtual bool deserializeFrom( Frame_REQUEST_RESPONSE&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; virtual bool deserializeFrom( Frame_REQUEST_FNF&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_REQUEST_N&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; + virtual bool deserializeFrom(Frame_REQUEST_N&, std::unique_ptr) + const = 0; virtual bool deserializeFrom( Frame_METADATA_PUSH&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_CANCEL&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_PAYLOAD&, - std::unique_ptr) = 0; - virtual bool deserializeFrom(Frame_ERROR&, std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_KEEPALIVE&, - std::unique_ptr, - bool supportsResumability) = 0; - virtual bool deserializeFrom(Frame_SETUP&, std::unique_ptr) = 0; - virtual bool deserializeFrom(Frame_LEASE&, std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_RESUME&, - std::unique_ptr) = 0; - virtual bool deserializeFrom( - Frame_RESUME_OK&, - std::unique_ptr) = 0; + std::unique_ptr) const = 0; + virtual bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_ERROR&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_SETUP&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_LEASE&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_RESUME&, std::unique_ptr) + const = 0; + virtual bool deserializeFrom(Frame_RESUME_OK&, std::unique_ptr) + const = 0; + + virtual size_t frameLengthFieldSize() const = 0; + bool& preallocateFrameSizeField(); + + protected: + folly::IOBufQueue createBufferQueue(size_t bufferSize) const; + + private: + bool preallocateFrameSizeField_{false}; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer_v0.cpp b/rsocket/framing/FrameSerializer_v0.cpp deleted file mode 100644 index cd4b1c1e0..000000000 --- a/rsocket/framing/FrameSerializer_v0.cpp +++ /dev/null @@ -1,775 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/framing/FrameSerializer_v0.h" - -#include - -namespace rsocket { - -constexpr const ProtocolVersion FrameSerializerV0::Version; -constexpr const size_t FrameSerializerV0::kFrameHeaderSize; // bytes - -namespace { -constexpr static const auto kMaxMetadataLength = 0xFFFFFF; // 24bit max value - -enum class FrameType_V0 : uint16_t { - RESERVED = 0x0000, - SETUP = 0x0001, - LEASE = 0x0002, - KEEPALIVE = 0x0003, - REQUEST_RESPONSE = 0x0004, - REQUEST_FNF = 0x0005, - REQUEST_STREAM = 0x0006, - REQUEST_SUB = 0x0007, - REQUEST_CHANNEL = 0x0008, - REQUEST_N = 0x0009, - CANCEL = 0x000A, - RESPONSE = 0x000B, - ERROR = 0x000C, - METADATA_PUSH = 0x000D, - RESUME = 0x000E, - RESUME_OK = 0x000F, - EXT = 0xFFFF, -}; - -enum class FrameFlags_V0 : uint16_t { - EMPTY = 0x0000, - IGNORE = 0x8000, - METADATA = 0x4000, - - FOLLOWS = 0x2000, - KEEPALIVE_RESPOND = 0x2000, - LEASE = 0x2000, - COMPLETE = 0x1000, - RESUME_ENABLE = 0x0800, -}; - -constexpr inline FrameFlags_V0 operator&(FrameFlags_V0 a, FrameFlags_V0 b) { - return static_cast( - static_cast(a) & static_cast(b)); -} - -inline uint16_t& operator|=(uint16_t& a, FrameFlags_V0 b) { - return (a |= static_cast(b)); -} - -constexpr inline bool operator!(FrameFlags_V0 a) { - return !static_cast(a); -} -} // namespace - -static folly::IOBufQueue createBufferQueue(size_t bufferSize) { - auto buf = rsocket::FrameBufferAllocator::allocate(bufferSize); - folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); - queue.append(std::move(buf)); - return queue; -} - -ProtocolVersion FrameSerializerV0::protocolVersion() { - return Version; -} - -static uint16_t serializeFrameType(FrameType frameType) { - switch (frameType) { - case FrameType::RESERVED: - case FrameType::SETUP: - case FrameType::LEASE: - case FrameType::KEEPALIVE: - case FrameType::REQUEST_RESPONSE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_STREAM: - return static_cast(frameType); - - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_N: - case FrameType::CANCEL: - case FrameType::PAYLOAD: - case FrameType::ERROR: - case FrameType::METADATA_PUSH: - case FrameType::RESUME: - case FrameType::RESUME_OK: - return static_cast(frameType) + 1; - - case FrameType::EXT: - return static_cast(FrameType_V0::EXT); - - default: - CHECK(false); - return 0; - } -} - -static FrameType deserializeFrameType(uint16_t frameType) { - if (frameType > static_cast(FrameType_V0::RESUME_OK) && - frameType != static_cast(FrameType_V0::EXT)) { - return FrameType::RESERVED; - } - - switch (static_cast(frameType)) { - case FrameType_V0::RESERVED: - case FrameType_V0::SETUP: - case FrameType_V0::LEASE: - case FrameType_V0::KEEPALIVE: - case FrameType_V0::REQUEST_RESPONSE: - case FrameType_V0::REQUEST_FNF: - case FrameType_V0::REQUEST_STREAM: - return static_cast(frameType); - - case FrameType_V0::REQUEST_SUB: - return FrameType::REQUEST_STREAM; - - case FrameType_V0::REQUEST_CHANNEL: - case FrameType_V0::REQUEST_N: - case FrameType_V0::CANCEL: - case FrameType_V0::RESPONSE: - case FrameType_V0::ERROR: - case FrameType_V0::METADATA_PUSH: - case FrameType_V0::RESUME: - case FrameType_V0::RESUME_OK: - return static_cast(frameType - 1); - - case FrameType_V0::EXT: - return FrameType::EXT; - - default: - CHECK(false); - return FrameType::RESERVED; - } -} - -static uint16_t serializeFrameFlags(FrameFlags frameType) { - uint16_t result = 0; - if (!!(frameType & FrameFlags::IGNORE)) { - result |= FrameFlags_V0::IGNORE; - } - if (!!(frameType & FrameFlags::METADATA)) { - result |= FrameFlags_V0::METADATA; - } - return result; -} - -static FrameFlags deserializeFrameFlags(FrameFlags_V0 flags) { - FrameFlags result = FrameFlags::EMPTY; - - if (!!(flags & FrameFlags_V0::IGNORE)) { - result |= FrameFlags::IGNORE; - } - if (!!(flags & FrameFlags_V0::METADATA)) { - result |= FrameFlags::METADATA; - } - return result; -} - -static void serializeHeaderInto( - folly::io::QueueAppender& appender, - const FrameHeader& header, - uint16_t extraFlags) { - appender.writeBE(serializeFrameType(header.type_)); - appender.writeBE(serializeFrameFlags(header.flags_) | extraFlags); - appender.writeBE(header.streamId_); -} - -static void deserializeHeaderFrom( - folly::io::Cursor& cur, - FrameHeader& header, - FrameFlags_V0& flags) { - header.type_ = deserializeFrameType(cur.readBE()); - - flags = static_cast(cur.readBE()); - header.flags_ = deserializeFrameFlags(flags); - - header.streamId_ = cur.readBE(); -} - -static void serializeMetadataInto( - folly::io::QueueAppender& appender, - std::unique_ptr metadata) { - if (metadata == nullptr) { - return; - } - - // Use signed int because the first bit in metadata length is reserved. - if (metadata->length() >= kMaxMetadataLength - sizeof(uint32_t)) { - CHECK(false) << "Metadata is too big to serialize"; - } - - appender.writeBE( - static_cast(metadata->length() + sizeof(uint32_t))); - appender.insert(std::move(metadata)); -} - -std::unique_ptr FrameSerializerV0::deserializeMetadataFrom( - folly::io::Cursor& cur, - FrameFlags flags) { - if (!(flags & FrameFlags::METADATA)) { - return nullptr; - } - - const auto length = cur.readBE(); - - if (length >= kMaxMetadataLength) { - throw std::runtime_error("Metadata is too big to deserialize"); - } - - if (length <= sizeof(uint32_t)) { - throw std::runtime_error("Metadata is too small to encode its size"); - } - - const auto metadataPayloadLength = - length - static_cast(sizeof(uint32_t)); - - // TODO: Check if metadataPayloadLength exceeds frame length minus frame - // header size. - - std::unique_ptr metadata; - cur.clone(metadata, metadataPayloadLength); - return metadata; -} - -static std::unique_ptr deserializeDataFrom( - folly::io::Cursor& cur) { - std::unique_ptr data; - auto totalLength = cur.totalLength(); - - if (totalLength > 0) { - cur.clone(data, totalLength); - } - return data; -} - -static Payload deserializePayloadFrom( - folly::io::Cursor& cur, - FrameFlags flags) { - auto metadata = FrameSerializerV0::deserializeMetadataFrom(cur, flags); - auto data = deserializeDataFrom(cur); - return Payload(std::move(data), std::move(metadata)); -} - -static void serializePayloadInto( - folly::io::QueueAppender& appender, - Payload&& payload) { - serializeMetadataInto(appender, std::move(payload.metadata)); - if (payload.data) { - appender.insert(std::move(payload.data)); - } -} - -static uint32_t payloadFramingSize(const Payload& payload) { - return (payload.metadata != nullptr ? sizeof(uint32_t) : 0); -} - -static std::unique_ptr serializeOutInternal( - Frame_REQUEST_Base&& frame) { - auto queue = createBufferQueue( - FrameSerializerV0::kFrameHeaderSize + sizeof(uint32_t) + - payloadFramingSize(frame.payload_)); - uint16_t extraFlags = 0; - if (!!(frame.header_.flags_ & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - if (!!(frame.header_.flags_ & FrameFlags::COMPLETE)) { - extraFlags |= FrameFlags_V0::COMPLETE; - } - - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - - appender.writeBE(frame.requestN_); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -static bool deserializeFromInternal( - Frame_REQUEST_Base& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags_ |= FrameFlags::FOLLOWS; - } - if (!!(flags & FrameFlags_V0::COMPLETE)) { - frame.header_.flags_ |= FrameFlags::COMPLETE; - } - - frame.requestN_ = cur.readBE(); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -FrameType FrameSerializerV0::peekFrameType(const folly::IOBuf& in) { - folly::io::Cursor cur(&in); - try { - return deserializeFrameType(cur.readBE()); - } catch (...) { - return FrameType::RESERVED; - } -} - -folly::Optional FrameSerializerV0::peekStreamId( - const folly::IOBuf& in) { - folly::io::Cursor cur(&in); - try { - cur.skip(sizeof(uint16_t)); // type - cur.skip(sizeof(uint16_t)); // flags - return folly::make_optional(cur.readBE()); - } catch (...) { - return folly::none; - } -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_STREAM&& frame) { - return serializeOutInternal(std::move(frame)); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_CHANNEL&& frame) { - return serializeOutInternal(std::move(frame)); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_RESPONSE&& frame) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags_ & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - - auto queue = - createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_FNF&& frame) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags_ & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - - auto queue = - createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_REQUEST_N&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(uint32_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(frame.requestN_); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_METADATA_PUSH&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(uint32_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - serializeMetadataInto(appender, std::move(frame.metadata_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_CANCEL&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_PAYLOAD&& frame) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags_ & FrameFlags::FOLLOWS)) { - extraFlags |= FrameFlags_V0::FOLLOWS; - } - if (!!(frame.header_.flags_ & FrameFlags::COMPLETE)) { - extraFlags |= FrameFlags_V0::COMPLETE; - } - - auto queue = - createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_ERROR&& frame) { - auto queue = createBufferQueue( - kFrameHeaderSize + sizeof(uint32_t) + payloadFramingSize(frame.payload_)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(static_cast(frame.errorCode_)); - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_KEEPALIVE&& frame, - bool resumeable) { - uint16_t extraFlags = 0; - if (!!(frame.header_.flags_ & FrameFlags::KEEPALIVE_RESPOND)) { - extraFlags |= FrameFlags_V0::KEEPALIVE_RESPOND; - } - - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, extraFlags); - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (resumeable) { - appender.writeBE(frame.position_); - } - if (frame.data_) { - appender.insert(std::move(frame.data_)); - } - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_SETUP&& frame) { - auto queue = createBufferQueue( - kFrameHeaderSize + 3 * sizeof(uint32_t) + frame.token_.data().size() + 2 + - frame.metadataMimeType_.length() + frame.dataMimeType_.length() + - payloadFramingSize(frame.payload_)); - uint16_t extraFlags = 0; - if (!!(frame.header_.flags_ & FrameFlags::RESUME_ENABLE)) { - extraFlags |= FrameFlags_V0::RESUME_ENABLE; - } - if (!!(frame.header_.flags_ & FrameFlags::LEASE)) { - extraFlags |= FrameFlags_V0::LEASE; - } - - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - - serializeHeaderInto(appender, frame.header_, extraFlags); - CHECK( - frame.versionMajor_ != ProtocolVersion::Unknown.major || - frame.versionMinor_ != ProtocolVersion::Unknown.minor); - appender.writeBE(static_cast(frame.versionMajor_)); - appender.writeBE(static_cast(frame.versionMinor_)); - appender.writeBE(static_cast(frame.keepaliveTime_)); - appender.writeBE(static_cast(frame.maxLifetime_)); - - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (!!(frame.header_.flags_ & FrameFlags::RESUME_ENABLE)) { - appender.push(frame.token_.data().data(), frame.token_.data().size()); - } - - CHECK( - frame.metadataMimeType_.length() <= std::numeric_limits::max()); - appender.writeBE(static_cast(frame.metadataMimeType_.length())); - appender.push( - reinterpret_cast(frame.metadataMimeType_.data()), - frame.metadataMimeType_.length()); - - CHECK(frame.dataMimeType_.length() <= std::numeric_limits::max()); - appender.writeBE(static_cast(frame.dataMimeType_.length())); - appender.push( - reinterpret_cast(frame.dataMimeType_.data()), - frame.dataMimeType_.length()); - - serializePayloadInto(appender, std::move(frame.payload_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_LEASE&& frame) { - auto queue = createBufferQueue( - kFrameHeaderSize + 3 * 2 * sizeof(uint32_t) + - (frame.metadata_ ? sizeof(uint32_t) : 0)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(static_cast(frame.ttl_)); - appender.writeBE(static_cast(frame.numberOfRequests_)); - serializeMetadataInto(appender, std::move(frame.metadata_)); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_RESUME&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + 16 + sizeof(int64_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - CHECK(frame.token_.data().size() <= 16); - appender.push(frame.token_.data().data(), frame.token_.data().size()); - appender.writeBE(frame.lastReceivedServerPosition_); - return queue.move(); -} - -std::unique_ptr FrameSerializerV0::serializeOut( - Frame_RESUME_OK&& frame) { - auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); - folly::io::QueueAppender appender(&queue, /* do not grow */ 0); - serializeHeaderInto(appender, frame.header_, /*extraFlags=*/0); - appender.writeBE(frame.position_); - return queue.move(); -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_STREAM& frame, - std::unique_ptr in) { - return deserializeFromInternal(frame, std::move(in)); -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_CHANNEL& frame, - std::unique_ptr in) { - return deserializeFromInternal(frame, std::move(in)); -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_RESPONSE& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags_ |= FrameFlags::FOLLOWS; - } - - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_FNF& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags_ |= FrameFlags::FOLLOWS; - } - - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_REQUEST_N& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.requestN_ = cur.readBE(); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_METADATA_PUSH& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.metadata_ = deserializeMetadataFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return frame.metadata_ != nullptr; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_CANCEL& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_PAYLOAD& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::FOLLOWS)) { - frame.header_.flags_ |= FrameFlags::FOLLOWS; - } - if (!!(flags & FrameFlags_V0::COMPLETE)) { - frame.header_.flags_ |= FrameFlags::COMPLETE; - } - - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_ERROR& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.errorCode_ = static_cast(cur.readBE()); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_KEEPALIVE& frame, - std::unique_ptr in, - bool resumable) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::KEEPALIVE_RESPOND)) { - frame.header_.flags_ |= FrameFlags::KEEPALIVE_RESPOND; - } - - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (resumable) { - frame.position_ = cur.readBE(); - } else { - frame.position_ = 0; - } - frame.data_ = deserializeDataFrom(cur); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_SETUP& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - - if (!!(flags & FrameFlags_V0::RESUME_ENABLE)) { - frame.header_.flags_ |= FrameFlags::RESUME_ENABLE; - } - if (!!(flags & FrameFlags_V0::LEASE)) { - frame.header_.flags_ |= FrameFlags::LEASE; - } - - frame.versionMajor_ = cur.readBE(); - frame.versionMinor_ = cur.readBE(); - frame.keepaliveTime_ = - std::min(cur.readBE(), Frame_SETUP::kMaxKeepaliveTime); - frame.maxLifetime_ = - std::min(cur.readBE(), Frame_SETUP::kMaxLifetime); - - // TODO: Remove hack: - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/243 - if (!!(frame.header_.flags_ & FrameFlags::RESUME_ENABLE)) { - std::vector data(16); - cur.pull(data.data(), data.size()); - frame.token_.set(std::move(data)); - } else { - frame.token_ = ResumeIdentificationToken(); - } - - auto mdmtLen = cur.readBE(); - frame.metadataMimeType_ = cur.readFixedString(mdmtLen); - - auto dmtLen = cur.readBE(); - frame.dataMimeType_ = cur.readFixedString(dmtLen); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_LEASE& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.ttl_ = std::min(cur.readBE(), Frame_LEASE::kMaxTtl); - frame.numberOfRequests_ = - std::min(cur.readBE(), Frame_LEASE::kMaxNumRequests); - frame.metadata_ = deserializeMetadataFrom(cur, frame.header_.flags_); - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_RESUME& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - std::vector data(16); - cur.pull(data.data(), data.size()); - auto protocolVer = protocolVersion(); - frame.versionMajor_ = protocolVer.major; - frame.versionMinor_ = protocolVer.minor; - frame.token_.set(std::move(data)); - frame.lastReceivedServerPosition_ = cur.readBE(); - frame.clientPosition_ = kUnspecifiedResumePosition; - } catch (...) { - return false; - } - return true; -} - -bool FrameSerializerV0::deserializeFrom( - Frame_RESUME_OK& frame, - std::unique_ptr in) { - folly::io::Cursor cur(in.get()); - try { - FrameFlags_V0 flags; - deserializeHeaderFrom(cur, frame.header_, flags); - frame.position_ = cur.readBE(); - } catch (...) { - return false; - } - return true; -} - -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v0.h b/rsocket/framing/FrameSerializer_v0.h deleted file mode 100644 index 74e356351..000000000 --- a/rsocket/framing/FrameSerializer_v0.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/framing/FrameSerializer.h" - -namespace rsocket { - -class FrameSerializerV0 : public FrameSerializer { - public: - constexpr static const ProtocolVersion Version = ProtocolVersion(0, 0); - constexpr static const size_t kFrameHeaderSize = 8; // bytes - - ProtocolVersion protocolVersion() override; - - FrameType peekFrameType(const folly::IOBuf& in) override; - folly::Optional peekStreamId(const folly::IOBuf& in) override; - - std::unique_ptr serializeOut(Frame_REQUEST_STREAM&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_CHANNEL&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_RESPONSE&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_FNF&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_N&&) override; - std::unique_ptr serializeOut(Frame_METADATA_PUSH&&) override; - std::unique_ptr serializeOut(Frame_CANCEL&&) override; - std::unique_ptr serializeOut(Frame_PAYLOAD&&) override; - std::unique_ptr serializeOut(Frame_ERROR&&) override; - std::unique_ptr serializeOut(Frame_KEEPALIVE&&, bool) override; - std::unique_ptr serializeOut(Frame_SETUP&&) override; - std::unique_ptr serializeOut(Frame_LEASE&&) override; - std::unique_ptr serializeOut(Frame_RESUME&&) override; - std::unique_ptr serializeOut(Frame_RESUME_OK&&) override; - - bool deserializeFrom(Frame_REQUEST_STREAM&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_CHANNEL&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_RESPONSE&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_FNF&, std::unique_ptr) - override; - bool deserializeFrom(Frame_REQUEST_N&, std::unique_ptr) - override; - bool deserializeFrom(Frame_METADATA_PUSH&, std::unique_ptr) - override; - bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) override; - bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) override; - bool deserializeFrom(Frame_ERROR&, std::unique_ptr) override; - bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr, bool) - override; - bool deserializeFrom(Frame_SETUP&, std::unique_ptr) override; - bool deserializeFrom(Frame_LEASE&, std::unique_ptr) override; - bool deserializeFrom(Frame_RESUME&, std::unique_ptr) override; - bool deserializeFrom(Frame_RESUME_OK&, std::unique_ptr) - override; - - static std::unique_ptr deserializeMetadataFrom( - folly::io::Cursor& cur, - FrameFlags flags); -}; -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v0_1.cpp b/rsocket/framing/FrameSerializer_v0_1.cpp deleted file mode 100644 index b42322437..000000000 --- a/rsocket/framing/FrameSerializer_v0_1.cpp +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/framing/FrameSerializer_v0_1.h" - -#include - -namespace rsocket { - -constexpr const ProtocolVersion FrameSerializerV0_1::Version; -constexpr const size_t FrameSerializerV0_1::kMinBytesNeededForAutodetection; - -ProtocolVersion FrameSerializerV0_1::protocolVersion() { - return Version; -} - -ProtocolVersion FrameSerializerV0_1::detectProtocolVersion( - const folly::IOBuf& firstFrame, - size_t skipBytes) { - // SETUP frame - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Frame Type = SETUP |0|M|L|S| Flags | - // +-------------------------------+-+-+-+-+-----------------------+ - // | Stream ID = 0 | - // +-------------------------------+-------------------------------+ - // | Major Version | Minor Version | - // +-------------------------------+-------------------------------+ - // ... - // +-------------------------------+-------------------------------+ - - // RESUME frame - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // | Frame Type = RESUME | Flags | - // +-------------------------------+-------------------------------+ - // | Stream ID = 0 | - // +-------------------------------+-------------------------------+ - // | | - // | Resume Identification Token | - // | | - // | | - // +-------------------------------+-------------------------------+ - // | Resume Position | - // | | - // +-------------------------------+-------------------------------+ - - folly::io::Cursor cur(&firstFrame); - try { - cur.skip(skipBytes); - - auto frameType = cur.readBE(); - cur.skip(sizeof(uint16_t)); // flags - auto streamId = cur.readBE(); - - constexpr static const auto kSETUP = 0x0001; - constexpr static const auto kRESUME = 0x000E; - - VLOG(4) << "frameType=" << frameType << "streamId=" << streamId; - - if (frameType == kSETUP && streamId == 0) { - auto majorVersion = cur.readBE(); - auto minorVersion = cur.readBE(); - - VLOG(4) << "majorVersion=" << majorVersion - << " minorVersion=" << minorVersion; - - if (majorVersion == 0 && (minorVersion == 0 || minorVersion == 1)) { - return ProtocolVersion(majorVersion, minorVersion); - } - } else if (frameType == kRESUME && streamId == 0) { - return FrameSerializerV0_1::Version; - } - } catch (...) { - } - return ProtocolVersion::Unknown; -} - -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v0_1.h b/rsocket/framing/FrameSerializer_v0_1.h deleted file mode 100644 index c53774f8b..000000000 --- a/rsocket/framing/FrameSerializer_v0_1.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/framing/FrameSerializer_v0.h" - -namespace rsocket { - -class FrameSerializerV0_1 : public FrameSerializerV0 { - public: - constexpr static const ProtocolVersion Version = ProtocolVersion(0, 1); - constexpr static const size_t kMinBytesNeededForAutodetection = 12; // bytes - - static ProtocolVersion detectProtocolVersion( - const folly::IOBuf& firstFrame, - size_t skipBytes = 0); - - ProtocolVersion protocolVersion() override; -}; -} // reactivesocket diff --git a/rsocket/framing/FrameSerializer_v1_0.cpp b/rsocket/framing/FrameSerializer_v1_0.cpp index f76764f7a..446246d11 100644 --- a/rsocket/framing/FrameSerializer_v1_0.cpp +++ b/rsocket/framing/FrameSerializer_v1_0.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameSerializer_v1_0.h" @@ -11,21 +23,14 @@ constexpr const size_t FrameSerializerV1_0::kFrameHeaderSize; constexpr const size_t FrameSerializerV1_0::kMinBytesNeededForAutodetection; namespace { -constexpr const auto kMedatadaLengthSize = 3; // bytes -constexpr const auto kMaxMetadataLength = 0xFFFFFF; // 24bit max value +constexpr const uint32_t kMedatadaLengthSize = 3u; // bytes +constexpr const uint32_t kMaxMetadataLength = 0xFFFFFFu; // 24bit max value } // namespace -ProtocolVersion FrameSerializerV1_0::protocolVersion() { +ProtocolVersion FrameSerializerV1_0::protocolVersion() const { return Version; } -static folly::IOBufQueue createBufferQueue(size_t bufferSize) { - auto buf = rsocket::FrameBufferAllocator::allocate(bufferSize); - folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength()); - queue.append(std::move(buf)); - return queue; -} - static FrameType deserializeFrameType(uint16_t frameType) { if (frameType > static_cast(FrameType::RESUME_OK) && frameType != static_cast(FrameType::EXT)) { @@ -37,12 +42,12 @@ static FrameType deserializeFrameType(uint16_t frameType) { static void serializeHeaderInto( folly::io::QueueAppender& appender, const FrameHeader& header) { - appender.writeBE(static_cast(header.streamId_)); + appender.writeBE(static_cast(header.streamId)); - auto type = static_cast(header.type_); // 6 bit - auto flags = static_cast(header.flags_); // 10 bit - appender.writeBE(static_cast((type << 2) | (flags >> 8))); - appender.writeBE(static_cast(flags)); // lower 8 bits + auto type = static_cast(header.type); // 6 bit + auto flags = static_cast(header.flags); // 10 bit + appender.write(static_cast((type << 2) | (flags >> 8))); + appender.write(static_cast(flags)); // lower 8 bits } static void deserializeHeaderFrom(folly::io::Cursor& cur, FrameHeader& header) { @@ -50,10 +55,10 @@ static void deserializeHeaderFrom(folly::io::Cursor& cur, FrameHeader& header) { if (streamId < 0) { throw std::runtime_error("invalid stream id"); } - header.streamId_ = static_cast(streamId); + header.streamId = static_cast(streamId); uint16_t type = cur.readBE(); // |Frame Type |I|M| - header.type_ = deserializeFrameType(type >> 2); - header.flags_ = + header.type = deserializeFrameType(type >> 2); + header.flags = static_cast(((type & 0x3) << 8) | cur.readBE()); } @@ -64,13 +69,12 @@ static void serializeMetadataInto( return; } - // Use signed int because the first bit in metadata length is reserved. - if (metadata->length() > kMaxMetadataLength) { - CHECK(false) << "Metadata is too big to serialize"; - } - // metadata length field not included in the medatadata length - uint32_t metadataLength = static_cast(metadata->length()); + uint32_t metadataLength = + static_cast(metadata->computeChainDataLength()); + CHECK_LT(metadataLength, kMaxMetadataLength) + << "Metadata is too big to serialize"; + appender.write(static_cast(metadataLength >> 16)); // first byte appender.write( static_cast((metadataLength >> 8) & 0xFF)); // second byte @@ -91,9 +95,8 @@ std::unique_ptr FrameSerializerV1_0::deserializeMetadataFrom( metadataLength |= static_cast(cur.read() << 8); metadataLength |= cur.read(); - if (metadataLength > kMaxMetadataLength) { - throw std::runtime_error("Metadata is too big to deserialize"); - } + CHECK_LE(metadataLength, kMaxMetadataLength) + << "Read out the 24-bit integer incorrectly somehow"; std::unique_ptr metadata; cur.clone(metadata, metadataLength); @@ -132,8 +135,8 @@ static uint32_t payloadFramingSize(const Payload& payload) { return (payload.metadata != nullptr ? kMedatadaLengthSize : 0); } -static std::unique_ptr serializeOutInternal( - Frame_REQUEST_Base&& frame) { +std::unique_ptr FrameSerializerV1_0::serializeOutInternal( + Frame_REQUEST_Base&& frame) const { auto queue = createBufferQueue( FrameSerializerV1_0::kFrameHeaderSize + sizeof(uint32_t) + payloadFramingSize(frame.payload_)); @@ -159,7 +162,7 @@ static bool deserializeFromInternal( throw std::runtime_error("invalid request N"); } frame.requestN_ = static_cast(requestN); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); + frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); } catch (...) { return false; } @@ -174,7 +177,7 @@ static size_t getResumeIdTokenFramingLength( : 0; } -FrameType FrameSerializerV1_0::peekFrameType(const folly::IOBuf& in) { +FrameType FrameSerializerV1_0::peekFrameType(const folly::IOBuf& in) const { folly::io::Cursor cur(&in); try { cur.skip(sizeof(int32_t)); // streamId @@ -186,9 +189,13 @@ FrameType FrameSerializerV1_0::peekFrameType(const folly::IOBuf& in) { } folly::Optional FrameSerializerV1_0::peekStreamId( - const folly::IOBuf& in) { + const folly::IOBuf& in, + bool skipFrameLengthBytes) const { folly::io::Cursor cur(&in); try { + if (skipFrameLengthBytes) { + cur.skip(3); // skip 3 bytes for frame length + } auto streamId = cur.readBE(); if (streamId < 0) { return folly::none; @@ -200,17 +207,17 @@ folly::Optional FrameSerializerV1_0::peekStreamId( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_STREAM&& frame) { + Frame_REQUEST_STREAM&& frame) const { return serializeOutInternal(std::move(frame)); } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_CHANNEL&& frame) { + Frame_REQUEST_CHANNEL&& frame) const { return serializeOutInternal(std::move(frame)); } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_RESPONSE&& frame) { + Frame_REQUEST_RESPONSE&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -220,7 +227,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_FNF&& frame) { + Frame_REQUEST_FNF&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -230,7 +237,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_REQUEST_N&& frame) { + Frame_REQUEST_N&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(uint32_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -239,7 +246,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_METADATA_PUSH&& frame) { + Frame_METADATA_PUSH&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -250,7 +257,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_CANCEL&& frame) { + Frame_CANCEL&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -258,7 +265,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_PAYLOAD&& frame) { + Frame_PAYLOAD&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -268,7 +275,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_ERROR&& frame) { + Frame_ERROR&& frame) const { auto queue = createBufferQueue( kFrameHeaderSize + sizeof(uint32_t) + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -279,8 +286,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_KEEPALIVE&& frame, - bool /*resumeable*/) { + Frame_KEEPALIVE&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -292,11 +298,11 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_SETUP&& frame) { + Frame_SETUP&& frame) const { auto queue = createBufferQueue( kFrameHeaderSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(int32_t) + sizeof(int32_t) + - getResumeIdTokenFramingLength(frame.header_.flags_, frame.token_) + + getResumeIdTokenFramingLength(frame.header_.flags, frame.token_) + +sizeof(uint8_t) + frame.metadataMimeType_.length() + sizeof(uint8_t) + frame.dataMimeType_.length() + payloadFramingSize(frame.payload_)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -310,7 +316,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( appender.writeBE(static_cast(frame.keepaliveTime_)); appender.writeBE(static_cast(frame.maxLifetime_)); - if (!!(frame.header_.flags_ & FrameFlags::RESUME_ENABLE)) { + if (!!(frame.header_.flags & FrameFlags::RESUME_ENABLE)) { appender.writeBE( static_cast(frame.token_.data().size())); appender.push(frame.token_.data().data(), frame.token_.data().size()); @@ -334,7 +340,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_LEASE&& frame) { + Frame_LEASE&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int32_t) + sizeof(int32_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); @@ -348,7 +354,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_RESUME&& frame) { + Frame_RESUME&& frame) const { auto queue = createBufferQueue( kFrameHeaderSize + sizeof(uint16_t) + sizeof(uint16_t) + sizeof(uint16_t) + frame.token_.data().size() + sizeof(int32_t) + @@ -371,7 +377,7 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( } std::unique_ptr FrameSerializerV1_0::serializeOut( - Frame_RESUME_OK&& frame) { + Frame_RESUME_OK&& frame) const { auto queue = createBufferQueue(kFrameHeaderSize + sizeof(int64_t)); folly::io::QueueAppender appender(&queue, /* do not grow */ 0); serializeHeaderInto(appender, frame.header_); @@ -381,23 +387,23 @@ std::unique_ptr FrameSerializerV1_0::serializeOut( bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_STREAM& frame, - std::unique_ptr in) { + std::unique_ptr in) const { return deserializeFromInternal(frame, std::move(in)); } bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_CHANNEL& frame, - std::unique_ptr in) { + std::unique_ptr in) const { return deserializeFromInternal(frame, std::move(in)); } bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_RESPONSE& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); + frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); } catch (...) { return false; } @@ -406,11 +412,11 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_FNF& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); + frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); } catch (...) { return false; } @@ -419,7 +425,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_REQUEST_N& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -436,7 +442,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_METADATA_PUSH& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -451,7 +457,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_CANCEL& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -463,11 +469,11 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_PAYLOAD& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); + frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); } catch (...) { return false; } @@ -476,12 +482,12 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_ERROR& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); frame.errorCode_ = static_cast(cur.readBE()); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); + frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); } catch (...) { return false; } @@ -490,8 +496,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_KEEPALIVE& frame, - std::unique_ptr in, - bool /*resumable*/) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -509,7 +514,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_SETUP& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -529,7 +534,7 @@ bool FrameSerializerV1_0::deserializeFrom( } frame.maxLifetime_ = static_cast(maxLifetime); - if (!!(frame.header_.flags_ & FrameFlags::RESUME_ENABLE)) { + if (!!(frame.header_.flags & FrameFlags::RESUME_ENABLE)) { auto resumeTokenSize = cur.readBE(); std::vector data(resumeTokenSize); cur.pull(data.data(), data.size()); @@ -543,7 +548,7 @@ bool FrameSerializerV1_0::deserializeFrom( auto dmtLen = cur.readBE(); frame.dataMimeType_ = cur.readFixedString(dmtLen); - frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags_); + frame.payload_ = deserializePayloadFrom(cur, frame.header_.flags); } catch (...) { return false; } @@ -552,7 +557,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_LEASE& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -577,7 +582,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_RESUME& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -609,7 +614,7 @@ bool FrameSerializerV1_0::deserializeFrom( bool FrameSerializerV1_0::deserializeFrom( Frame_RESUME_OK& frame, - std::unique_ptr in) { + std::unique_ptr in) const { folly::io::Cursor cur(in.get()); try { deserializeHeaderFrom(cur, frame.header_); @@ -679,4 +684,7 @@ ProtocolVersion FrameSerializerV1_0::detectProtocolVersion( return ProtocolVersion::Unknown; } -} // reactivesocket +size_t FrameSerializerV1_0::frameLengthFieldSize() const { + return 3; // bytes +} +} // namespace rsocket diff --git a/rsocket/framing/FrameSerializer_v1_0.h b/rsocket/framing/FrameSerializer_v1_0.h index 1807e2911..f636584dd 100644 --- a/rsocket/framing/FrameSerializer_v1_0.h +++ b/rsocket/framing/FrameSerializer_v1_0.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -12,55 +24,74 @@ class FrameSerializerV1_0 : public FrameSerializer { constexpr static const size_t kFrameHeaderSize = 6; // bytes constexpr static const size_t kMinBytesNeededForAutodetection = 10; // bytes - ProtocolVersion protocolVersion() override; + ProtocolVersion protocolVersion() const override; static ProtocolVersion detectProtocolVersion( const folly::IOBuf& firstFrame, size_t skipBytes = 0); - FrameType peekFrameType(const folly::IOBuf& in) override; - folly::Optional peekStreamId(const folly::IOBuf& in) override; + FrameType peekFrameType(const folly::IOBuf& in) const override; + folly::Optional peekStreamId( + const folly::IOBuf& in, + bool skipFrameLengthBytes) const override; - std::unique_ptr serializeOut(Frame_REQUEST_STREAM&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_CHANNEL&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_RESPONSE&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_FNF&&) override; - std::unique_ptr serializeOut(Frame_REQUEST_N&&) override; - std::unique_ptr serializeOut(Frame_METADATA_PUSH&&) override; - std::unique_ptr serializeOut(Frame_CANCEL&&) override; - std::unique_ptr serializeOut(Frame_PAYLOAD&&) override; - std::unique_ptr serializeOut(Frame_ERROR&&) override; - std::unique_ptr serializeOut(Frame_KEEPALIVE&&, bool) override; - std::unique_ptr serializeOut(Frame_SETUP&&) override; - std::unique_ptr serializeOut(Frame_LEASE&&) override; - std::unique_ptr serializeOut(Frame_RESUME&&) override; - std::unique_ptr serializeOut(Frame_RESUME_OK&&) override; + std::unique_ptr serializeOut( + Frame_REQUEST_STREAM&&) const override; + std::unique_ptr serializeOut( + Frame_REQUEST_CHANNEL&&) const override; + std::unique_ptr serializeOut( + Frame_REQUEST_RESPONSE&&) const override; + std::unique_ptr serializeOut( + Frame_REQUEST_FNF&&) const override; + std::unique_ptr serializeOut(Frame_REQUEST_N&&) const override; + std::unique_ptr serializeOut( + Frame_METADATA_PUSH&&) const override; + std::unique_ptr serializeOut(Frame_CANCEL&&) const override; + std::unique_ptr serializeOut(Frame_PAYLOAD&&) const override; + std::unique_ptr serializeOut(Frame_ERROR&&) const override; + std::unique_ptr serializeOut(Frame_KEEPALIVE&&) const override; + std::unique_ptr serializeOut(Frame_SETUP&&) const override; + std::unique_ptr serializeOut(Frame_LEASE&&) const override; + std::unique_ptr serializeOut(Frame_RESUME&&) const override; + std::unique_ptr serializeOut(Frame_RESUME_OK&&) const override; bool deserializeFrom(Frame_REQUEST_STREAM&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_CHANNEL&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_RESPONSE&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_FNF&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_REQUEST_N&, std::unique_ptr) - override; + const override; bool deserializeFrom(Frame_METADATA_PUSH&, std::unique_ptr) - override; - bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) override; - bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) override; - bool deserializeFrom(Frame_ERROR&, std::unique_ptr) override; - bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr, bool) - override; - bool deserializeFrom(Frame_SETUP&, std::unique_ptr) override; - bool deserializeFrom(Frame_LEASE&, std::unique_ptr) override; - bool deserializeFrom(Frame_RESUME&, std::unique_ptr) override; + const override; + bool deserializeFrom(Frame_CANCEL&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_PAYLOAD&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_ERROR&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_KEEPALIVE&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_SETUP&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_LEASE&, std::unique_ptr) + const override; + bool deserializeFrom(Frame_RESUME&, std::unique_ptr) + const override; bool deserializeFrom(Frame_RESUME_OK&, std::unique_ptr) - override; + const override; static std::unique_ptr deserializeMetadataFrom( folly::io::Cursor& cur, FrameFlags flags); + + private: + std::unique_ptr serializeOutInternal( + Frame_REQUEST_Base&& frame) const; + + size_t frameLengthFieldSize() const override; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FrameTransport.cpp b/rsocket/framing/FrameTransport.cpp deleted file mode 100644 index 3bcd05170..000000000 --- a/rsocket/framing/FrameTransport.cpp +++ /dev/null @@ -1,233 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/framing/FrameTransport.h" - -#include -#include -#include - -#include "rsocket/DuplexConnection.h" -#include "rsocket/framing/FrameProcessor.h" - -namespace rsocket { - -using namespace yarpl::flowable; - -FrameTransport::FrameTransport(std::unique_ptr connection) - : connection_(std::move(connection)) { - CHECK(connection_); -} - -FrameTransport::~FrameTransport() { - VLOG(6) << "~FrameTransport"; -} - -void FrameTransport::connect() { - Lock lock(mutex_); - - DCHECK(connection_); - - if (connectionOutput_) { - // Already connected. - return; - } - - connectionOutput_ = connection_->getOutput(); - connectionOutput_->onSubscribe(yarpl::get_ref(this)); - - // The onSubscribe call on the previous line may have called the terminating - // signal which would call disconnect/close. - if (connection_) { - // This may call ::onSubscribe in-line, which calls ::request on the - // provided subscription, which might deliver frames in-line. It can also - // call onComplete which will call disconnect/close and reset the - // connection_ while still inside of the connection_::setInput method. We - // will create a hard reference for that case and keep the object alive - // until setInput method returns - auto connectionCopy = connection_; - connectionCopy->setInput(yarpl::get_ref(this)); - } -} - -void FrameTransport::setFrameProcessor( - std::shared_ptr frameProcessor) { - Lock lock(mutex_); - - frameProcessor_ = std::move(frameProcessor); - if (frameProcessor_) { - CHECK(!isClosed()); - connect(); - } - - drainWrites(lock); - drainReads(lock); -} - -void FrameTransport::close() { - closeImpl(folly::exception_wrapper()); -} - -void FrameTransport::closeWithError(folly::exception_wrapper ew) { - if (!ew) { - VLOG(1) << "FrameTransport::closeWithError() called with empty exception"; - ew = std::runtime_error("Undefined error"); - } - closeImpl(std::move(ew)); -} - -void FrameTransport::closeImpl(folly::exception_wrapper ew) { - Lock lock(mutex_); - - // Make sure we never try to call back into the processor. - frameProcessor_ = nullptr; - - if (!connection_) { - return; - } - - auto oldConnection = std::move(connection_); - - // Send terminal signals to the DuplexConnection's input and output before - // tearing it down. We must do this per DuplexConnection specification (see - // interface definition). - if (auto subscriber = std::move(connectionOutput_)) { - if (ew) { - subscriber->onError(ew.to_exception_ptr()); - } else { - subscriber->onComplete(); - } - } - if (auto subscription = std::move(connectionInputSub_)) { - subscription->cancel(); - } -} - -void FrameTransport::onSubscribe(yarpl::Reference subscription) { - Lock lock(mutex_); - - if (!connection_) { - return; - } - - CHECK(!connectionInputSub_); - CHECK(frameProcessor_); - connectionInputSub_ = std::move(subscription); - connectionInputSub_->request(kMaxRequestN); -} - -void FrameTransport::onNext(std::unique_ptr frame) { - Lock lock(mutex_); - - if (connection_ && frameProcessor_) { - frameProcessor_->processFrame(std::move(frame)); - } else { - pendingReads_.emplace_back(std::move(frame)); - } -} - -void FrameTransport::terminateProcessor(folly::exception_wrapper ex) { - // This method can be executed multiple times while terminating. - - std::shared_ptr frameProcessor; - { - Lock lock(mutex_); - if (!frameProcessor_) { - pendingTerminal_ = std::move(ex); - return; - } - frameProcessor = std::move(frameProcessor_); - } - - if (frameProcessor) { - VLOG(3) << this << " terminating frame processor ex=" << ex.what(); - frameProcessor->onTerminal(std::move(ex)); - } -} - -void FrameTransport::onComplete() { - VLOG(6) << "onComplete"; - terminateProcessor(folly::exception_wrapper()); -} - -void FrameTransport::onError(std::exception_ptr eptr) { - VLOG(6) << "onError" << folly::exceptionStr(eptr); - - try { - std::rethrow_exception(eptr); - } catch (const std::exception& exn) { - folly::exception_wrapper ew{std::move(eptr), exn}; - terminateProcessor(std::move(ew)); - } -} - -void FrameTransport::request(int64_t n) { - Lock lock(mutex_); - - if (!connection_) { - // request(n) can be delivered during disconnecting. We don't care for it - // anymore. - return; - } - - if (writeAllowance_.release(n) > 0) { - // There are no pending wfrites or we already have this method on the - // stack. - return; - } - - drainWrites(lock); -} - -void FrameTransport::cancel() { - VLOG(6) << "cancel"; - terminateProcessor(folly::exception_wrapper()); -} - -void FrameTransport::outputFrameOrEnqueue(std::unique_ptr frame) { - Lock lock(mutex_); - - // We allow sending frames even without a frame processor so it's possible to - // send terminal frames without expecting anything in return. - if (connection_) { - drainWrites(lock); - if (pendingWrites_.empty() && writeAllowance_.tryAcquire()) { - connectionOutput_->onNext(std::move(frame)); - return; - } - } - - // We either have no allowance to perform the operation, or the queue has not - // been drained (e.g. we're looping in ::request), or we are disconnected. - pendingWrites_.emplace_back(std::move(frame)); -} - -void FrameTransport::drainReads(const FrameTransport::Lock&) { - if (!frameProcessor_) { - return; - } - - while (!pendingReads_.empty()) { - auto frame = std::move(pendingReads_.front()); - pendingReads_.pop_front(); - frameProcessor_->processFrame(std::move(frame)); - } - - if (pendingTerminal_) { - terminateProcessor(std::move(*pendingTerminal_)); - pendingTerminal_ = folly::none; - } -} - -void FrameTransport::drainWrites(const FrameTransport::Lock&) { - if (!connection_) { - return; - } - - // Drain the queue or the allowance. - while (!pendingWrites_.empty() && writeAllowance_.tryAcquire()) { - auto frame = std::move(pendingWrites_.front()); - pendingWrites_.pop_front(); - connectionOutput_->onNext(std::move(frame)); - } -} -} diff --git a/rsocket/framing/FrameTransport.h b/rsocket/framing/FrameTransport.h index 1304ca01d..6c5ed3ef1 100644 --- a/rsocket/framing/FrameTransport.h +++ b/rsocket/framing/FrameTransport.h @@ -1,111 +1,38 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include -#include -#include +#include -#include -#include - -#include "rsocket/internal/AllowanceSemaphore.h" -#include "rsocket/internal/Common.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscription.h" +#include "rsocket/DuplexConnection.h" +#include "rsocket/framing/FrameProcessor.h" namespace rsocket { -class DuplexConnection; -class FrameProcessor; - -class FrameTransport final : - /// Registered as an input in the DuplexConnection. - public yarpl::flowable::Subscriber>, - /// Receives signals about connection writability. - public yarpl::flowable::Subscription { +// Refer to FrameTransportImpl for documentation on the implementation +class FrameTransport { public: - explicit FrameTransport(std::unique_ptr connection); - ~FrameTransport(); - - void setFrameProcessor(std::shared_ptr); - - /// Enqueues provided frame to be written to the underlying connection. - /// Enqueuing a terminal frame does not end the stream. - /// - /// This signal corresponds to Subscriber::onNext. - void outputFrameOrEnqueue(std::unique_ptr); - - /// Cancel the input, complete the output, and close the underlying - /// connection. - void close(); - - /// Cancel the input, error the output, and close the underlying connection. - /// This must be closed with a non-empty exception_wrapper. - void closeWithError(folly::exception_wrapper); - - bool isClosed() const { - return !connection_; - } - - bool outputQueueEmpty() const { - return pendingWrites_.empty(); - } - - private: - // TODO(t15924567): Recursive locks are evil! This should instead use a - // synchronization abstraction which preserves FIFO ordering. However, this is - // incrementally better than the race conditions which existed here before. - // - // Further reading: - // https://groups.google.com/forum/?hl=en#!topic/comp.programming.threads/tcrTKnfP8HI%5B1-25%5D - using Mutex = std::recursive_mutex; - using Lock = std::lock_guard; - - void connect(); - - // Subscriber. - - void onSubscribe(yarpl::Reference) override; - void onNext(std::unique_ptr) override; - void onComplete() override; - void onError(std::exception_ptr) override; - - // Subscription. - - void request(int64_t) override; - void cancel() override; - - /// Drain all pending reads and any pending terminal signal into the - /// FrameProcessor. - /// - /// TODO: This always sends the payloads first and then follows with the - /// terminal signal, regardless if terminal signal was sent before the - /// payloads. Not clear if that is desirable. - void drainReads(const Lock&); - - /// Drain all pending writes into the output subscriber. - void drainWrites(const Lock&); - - /// Terminates the FrameProcessor. Will queue up the exception if no - /// processor is set, overwriting any previously queued exception. - void terminateProcessor(folly::exception_wrapper); - - void closeImpl(folly::exception_wrapper); - - mutable Mutex mutex_; - - std::shared_ptr frameProcessor_; - - AllowanceSemaphore writeAllowance_; - std::shared_ptr connection_; + virtual ~FrameTransport() = default; + virtual void setFrameProcessor(std::shared_ptr) = 0; + virtual void outputFrameOrDrop(std::unique_ptr) = 0; + virtual void close() = 0; - yarpl::Reference>> - connectionOutput_; - yarpl::Reference connectionInputSub_; + // Just for observation purposes! + // TODO(T25011919): remove + virtual DuplexConnection* getConnection() = 0; - std::deque> pendingWrites_; - std::deque> pendingReads_; - folly::Optional pendingTerminal_; + virtual bool isConnectionFramed() const = 0; }; -} +} // namespace rsocket diff --git a/rsocket/framing/FrameTransportImpl.cpp b/rsocket/framing/FrameTransportImpl.cpp new file mode 100644 index 000000000..8e49b9bac --- /dev/null +++ b/rsocket/framing/FrameTransportImpl.cpp @@ -0,0 +1,136 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/FrameTransportImpl.h" + +#include +#include +#include + +#include "rsocket/DuplexConnection.h" +#include "rsocket/framing/FrameProcessor.h" + +namespace rsocket { + +using namespace yarpl::flowable; + +FrameTransportImpl::FrameTransportImpl( + std::unique_ptr connection) + : connection_(std::move(connection)) { + CHECK(connection_); +} + +FrameTransportImpl::~FrameTransportImpl() { + VLOG(1) << "~FrameTransport (" << this << ")"; +} + +void FrameTransportImpl::connect() { + CHECK(connection_); + + // The onSubscribe call on the previous line may have called the terminating + // signal which would call disconnect/close. + if (connection_) { + // This may call ::onSubscribe in-line, which calls ::request on the + // provided subscription, which might deliver frames in-line. It can also + // call onComplete which will call disconnect/close and reset the + // connection_ while still inside of the connection_::setInput method. We + // will create a hard reference for that case and keep the object alive + // until setInput method returns + auto connectionCopy = connection_; + connectionCopy->setInput(shared_from_this()); + } +} + +void FrameTransportImpl::setFrameProcessor( + std::shared_ptr frameProcessor) { + frameProcessor_ = std::move(frameProcessor); + if (frameProcessor_) { + CHECK(!isClosed()); + connect(); + } +} + +void FrameTransportImpl::close() { + // Make sure we never try to call back into the processor. + frameProcessor_ = nullptr; + + if (!connection_) { + return; + } + connection_.reset(); + + if (auto subscription = std::move(connectionInputSub_)) { + subscription->cancel(); + } +} + +void FrameTransportImpl::onSubscribe( + std::shared_ptr subscription) { + if (!connection_) { + return; + } + + CHECK(!connectionInputSub_); + CHECK(frameProcessor_); + connectionInputSub_ = std::move(subscription); + connectionInputSub_->request(std::numeric_limits::max()); +} + +void FrameTransportImpl::onNext(std::unique_ptr frame) { + // Copy in case frame processing calls through to close(). + if (auto const processor = frameProcessor_) { + processor->processFrame(std::move(frame)); + } +} + +void FrameTransportImpl::terminateProcessor(folly::exception_wrapper ex) { + // This method can be executed multiple times while terminating. + + if (!frameProcessor_) { + // already terminated + return; + } + + if (auto conn_sub = std::move(connectionInputSub_)) { + conn_sub->cancel(); + } + + auto frameProcessor = std::move(frameProcessor_); + VLOG(3) << this << " terminating frame processor ex=" << ex.what(); + frameProcessor->onTerminal(std::move(ex)); +} + +void FrameTransportImpl::onComplete() { + VLOG(3) << "FrameTransport received onComplete"; + terminateProcessor(folly::exception_wrapper()); +} + +void FrameTransportImpl::onError(folly::exception_wrapper ex) { + VLOG(3) << "FrameTransport received onError: " << ex.what(); + terminateProcessor(std::move(ex)); +} + +void FrameTransportImpl::outputFrameOrDrop( + std::unique_ptr frame) { + if (connection_) { + connection_->send(std::move(frame)); + } +} + +bool FrameTransportImpl::isConnectionFramed() const { + CHECK(connection_); + return connection_->isFramed(); +} + +} // namespace rsocket diff --git a/rsocket/framing/FrameTransportImpl.h b/rsocket/framing/FrameTransportImpl.h new file mode 100644 index 000000000..36ce9b526 --- /dev/null +++ b/rsocket/framing/FrameTransportImpl.h @@ -0,0 +1,77 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "rsocket/DuplexConnection.h" +#include "rsocket/internal/Common.h" +#include "yarpl/flowable/Subscription.h" + +#include "rsocket/framing/FrameTransport.h" + +namespace rsocket { + +class FrameProcessor; + +class FrameTransportImpl + : public FrameTransport, + /// Registered as an input in the DuplexConnection. + public DuplexConnection::Subscriber, + public std::enable_shared_from_this { + public: + explicit FrameTransportImpl(std::unique_ptr connection); + ~FrameTransportImpl(); + + void setFrameProcessor(std::shared_ptr) override; + + /// Writes the frame directly to output. If the connection was closed it will + /// drop the frame. + void outputFrameOrDrop(std::unique_ptr) override; + + /// Cancel the input and close the underlying connection. + void close() override; + + bool isClosed() const { + return !connection_; + } + + DuplexConnection* getConnection() override { + return connection_.get(); + } + + bool isConnectionFramed() const override; + + // Subscriber. + + void onSubscribe(std::shared_ptr) override; + void onNext(std::unique_ptr) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; + + private: + void connect(); + + /// Terminates the FrameProcessor. Will queue up the exception if no + /// processor is set, overwriting any previously queued exception. + void terminateProcessor(folly::exception_wrapper); + + std::shared_ptr frameProcessor_; + std::shared_ptr connection_; + + std::shared_ptr connectionOutput_; + std::shared_ptr connectionInputSub_; +}; + +} // namespace rsocket diff --git a/rsocket/framing/FrameType.cpp b/rsocket/framing/FrameType.cpp index 4cf04eabf..8fb4fd140 100644 --- a/rsocket/framing/FrameType.cpp +++ b/rsocket/framing/FrameType.cpp @@ -1,50 +1,72 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FrameType.h" -#include - #include +#include + namespace rsocket { -std::ostream& operator<<(std::ostream& os, FrameType type) { +constexpr folly::StringPiece kUnknown{"UNKNOWN_FRAME_TYPE"}; + +folly::StringPiece toString(FrameType type) { switch (type) { case FrameType::RESERVED: - return os << "RESERVED"; + return "RESERVED"; case FrameType::SETUP: - return os << "SETUP"; + return "SETUP"; case FrameType::LEASE: - return os << "LEASE"; + return "LEASE"; case FrameType::KEEPALIVE: - return os << "KEEPALIVE"; + return "KEEPALIVE"; case FrameType::REQUEST_RESPONSE: - return os << "REQUEST_RESPONSE"; + return "REQUEST_RESPONSE"; case FrameType::REQUEST_FNF: - return os << "REQUEST_FNF"; + return "REQUEST_FNF"; case FrameType::REQUEST_STREAM: - return os << "REQUEST_STREAM"; + return "REQUEST_STREAM"; case FrameType::REQUEST_CHANNEL: - return os << "REQUEST_CHANNEL"; + return "REQUEST_CHANNEL"; case FrameType::REQUEST_N: - return os << "REQUEST_N"; + return "REQUEST_N"; case FrameType::CANCEL: - return os << "CANCEL"; + return "CANCEL"; case FrameType::PAYLOAD: - return os << "PAYLOAD"; + return "PAYLOAD"; case FrameType::ERROR: - return os << "ERROR"; + return "ERROR"; case FrameType::METADATA_PUSH: - return os << "METADATA_PUSH"; + return "METADATA_PUSH"; case FrameType::RESUME: - return os << "RESUME"; + return "RESUME"; case FrameType::RESUME_OK: - return os << "RESUME_OK"; + return "RESUME_OK"; case FrameType::EXT: - return os << "EXT"; + return "EXT"; default: - break; + DLOG(FATAL) << "Unknown frame type"; + return kUnknown; } - return os << "Unknown FrameType[" << static_cast(type) << "]"; } + +std::ostream& operator<<(std::ostream& os, FrameType type) { + auto const str = toString(type); + if (str == kUnknown) { + return os << "Unknown FrameType[" << static_cast(type) << "]"; + } + return os << str; } +} // namespace rsocket diff --git a/rsocket/framing/FrameType.h b/rsocket/framing/FrameType.h index 3c8d04a9d..726f9cd75 100644 --- a/rsocket/framing/FrameType.h +++ b/rsocket/framing/FrameType.h @@ -1,10 +1,24 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include +#include + namespace rsocket { enum class FrameType : uint8_t { @@ -26,5 +40,8 @@ enum class FrameType : uint8_t { EXT = 0x3F, }; +folly::StringPiece toString(FrameType); + std::ostream& operator<<(std::ostream&, FrameType); -} + +} // namespace rsocket diff --git a/rsocket/framing/FramedDuplexConnection.cpp b/rsocket/framing/FramedDuplexConnection.cpp index 11f971b81..9dec14a76 100644 --- a/rsocket/framing/FramedDuplexConnection.cpp +++ b/rsocket/framing/FramedDuplexConnection.cpp @@ -1,13 +1,89 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FramedDuplexConnection.h" -#include "rsocket/framing/FrameSerializer.h" +#include +#include "rsocket/framing/FrameSerializer_v1_0.h" #include "rsocket/framing/FramedReader.h" -#include "rsocket/framing/FramedWriter.h" namespace rsocket { -using namespace yarpl::flowable; +namespace { + +constexpr auto kMaxFrameLength = 0xFFFFFF; // 24bit max value + +template +void writeFrameLength( + TWriter& cur, + size_t frameLength, + size_t frameSizeFieldLength) { + DCHECK(frameSizeFieldLength > 0); + + // starting from the highest byte + // frameSizeFieldLength == 3 => shift = [16,8,0] + // frameSizeFieldLength == 4 => shift = [24,16,8,0] + auto shift = (frameSizeFieldLength - 1) * 8; + + while (frameSizeFieldLength--) { + const auto byte = (frameLength >> shift) & 0xFF; + cur.write(static_cast(byte)); + shift -= 8; + } +} + +size_t getFrameSizeFieldLength(ProtocolVersion version) { + CHECK(version != ProtocolVersion::Unknown); + if (version < FrameSerializerV1_0::Version) { + return sizeof(int32_t); + } else { + return 3; // bytes + } +} + +std::unique_ptr prependSize( + ProtocolVersion version, + std::unique_ptr payload) { + CHECK(payload); + + const auto frameSizeFieldLength = getFrameSizeFieldLength(version); + const auto payloadLength = payload->computeChainDataLength(); + + CHECK_LE(payloadLength, kMaxFrameLength) + << "payloadLength: " << payloadLength + << " kMaxFrameLength: " << kMaxFrameLength; + + if (payload->headroom() >= frameSizeFieldLength) { + // move the data pointer back and write value to the payload + payload->prepend(frameSizeFieldLength); + folly::io::RWPrivateCursor cur(payload.get()); + writeFrameLength(cur, payloadLength, frameSizeFieldLength); + VLOG(4) << "writing frame length=" << payload->length() << std::endl + << hexDump(payload->clone()->moveToFbString()); + return payload; + } else { + auto newPayload = folly::IOBuf::createCombined(frameSizeFieldLength); + folly::io::Appender appender(newPayload.get(), /* do not grow */ 0); + writeFrameLength(appender, payloadLength, frameSizeFieldLength); + newPayload->appendChain(std::move(payload)); + VLOG(4) << "writing frame length=" << newPayload->computeChainDataLength() + << std::endl + << hexDump(newPayload->clone()->moveToFbString()); + return newPayload; + } +} + +} // namespace FramedDuplexConnection::~FramedDuplexConnection() {} @@ -17,19 +93,21 @@ FramedDuplexConnection::FramedDuplexConnection( : inner_(std::move(connection)), protocolVersion_(std::make_shared(protocolVersion)) {} -yarpl::Reference>> -FramedDuplexConnection::getOutput() noexcept { - return yarpl::make_ref( - inner_->getOutput(), protocolVersion_); +void FramedDuplexConnection::send(std::unique_ptr buf) { + if (!inner_) { + return; + } + + auto sized = prependSize(*protocolVersion_, std::move(buf)); + inner_->send(std::move(sized)); } void FramedDuplexConnection::setInput( - yarpl::Reference>> framesSink) { - if(!inputReader_) { - inputReader_ = yarpl::make_ref(protocolVersion_); + std::shared_ptr framesSink) { + if (!inputReader_) { + inputReader_ = std::make_shared(protocolVersion_); inner_->setInput(inputReader_); } inputReader_->setInput(std::move(framesSink)); } - -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FramedDuplexConnection.h b/rsocket/framing/FramedDuplexConnection.h index e8259cfff..2073266ea 100644 --- a/rsocket/framing/FramedDuplexConnection.h +++ b/rsocket/framing/FramedDuplexConnection.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -8,7 +20,6 @@ namespace rsocket { class FramedReader; -class FramedWriter; struct ProtocolVersion; class FramedDuplexConnection : public virtual DuplexConnection { @@ -19,20 +30,21 @@ class FramedDuplexConnection : public virtual DuplexConnection { ~FramedDuplexConnection(); - yarpl::Reference>> - getOutput() noexcept override; + void send(std::unique_ptr) override; - void setInput(yarpl::Reference>> - framesSink) override; + void setInput(std::shared_ptr) override; - bool isFramed() override { + bool isFramed() const override { return true; } + DuplexConnection* getConnection() { + return inner_.get(); + } + private: - std::unique_ptr inner_; - yarpl::Reference inputReader_; - std::shared_ptr protocolVersion_; + const std::unique_ptr inner_; + std::shared_ptr inputReader_; + const std::shared_ptr protocolVersion_; }; - -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FramedReader.cpp b/rsocket/framing/FramedReader.cpp index 6e393c24e..02edba694 100644 --- a/rsocket/framing/FramedReader.cpp +++ b/rsocket/framing/FramedReader.cpp @@ -1,89 +1,89 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/framing/FramedReader.h" #include -#include "rsocket/framing/FrameSerializer_v0_1.h" #include "rsocket/framing/FrameSerializer_v1_0.h" +#include "rsocket/internal/Common.h" namespace rsocket { using namespace yarpl::flowable; namespace { -constexpr auto kFrameLengthFieldLengthV0_1 = sizeof(int32_t); -constexpr auto kFrameLengthFieldLengthV1_0 = 3; // bytes -} // namespace -size_t FramedReader::getFrameSizeFieldLength() const { - DCHECK(*protocolVersion_ != ProtocolVersion::Unknown); - if (*protocolVersion_ < FrameSerializerV1_0::Version) { - return kFrameLengthFieldLengthV0_1; - } else { - return kFrameLengthFieldLengthV1_0; // bytes - } +constexpr size_t kFrameLengthFieldLengthV1_0 = 3; + +/// Get the byte size of the frame length field in an RSocket frame. +size_t frameSizeFieldLength(ProtocolVersion version) { + DCHECK_NE(version, ProtocolVersion::Unknown); + return kFrameLengthFieldLengthV1_0; } -size_t FramedReader::getFrameMinimalLength() const { - DCHECK(*protocolVersion_ != ProtocolVersion::Unknown); - if (*protocolVersion_ < FrameSerializerV1_0::Version) { - return FrameSerializerV0::kFrameHeaderSize + getFrameSizeFieldLength(); - } else { - return FrameSerializerV1_0::kFrameHeaderSize; - } +/// Get the minimum size for a valid RSocket frame (including its frame length +/// field). +size_t minimalFrameLength(ProtocolVersion version) { + DCHECK_NE(version, ProtocolVersion::Unknown); + return FrameSerializerV1_0::kFrameHeaderSize; } -size_t FramedReader::getFrameSizeWithLengthField(size_t frameSize) const { - DCHECK(*protocolVersion_ != ProtocolVersion::Unknown); - if (*protocolVersion_ < FrameSerializerV1_0::Version) { - return frameSize; - } else { - return frameSize + getFrameSizeFieldLength(); - } +/// Compute the length of the entire frame (including its frame length field), +/// if given only its frame length field. +size_t frameSizeWithLengthField(ProtocolVersion version, size_t frameSize) { + return version < FrameSerializerV1_0::Version + ? frameSize + : frameSize + frameSizeFieldLength(version); } -size_t FramedReader::getPayloadSize(size_t frameSize) const { - DCHECK(*protocolVersion_ != ProtocolVersion::Unknown); - if (*protocolVersion_ < FrameSerializerV1_0::Version) { - return frameSize - getFrameSizeFieldLength(); - } else { - return frameSize; - } +/// Compute the length of the frame (excluding its frame length field), if given +/// only its frame length field. +size_t frameSizeWithoutLengthField(ProtocolVersion version, size_t frameSize) { + DCHECK_NE(version, ProtocolVersion::Unknown); + return version < FrameSerializerV1_0::Version + ? frameSize - frameSizeFieldLength(version) + : frameSize; } +} // namespace size_t FramedReader::readFrameLength() const { - auto frameSizeFieldLength = getFrameSizeFieldLength(); - DCHECK(frameSizeFieldLength > 0); + const auto fieldLength = frameSizeFieldLength(*version_); + DCHECK_GT(fieldLength, 0); - folly::io::Cursor cur(payloadQueue_.front()); + folly::io::Cursor cur{payloadQueue_.front()}; size_t frameLength = 0; - // start reading the highest byte - // frameSizeFieldLength == 3 => shift = [16,8,0] - // frameSizeFieldLength == 4 => shift = [24,16,8,0] - auto shift = (frameSizeFieldLength - 1) * 8; - - while (frameSizeFieldLength--) { - frameLength |= static_cast(cur.read() << shift); - shift -= 8; + // Reading of arbitrary-sized big-endian integer. + for (size_t i = 0; i < fieldLength; ++i) { + frameLength <<= 8; + frameLength |= cur.read(); } + return frameLength; } -void FramedReader::onSubscribe( - yarpl::Reference subscription) noexcept { - SubscriberBase::onSubscribe(subscription); - subscription->request(kMaxRequestN); +void FramedReader::onSubscribe(std::shared_ptr subscription) { + subscription_ = std::move(subscription); + subscription_->request(std::numeric_limits::max()); } -void FramedReader::onNext(std::unique_ptr payload) noexcept { - if (payload) { - VLOG(4) << "incoming bytes length=" << payload->length() << std::endl - << hexDump(payload->clone()->moveToFbString()); - payloadQueue_.append(std::move(payload)); - parseFrames(); - } +void FramedReader::onNext(std::unique_ptr payload) { + VLOG(4) << "incoming bytes length=" << payload->length() << '\n' + << hexDump(payload->clone()->moveToFbString()); + payloadQueue_.append(std::move(payload)); + parseFrames(); } void FramedReader::parseFrames() { @@ -91,124 +91,127 @@ void FramedReader::parseFrames() { return; } + // Delivering onNext can trigger termination and destroy this instance. + auto const self = shared_from_this(); + dispatchingFrames_ = true; - while (allowance_.canAcquire() && frames_) { + while (allowance_.canConsume(1) && inner_) { if (!ensureOrAutodetectProtocolVersion()) { - // at this point we dont have enough bytes on the wire - // or we errored out + // At this point we dont have enough bytes on the wire or we errored out. break; } - if (payloadQueue_.chainLength() < getFrameSizeFieldLength()) { - // we don't even have the next frame size value + auto const frameSizeFieldLen = frameSizeFieldLength(*version_); + if (payloadQueue_.chainLength() < frameSizeFieldLen) { + // We don't even have the next frame size value. break; } - const auto nextFrameSize = readFrameLength(); - - // so if the size value is less than minimal frame length something is wrong - if (nextFrameSize < getFrameMinimalLength()) { - error("invalid data stream"); + auto const nextFrameSize = readFrameLength(); + if (nextFrameSize < minimalFrameLength(*version_)) { + error("Invalid frame - Frame size smaller than minimum"); break; } if (payloadQueue_.chainLength() < - getFrameSizeWithLengthField(nextFrameSize)) { - // need to accumulate more data + frameSizeWithLengthField(*version_, nextFrameSize)) { + // Need to accumulate more data. break; } - payloadQueue_.trimStart(getFrameSizeFieldLength()); - auto payloadSize = getPayloadSize(nextFrameSize); - // IOBufQueue::split(0) returns a null unique_ptr, so we create an empty - // IOBuf object and pass a unique_ptr to it instead. This simplifies - // clients' code because they can assume the pointer is non-null. - auto nextFrame = payloadSize != 0 ? payloadQueue_.split(payloadSize) - : folly::IOBuf::create(0); - CHECK(allowance_.tryAcquire(1)); + payloadQueue_.trimStart(frameSizeFieldLen); + const auto payloadSize = + frameSizeWithoutLengthField(*version_, nextFrameSize); + + DCHECK_GT(payloadSize, 0) + << "folly::IOBufQueue::split(0) returns a nullptr, can't have that"; + auto nextFrame = payloadQueue_.split(payloadSize); - VLOG(4) << "parsed frame length=" << nextFrame->length() << std::endl + CHECK(allowance_.tryConsume(1)); + + VLOG(4) << "parsed frame length=" << nextFrame->length() << '\n' << hexDump(nextFrame->clone()->moveToFbString()); - frames_->onNext(std::move(nextFrame)); + inner_->onNext(std::move(nextFrame)); } + dispatchingFrames_ = false; } -void FramedReader::onComplete() noexcept { - completed_ = true; - payloadQueue_.move(); // equivalent to clear(), releases the buffers - if (auto subscriber = std::move(frames_)) { +void FramedReader::onComplete() { + payloadQueue_.move(); + auto subscription = std::move(subscription_); + if (auto subscriber = std::move(inner_)) { + // After this call the instance can be destroyed! subscriber->onComplete(); } } -void FramedReader::onError(std::exception_ptr ex) noexcept { - completed_ = true; - payloadQueue_.move(); // equivalent to clear(), releases the buffers - if (auto subscriber = std::move(frames_)) { +void FramedReader::onError(folly::exception_wrapper ex) { + payloadQueue_.move(); + auto subscription = std::move(subscription_); + if (auto subscriber = std::move(inner_)) { + // After this call the instance can be destroyed! subscriber->onError(std::move(ex)); } } -void FramedReader::request(int64_t n) noexcept { - allowance_.release(n); +void FramedReader::request(int64_t n) { + allowance_.add(n); parseFrames(); } -void FramedReader::cancel() noexcept { - allowance_.drain(); - frames_ = nullptr; +void FramedReader::cancel() { + allowance_.consumeAll(); + inner_ = nullptr; } void FramedReader::setInput( - yarpl::Reference>> frames) { - CHECK(!frames_) - << "FrameReader should be closed before setting another input."; - frames_ = std::move(frames); - frames_->onSubscribe(yarpl::get_ref(this)); + std::shared_ptr inner) { + CHECK(!inner_) + << "Must cancel original input to FramedReader before setting a new one"; + inner_ = std::move(inner); + inner_->onSubscribe(shared_from_this()); } bool FramedReader::ensureOrAutodetectProtocolVersion() { - if (*protocolVersion_ != ProtocolVersion::Unknown) { + if (*version_ != ProtocolVersion::Unknown) { return true; } - auto minBytesNeeded = std::max( - FrameSerializerV0_1::kMinBytesNeededForAutodetection, - FrameSerializerV1_0::kMinBytesNeededForAutodetection); - DCHECK(minBytesNeeded > 0); + const auto minBytesNeeded = + FrameSerializerV1_0::kMinBytesNeededForAutodetection; + DCHECK_GT(minBytesNeeded, 0); if (payloadQueue_.chainLength() < minBytesNeeded) { return false; } - DCHECK(minBytesNeeded > kFrameLengthFieldLengthV0_1); - DCHECK(minBytesNeeded > kFrameLengthFieldLengthV1_0); + DCHECK_GT(minBytesNeeded, kFrameLengthFieldLengthV1_0); - bool recognized = FrameSerializerV1_0::detectProtocolVersion( - *payloadQueue_.front(), kFrameLengthFieldLengthV1_0) != - ProtocolVersion::Unknown; - if (recognized) { - *protocolVersion_ = FrameSerializerV1_0::Version; - return true; - } + auto const& firstFrame = *payloadQueue_.front(); - recognized = FrameSerializerV0_1::detectProtocolVersion( - *payloadQueue_.front(), kFrameLengthFieldLengthV0_1) != - ProtocolVersion::Unknown; - if (recognized) { - *protocolVersion_ = FrameSerializerV0_1::Version; + const auto detectedV1 = FrameSerializerV1_0::detectProtocolVersion( + firstFrame, kFrameLengthFieldLengthV1_0); + if (detectedV1 != ProtocolVersion::Unknown) { + *version_ = FrameSerializerV1_0::Version; return true; } - error("could not detect protocol version from framing"); + error("Could not detect protocol version from framing"); return false; } void FramedReader::error(std::string errorMsg) { VLOG(1) << "error: " << errorMsg; - onError(std::make_exception_ptr(std::runtime_error(std::move(errorMsg)))); - SubscriberBase::subscription()->cancel(); + + payloadQueue_.move(); + if (auto subscription = std::move(subscription_)) { + subscription->cancel(); + } + if (auto subscriber = std::move(inner_)) { + // After this call the instance can be destroyed! + subscriber->onError(std::runtime_error{std::move(errorMsg)}); + } } } // namespace rsocket diff --git a/rsocket/framing/FramedReader.h b/rsocket/framing/FramedReader.h index 8c048ecad..d0bc05a4f 100644 --- a/rsocket/framing/FramedReader.h +++ b/rsocket/framing/FramedReader.h @@ -1,59 +1,67 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include "rsocket/internal/AllowanceSemaphore.h" -#include "yarpl/flowable/Subscriber.h" + +#include "rsocket/DuplexConnection.h" +#include "rsocket/framing/ProtocolVersion.h" +#include "rsocket/internal/Allowance.h" #include "yarpl/flowable/Subscription.h" namespace rsocket { -struct ProtocolVersion; +class FramedReader : public DuplexConnection::Subscriber, + public yarpl::flowable::Subscription, + public std::enable_shared_from_this { + public: + explicit FramedReader(std::shared_ptr version) + : version_{std::move(version)} {} + + /// Set the inner subscriber which will be getting full frame payloads. + void setInput(std::shared_ptr); -class FramedReader : public yarpl::flowable::Subscriber>, - public yarpl::flowable::Subscription { - using SubscriberBase = yarpl::flowable::Subscriber>; + /// Cancel the subscription and error the inner subscriber. + void error(std::string); - public: - explicit FramedReader(std::shared_ptr protocolVersion) - : payloadQueue_(folly::IOBufQueue::cacheChainLength()), - protocolVersion_(std::move(protocolVersion)) {} + // Subscriber. - void setInput(yarpl::Reference>> - frames); + void onSubscribe(std::shared_ptr) override; + void onNext(std::unique_ptr) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; + + // Subscription. + + void request(int64_t) override; + void cancel() override; private: - // Subscriber methods - void onSubscribe( - yarpl::Reference subscription) noexcept override; - void onNext(std::unique_ptr element) noexcept override; - void onComplete() noexcept override; - void onError(std::exception_ptr ex) noexcept override; - - // Subscription methods - void request(int64_t n) noexcept override; - void cancel() noexcept override; - - void error(std::string errorMsg); void parseFrames(); bool ensureOrAutodetectProtocolVersion(); - size_t getFrameSizeFieldLength() const; - size_t getFrameMinimalLength() const; - size_t getFrameSizeWithLengthField(size_t frameSize) const; - size_t getPayloadSize(size_t frameSize) const; size_t readFrameLength() const; - yarpl::Reference>> frames_; - - AllowanceSemaphore allowance_{0}; + std::shared_ptr subscription_; + std::shared_ptr inner_; - bool completed_{false}; + Allowance allowance_; bool dispatchingFrames_{false}; - folly::IOBufQueue payloadQueue_; - std::shared_ptr protocolVersion_; + folly::IOBufQueue payloadQueue_{folly::IOBufQueue::cacheChainLength()}; + const std::shared_ptr version_; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/framing/FramedWriter.cpp b/rsocket/framing/FramedWriter.cpp deleted file mode 100644 index 0011e5e32..000000000 --- a/rsocket/framing/FramedWriter.cpp +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/framing/FramedWriter.h" - -#include - -#include "rsocket/framing/FrameSerializer_v1_0.h" - -namespace rsocket { - -using namespace yarpl::flowable; - -constexpr static const auto kMaxFrameLength = 0xFFFFFF; // 24bit max value - -template -static void writeFrameLength( - TWriter& cur, - size_t frameLength, - size_t frameSizeFieldLength) { - DCHECK(frameSizeFieldLength > 0); - - // starting from the highest byte - // frameSizeFieldLength == 3 => shift = [16,8,0] - // frameSizeFieldLength == 4 => shift = [24,16,8,0] - auto shift = (frameSizeFieldLength - 1) * 8; - - while (frameSizeFieldLength--) { - auto byte = (frameLength >> shift) & 0xFF; - cur.write(static_cast(byte)); - shift -= 8; - } -} - -size_t FramedWriter::getFrameSizeFieldLength() const { - CHECK(*protocolVersion_ != ProtocolVersion::Unknown); - if (*protocolVersion_ < FrameSerializerV1_0::Version) { - return sizeof(int32_t); - } else { - return 3; // bytes - } -} - -size_t FramedWriter::getPayloadLength(size_t payloadLength) const { - DCHECK(*protocolVersion_ != ProtocolVersion::Unknown); - if (*protocolVersion_ < FrameSerializerV1_0::Version) { - return payloadLength + getFrameSizeFieldLength(); - } else { - return payloadLength; - } -} - -void FramedWriter::onSubscribe( - yarpl::Reference subscription) { - SubscriberBase::onSubscribe(subscription); - stream_->onSubscribe(std::move(subscription)); -} - -std::unique_ptr FramedWriter::appendSize( - std::unique_ptr payload) { - CHECK(payload); - - const auto frameSizeFieldLength = getFrameSizeFieldLength(); - // the frame size includes the payload size and the size value - auto payloadLength = getPayloadLength(payload->computeChainDataLength()); - if (payloadLength > kMaxFrameLength) { - return nullptr; - } - - if (payload->headroom() >= frameSizeFieldLength) { - // move the data pointer back and write value to the payload - payload->prepend(frameSizeFieldLength); - folly::io::RWPrivateCursor cur(payload.get()); - writeFrameLength(cur, payloadLength, frameSizeFieldLength); - VLOG(4) << "writing frame length=" << payload->length() << std::endl - << hexDump(payload->clone()->moveToFbString()); - return payload; - } else { - auto newPayload = folly::IOBuf::createCombined(frameSizeFieldLength); - folly::io::Appender appender(newPayload.get(), /* do not grow */ 0); - writeFrameLength(appender, payloadLength, frameSizeFieldLength); - newPayload->appendChain(std::move(payload)); - VLOG(4) << "writing frame length=" << newPayload->computeChainDataLength() - << std::endl - << hexDump(newPayload->clone()->moveToFbString()); - return newPayload; - } -} - -void FramedWriter::onNext(std::unique_ptr payload) { - auto sizedPayload = appendSize(std::move(payload)); - if (!sizedPayload) { - error("payload too big"); - return; - } - stream_->onNext(std::move(sizedPayload)); -} - -void FramedWriter::onNextMultiple( - std::vector> payloads) { - folly::IOBufQueue payloadQueue; - - for (auto& payload : payloads) { - auto sizedPayload = appendSize(std::move(payload)); - if (!sizedPayload) { - error("payload too big"); - return; - } - payloadQueue.append(std::move(sizedPayload)); - } - stream_->onNext(payloadQueue.move()); -} - -void FramedWriter::error(std::string errorMsg) { - VLOG(1) << "error: " << errorMsg; - onError(std::make_exception_ptr(std::runtime_error(std::move(errorMsg)))); - SubscriberBase::subscription()->cancel(); -} - -void FramedWriter::onComplete() { - SubscriberBase::onComplete(); - stream_->onComplete(); - stream_ = nullptr; -} - -void FramedWriter::onError(std::exception_ptr ex) { - SubscriberBase::onError(ex); - stream_->onError(std::move(ex)); - stream_ = nullptr; -} - -} // reactivesocket diff --git a/rsocket/framing/FramedWriter.h b/rsocket/framing/FramedWriter.h deleted file mode 100644 index 356f6c129..000000000 --- a/rsocket/framing/FramedWriter.h +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include "yarpl/flowable/Subscriber.h" - -namespace folly { -class IOBuf; -} - -namespace rsocket { - -struct ProtocolVersion; - -class FramedWriter : public yarpl::flowable::Subscriber> { - using SubscriberBase = yarpl::flowable::Subscriber>; - - public: - explicit FramedWriter( - yarpl::Reference stream, - std::shared_ptr protocolVersion) - : stream_(std::move(stream)), - protocolVersion_(std::move(protocolVersion)) {} - - void onNextMultiple(std::vector> element); - - private: - // Subscriber methods - void onSubscribe( - yarpl::Reference subscription) override; - void onNext(std::unique_ptr element) override; - void onComplete() override; - void onError(std::exception_ptr ex) override; - - void error(std::string errorMsg); - - size_t getFrameSizeFieldLength() const; - size_t getPayloadLength(size_t payloadLength) const; - - std::unique_ptr appendSize( - std::unique_ptr payload); - - yarpl::Reference stream_; - std::shared_ptr protocolVersion_; -}; - -} // reactivesocket diff --git a/rsocket/framing/Framer.cpp b/rsocket/framing/Framer.cpp new file mode 100644 index 000000000..0fb97763b --- /dev/null +++ b/rsocket/framing/Framer.cpp @@ -0,0 +1,204 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/Framer.h" +#include +#include "rsocket/framing/FrameSerializer_v1_0.h" + +namespace rsocket { + +namespace { + +constexpr size_t kFrameLengthFieldLengthV1_0 = 3; +constexpr auto kMaxFrameLength = 0xFFFFFF; // 24bit max value + +template +void writeFrameLength( + TWriter& cur, + size_t frameLength, + size_t frameSizeFieldLength) { + DCHECK(frameSizeFieldLength > 0); + + // starting from the highest byte + // frameSizeFieldLength == 3 => shift = [16,8,0] + // frameSizeFieldLength == 4 => shift = [24,16,8,0] + auto shift = (frameSizeFieldLength - 1) * 8; + + while (frameSizeFieldLength--) { + const auto byte = (frameLength >> shift) & 0xFF; + cur.write(static_cast(byte)); + shift -= 8; + } +} +} // namespace + +/// Get the byte size of the frame length field in an RSocket frame. +size_t Framer::frameSizeFieldLength() const { + DCHECK_NE(protocolVersion_, ProtocolVersion::Unknown); + if (protocolVersion_ < FrameSerializerV1_0::Version) { + return sizeof(int32_t); + } else { + return 3; // bytes + } +} + +/// Get the minimum size for a valid RSocket frame (including its frame length +/// field). +size_t Framer::minimalFrameLength() const { + DCHECK_NE(protocolVersion_, ProtocolVersion::Unknown); + return FrameSerializerV1_0::kFrameHeaderSize; +} + +/// Compute the length of the entire frame (including its frame length field), +/// if given only its frame length field. +size_t Framer::frameSizeWithLengthField(size_t frameSize) const { + return protocolVersion_ < FrameSerializerV1_0::Version + ? frameSize + : frameSize + frameSizeFieldLength(); +} + +/// Compute the length of the frame (excluding its frame length field), if given +/// only its frame length field. +size_t Framer::frameSizeWithoutLengthField(size_t frameSize) const { + DCHECK_NE(protocolVersion_, ProtocolVersion::Unknown); + return protocolVersion_ < FrameSerializerV1_0::Version + ? frameSize - frameSizeFieldLength() + : frameSize; +} + +size_t Framer::readFrameLength() const { + const auto fieldLength = frameSizeFieldLength(); + DCHECK_GT(fieldLength, 0); + + folly::io::Cursor cur{payloadQueue_.front()}; + size_t frameLength = 0; + + // Reading of arbitrary-sized big-endian integer. + for (size_t i = 0; i < fieldLength; ++i) { + frameLength <<= 8; + frameLength |= cur.read(); + } + + return frameLength; +} + +void Framer::addFrameChunk(std::unique_ptr payload) { + payloadQueue_.append(std::move(payload)); + parseFrames(); +} + +void Framer::parseFrames() { + if (payloadQueue_.empty() || !ensureOrAutodetectProtocolVersion()) { + // At this point we dont have enough bytes on the wire or we errored out. + return; + } + + while (!payloadQueue_.empty()) { + auto const frameSizeFieldLen = frameSizeFieldLength(); + if (payloadQueue_.chainLength() < frameSizeFieldLen) { + // We don't even have the next frame size value. + break; + } + + auto const nextFrameSize = readFrameLength(); + if (nextFrameSize < minimalFrameLength()) { + error("Invalid frame - Frame size smaller than minimum"); + break; + } + + if (payloadQueue_.chainLength() < frameSizeWithLengthField(nextFrameSize)) { + // Need to accumulate more data. + break; + } + + auto payloadSize = frameSizeWithoutLengthField(nextFrameSize); + if (stripFrameLengthField_) { + payloadQueue_.trimStart(frameSizeFieldLen); + } else { + payloadSize += frameSizeFieldLen; + } + + DCHECK_GT(payloadSize, 0) + << "folly::IOBufQueue::split(0) returns a nullptr, can't have that"; + auto nextFrame = payloadQueue_.split(payloadSize); + onFrame(std::move(nextFrame)); + } +} + +bool Framer::ensureOrAutodetectProtocolVersion() { + if (protocolVersion_ != ProtocolVersion::Unknown) { + return true; + } + + const auto minBytesNeeded = + FrameSerializerV1_0::kMinBytesNeededForAutodetection; + DCHECK_GT(minBytesNeeded, 0); + if (payloadQueue_.chainLength() < minBytesNeeded) { + return false; + } + + DCHECK_GT(minBytesNeeded, kFrameLengthFieldLengthV1_0); + + auto const& firstFrame = *payloadQueue_.front(); + + const auto detectedV1 = FrameSerializerV1_0::detectProtocolVersion( + firstFrame, kFrameLengthFieldLengthV1_0); + if (detectedV1 != ProtocolVersion::Unknown) { + protocolVersion_ = FrameSerializerV1_0::Version; + return true; + } + + error("Could not detect protocol version from data"); + return false; +} + +std::unique_ptr Framer::prependSize( + std::unique_ptr payload) { + CHECK(payload); + + const auto frameSizeFieldLengthValue = frameSizeFieldLength(); + const auto payloadLength = payload->computeChainDataLength(); + + CHECK_LE(payloadLength, kMaxFrameLength) + << "payloadLength: " << payloadLength + << " kMaxFrameLength: " << kMaxFrameLength; + + if (payload->headroom() >= frameSizeFieldLengthValue) { + // move the data pointer back and write value to the payload + payload->prepend(frameSizeFieldLengthValue); + folly::io::RWPrivateCursor cur(payload.get()); + writeFrameLength(cur, payloadLength, frameSizeFieldLengthValue); + return payload; + } else { + auto newPayload = folly::IOBuf::createCombined(frameSizeFieldLengthValue); + folly::io::Appender appender(newPayload.get(), /* do not grow */ 0); + writeFrameLength(appender, payloadLength, frameSizeFieldLengthValue); + newPayload->appendChain(std::move(payload)); + return newPayload; + } +} + +StreamId Framer::peekStreamId( + const folly::IOBuf& frame, + bool skipFrameLengthBytes) const { + return FrameSerializer::peekStreamId( + protocolVersion_, frame, skipFrameLengthBytes) + .value(); +} + +std::unique_ptr Framer::drainPayloadQueue() { + return payloadQueue_.move(); +} + +} // namespace rsocket diff --git a/rsocket/framing/Framer.h b/rsocket/framing/Framer.h new file mode 100644 index 000000000..2ff740492 --- /dev/null +++ b/rsocket/framing/Framer.h @@ -0,0 +1,73 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "rsocket/framing/ProtocolVersion.h" +#include "rsocket/internal/Common.h" + +namespace rsocket { + +/// +/// Frames class is used to parse individual rsocket frames from the stream of +/// incoming payload chunks. Every time a frame is parsed the onFrame method is +/// invoked. +/// Each rsocket frame is prepended with the frame length by +/// prependSize method. +/// +class Framer { + public: + Framer(ProtocolVersion protocolVersion, bool stripFrameLengthField) + : protocolVersion_{protocolVersion}, + stripFrameLengthField_{stripFrameLengthField} {} + virtual ~Framer() {} + + /// For processing incoming frame chunks + void addFrameChunk(std::unique_ptr); + + /// Prepends payload size to the beginning of he IOBuf based on the + /// set protocol version + std::unique_ptr prependSize( + std::unique_ptr payload); + + /// derived class can override this method to react to termination + virtual void error(const char*) = 0; + virtual void onFrame(std::unique_ptr) = 0; + + ProtocolVersion protocolVersion() const { + return protocolVersion_; + } + + StreamId peekStreamId(const folly::IOBuf& frame, bool) const; + + std::unique_ptr drainPayloadQueue(); + + private: + // to explicitly trigger parsing frames + void parseFrames(); + bool ensureOrAutodetectProtocolVersion(); + + size_t readFrameLength() const; + size_t frameSizeFieldLength() const; + size_t minimalFrameLength() const; + size_t frameSizeWithLengthField(size_t frameSize) const; + size_t frameSizeWithoutLengthField(size_t frameSize) const; + + folly::IOBufQueue payloadQueue_{folly::IOBufQueue::cacheChainLength()}; + ProtocolVersion protocolVersion_; + const bool stripFrameLengthField_; +}; + +} // namespace rsocket diff --git a/rsocket/framing/ProtocolVersion.cpp b/rsocket/framing/ProtocolVersion.cpp new file mode 100644 index 000000000..ee8f54c5f --- /dev/null +++ b/rsocket/framing/ProtocolVersion.cpp @@ -0,0 +1,32 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/ProtocolVersion.h" + +#include +#include + +namespace rsocket { + +const ProtocolVersion ProtocolVersion::Unknown = ProtocolVersion( + std::numeric_limits::max(), + std::numeric_limits::max()); + +const ProtocolVersion ProtocolVersion::Latest = ProtocolVersion(1, 0); + +std::ostream& operator<<(std::ostream& os, const ProtocolVersion& version) { + return os << version.major << "." << version.minor; +} + +} // namespace rsocket diff --git a/rsocket/framing/ProtocolVersion.h b/rsocket/framing/ProtocolVersion.h new file mode 100644 index 000000000..3daf24dad --- /dev/null +++ b/rsocket/framing/ProtocolVersion.h @@ -0,0 +1,75 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace rsocket { + +// Bug in GCC: https://bugzilla.redhat.com/show_bug.cgi?id=130601 +#pragma push_macro("major") +#pragma push_macro("minor") +#undef major +#undef minor + +struct ProtocolVersion { + uint16_t major{}; + uint16_t minor{}; + + constexpr ProtocolVersion() = default; + constexpr ProtocolVersion(uint16_t _major, uint16_t _minor) + : major(_major), minor(_minor) {} + + static const ProtocolVersion Unknown; + static const ProtocolVersion Latest; +}; + +#pragma pop_macro("major") +#pragma pop_macro("minor") + +std::ostream& operator<<(std::ostream&, const ProtocolVersion&); + +constexpr bool operator==( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return left.major == right.major && left.minor == right.minor; +} + +constexpr bool operator!=( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return !(left == right); +} + +constexpr bool operator<( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return left != ProtocolVersion::Unknown && + right != ProtocolVersion::Unknown && + (left.major < right.major || + (left.major == right.major && left.minor < right.minor)); +} + +constexpr bool operator>( + const ProtocolVersion& left, + const ProtocolVersion& right) { + return left != ProtocolVersion::Unknown && + right != ProtocolVersion::Unknown && + (left.major > right.major || + (left.major == right.major && left.minor > right.minor)); +} + +} // namespace rsocket diff --git a/rsocket/framing/ResumeIdentificationToken.cpp b/rsocket/framing/ResumeIdentificationToken.cpp new file mode 100644 index 000000000..31011f14d --- /dev/null +++ b/rsocket/framing/ResumeIdentificationToken.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/ResumeIdentificationToken.h" + +#include +#include +#include +#include + +#include +#include + +namespace rsocket { + +constexpr const char* kHexChars = "0123456789abcdef"; + +ResumeIdentificationToken::ResumeIdentificationToken() {} + +ResumeIdentificationToken::ResumeIdentificationToken(const std::string& token) { + const auto getNibble = [&token](size_t i) { + uint8_t nibble; + if (token[i] >= '0' && token[i] <= '9') { + nibble = token[i] - '0'; + } else if (token[i] >= 'a' && token[i] <= 'f') { + nibble = token[i] - 'a' + 10; + } else { + throw std::invalid_argument("ResumeToken not in right format: " + token); + } + return nibble; + }; + if (token.size() < 2 || token[0] != '0' || token[1] != 'x' || + (token.size() % 2) != 0) { + throw std::invalid_argument("ResumeToken not in right format: " + token); + } + for (size_t i = 2 /* skipping '0x' */; i < token.size(); i += 2) { + const uint8_t firstNibble = getNibble(i + 0); + const uint8_t secondNibble = getNibble(i + 1); + bits_.push_back((firstNibble << 4) | secondNibble); + } +} + +ResumeIdentificationToken ResumeIdentificationToken::generateNew() { + constexpr size_t kSize = 16; + std::vector data; + data.reserve(kSize); + for (size_t i = 0; i < kSize; i++) { + data.push_back(static_cast(folly::Random::rand32())); + } + return ResumeIdentificationToken(std::move(data)); +} + +void ResumeIdentificationToken::set(std::vector newBits) { + CHECK(newBits.size() <= std::numeric_limits::max()); + bits_ = std::move(newBits); +} + +std::string ResumeIdentificationToken::str() const { + std::stringstream out; + out << *this; + return out.str(); +} + +std::ostream& operator<<( + std::ostream& out, + const ResumeIdentificationToken& token) { + out << "0x"; + for (const auto b : token.data()) { + out << kHexChars[(b & 0xF0) >> 4]; + out << kHexChars[b & 0x0F]; + } + return out; +} + +} // namespace rsocket diff --git a/rsocket/framing/ResumeIdentificationToken.h b/rsocket/framing/ResumeIdentificationToken.h new file mode 100644 index 000000000..be276ec3e --- /dev/null +++ b/rsocket/framing/ResumeIdentificationToken.h @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace rsocket { + +class ResumeIdentificationToken { + public: + /// Creates an empty token. + ResumeIdentificationToken(); + + // The string token and ::str() function should complement each other. The + // string representation should be of the format + // 0x44ab7cf01fd290b63140d01ee789cfb6 + explicit ResumeIdentificationToken(const std::string&); + + static ResumeIdentificationToken generateNew(); + + const std::vector& data() const { + return bits_; + } + + void set(std::vector newBits); + + bool operator==(const ResumeIdentificationToken& right) const { + return data() == right.data(); + } + + bool operator!=(const ResumeIdentificationToken& right) const { + return data() != right.data(); + } + + bool operator<(const ResumeIdentificationToken& right) const { + return data() < right.data(); + } + + std::string str() const; + + private: + explicit ResumeIdentificationToken(std::vector bits) + : bits_(std::move(bits)) {} + + std::vector bits_; +}; + +std::ostream& operator<<(std::ostream&, const ResumeIdentificationToken&); + +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameProcessor.cpp b/rsocket/framing/ScheduledFrameProcessor.cpp new file mode 100644 index 000000000..e1abeade9 --- /dev/null +++ b/rsocket/framing/ScheduledFrameProcessor.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/ScheduledFrameProcessor.h" + +namespace rsocket { + +ScheduledFrameProcessor::ScheduledFrameProcessor( + std::shared_ptr processor, + folly::EventBase* evb) + : evb_{evb}, processor_{std::move(processor)} {} + +ScheduledFrameProcessor::~ScheduledFrameProcessor() = default; + +void ScheduledFrameProcessor::processFrame( + std::unique_ptr ioBuf) { + CHECK(processor_) << "Calling processFrame() after onTerminal()"; + + evb_->runInEventBaseThread( + [processor = processor_, buf = std::move(ioBuf)]() mutable { + processor->processFrame(std::move(buf)); + }); +} + +void ScheduledFrameProcessor::onTerminal(folly::exception_wrapper ew) { + evb_->runInEventBaseThread( + [e = std::move(ew), processor = std::move(processor_)]() mutable { + processor->onTerminal(std::move(e)); + }); +} + +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameProcessor.h b/rsocket/framing/ScheduledFrameProcessor.h new file mode 100644 index 000000000..e4546af79 --- /dev/null +++ b/rsocket/framing/ScheduledFrameProcessor.h @@ -0,0 +1,44 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/framing/FrameProcessor.h" + +namespace rsocket { + +// This class is a wrapper around FrameProcessor which ensures all methods of +// FrameProcessor get executed in a particular EventBase. +// +// This is currently used in the server where the resumed Transport of the +// client is on a different EventBase compared to the EventBase on which the +// original RSocketStateMachine was constructed for the client. Here the +// transport uses this class to schedule events of the RSocketStateMachine +// (FrameProcessor) in the original EventBase. +class ScheduledFrameProcessor : public FrameProcessor { + public: + ScheduledFrameProcessor(std::shared_ptr, folly::EventBase*); + ~ScheduledFrameProcessor(); + + void processFrame(std::unique_ptr) override; + void onTerminal(folly::exception_wrapper) override; + + private: + folly::EventBase* const evb_; + std::shared_ptr processor_; +}; + +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameTransport.cpp b/rsocket/framing/ScheduledFrameTransport.cpp new file mode 100644 index 000000000..88f715f16 --- /dev/null +++ b/rsocket/framing/ScheduledFrameTransport.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/ScheduledFrameTransport.h" + +#include "rsocket/framing/ScheduledFrameProcessor.h" + +namespace rsocket { + +ScheduledFrameTransport::~ScheduledFrameTransport() = default; + +void ScheduledFrameTransport::setFrameProcessor( + std::shared_ptr fp) { + CHECK(frameTransport_) << "Inner transport already closed"; + + transportEvb_->runInEventBaseThread([stateMachineEvb = stateMachineEvb_, + transport = frameTransport_, + fp = std::move(fp)]() mutable { + auto scheduledFP = std::make_shared( + std::move(fp), stateMachineEvb); + transport->setFrameProcessor(std::move(scheduledFP)); + }); +} + +void ScheduledFrameTransport::outputFrameOrDrop( + std::unique_ptr ioBuf) { + CHECK(frameTransport_) << "Inner transport already closed"; + + transportEvb_->runInEventBaseThread( + [transport = frameTransport_, buf = std::move(ioBuf)]() mutable { + transport->outputFrameOrDrop(std::move(buf)); + }); +} + +void ScheduledFrameTransport::close() { + CHECK(frameTransport_) << "Inner transport already closed"; + + transportEvb_->runInEventBaseThread( + [transport = std::move(frameTransport_)]() { transport->close(); }); +} + +bool ScheduledFrameTransport::isConnectionFramed() const { + CHECK(frameTransport_) << "Inner transport already closed"; + return frameTransport_->isConnectionFramed(); +} + +} // namespace rsocket diff --git a/rsocket/framing/ScheduledFrameTransport.h b/rsocket/framing/ScheduledFrameTransport.h new file mode 100644 index 000000000..cc53f9444 --- /dev/null +++ b/rsocket/framing/ScheduledFrameTransport.h @@ -0,0 +1,63 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/framing/FrameTransport.h" + +namespace rsocket { + +// This class is a wrapper around FrameTransport which ensures all methods of +// FrameTransport get executed in a particular EventBase. +// +// This is currently used in the server where the resumed Transport of the +// client is on a different EventBase compared to the EventBase on which the +// original RSocketStateMachine was constructed for the client. Here the +// RSocketStateMachine uses this class to schedule events of the Transport in +// the new EventBase. +class ScheduledFrameTransport : public FrameTransport { + public: + ScheduledFrameTransport( + std::shared_ptr frameTransport, + folly::EventBase* transportEvb, + folly::EventBase* stateMachineEvb) + : transportEvb_(transportEvb), + stateMachineEvb_(stateMachineEvb), + frameTransport_(std::move(frameTransport)) {} + + ~ScheduledFrameTransport(); + + void setFrameProcessor(std::shared_ptr) override; + void outputFrameOrDrop(std::unique_ptr) override; + void close() override; + bool isConnectionFramed() const override; + + private: + DuplexConnection* getConnection() override { + DLOG(FATAL) + << "ScheduledFrameTransport doesn't support getConnection method, " + "because it can create safe usage issues when EventBase of the " + "transport and the RSocketClient is not the same."; + return nullptr; + } + + private: + folly::EventBase* const transportEvb_; + folly::EventBase* const stateMachineEvb_; + std::shared_ptr frameTransport_; +}; + +} // namespace rsocket diff --git a/rsocket/internal/Allowance.h b/rsocket/internal/Allowance.h new file mode 100644 index 000000000..059dd3c47 --- /dev/null +++ b/rsocket/internal/Allowance.h @@ -0,0 +1,85 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace rsocket { + +class Allowance { + public: + using ValueType = size_t; + + Allowance() = default; + + explicit Allowance(ValueType initialValue) : value_(initialValue) {} + + bool tryConsume(ValueType n) { + if (!canConsume(n)) { + return false; + } + value_ -= n; + return true; + } + + ValueType add(ValueType n) { + auto old_value = value_; + value_ += n; + if (old_value > value_) { + value_ = max(); + } + return old_value; + } + + bool canConsume(ValueType n) const { + return value_ >= n; + } + + ValueType consumeAll() { + return consumeUpTo(max()); + } + + ValueType consumeUpTo(ValueType limit) { + if (limit > value_) { + limit = value_; + } + value_ -= limit; + return limit; + } + + explicit operator bool() const { + return value_; + } + + ValueType get() const { + return value_; + } + + static ValueType max() { + return std::numeric_limits::max(); + } + + private: + static_assert( + !std::numeric_limits::is_signed, + "Allowance representation must be an unsigned type"); + static_assert( + std::numeric_limits::is_integer, + "Allowance representation must be an integer type"); + ValueType value_{0}; +}; +} // namespace rsocket diff --git a/rsocket/internal/AllowanceSemaphore.h b/rsocket/internal/AllowanceSemaphore.h deleted file mode 100644 index 20611b7d9..000000000 --- a/rsocket/internal/AllowanceSemaphore.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include - -namespace rsocket { - -class AllowanceSemaphore { - public: - using ValueType = size_t; - - AllowanceSemaphore() = default; - - explicit AllowanceSemaphore(ValueType initialValue) : value_(initialValue) {} - - bool tryAcquire(ValueType n = 1) { - if (!canAcquire(n)) { - return false; - } - value_ -= n; - return true; - } - - ValueType release(ValueType n) { - auto old_value = value_; - value_ += n; - if (old_value > value_) { - value_ = max(); - } - return old_value; - } - - bool canAcquire(ValueType n = 1) const { - return value_ >= n; - } - - ValueType drain() { - return drainWithLimit(max()); - } - - ValueType drainWithLimit(ValueType limit) { - if (limit > value_) { - limit = value_; - } - value_ -= limit; - return limit; - } - - explicit operator bool() const { - return value_; - } - - static ValueType max() { - return std::numeric_limits::max(); - } - - private: - static_assert( - !std::numeric_limits::is_signed, - "Allowance representation must be an unsigned type"); - static_assert( - std::numeric_limits::is_integer, - "Allowance representation must be an integer type"); - ValueType value_{0}; -}; -} // reactivesocket diff --git a/rsocket/internal/ClientResumeStatusCallback.h b/rsocket/internal/ClientResumeStatusCallback.h index 0ec5078df..abe20fc9d 100644 --- a/rsocket/internal/ClientResumeStatusCallback.h +++ b/rsocket/internal/ClientResumeStatusCallback.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -18,4 +30,4 @@ class ClientResumeStatusCallback { virtual void onResumeError(folly::exception_wrapper ex) noexcept = 0; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/Common.cpp b/rsocket/internal/Common.cpp index 292387864..fbd33f592 100644 --- a/rsocket/internal/Common.cpp +++ b/rsocket/internal/Common.cpp @@ -1,22 +1,29 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/Common.h" +#include + #include #include #include +#include #include namespace rsocket { -namespace { -constexpr const char* HEX_CHARS = {"0123456789abcdef"}; -} - -constexpr const ProtocolVersion ProtocolVersion::Unknown = ProtocolVersion( - std::numeric_limits::max(), - std::numeric_limits::max()); - static const char* getTerminatingSignalErrorMessage(int terminatingSignal) { switch (static_cast(terminatingSignal)) { case StreamCompletionSignal::CONNECTION_END: @@ -44,6 +51,37 @@ static const char* getTerminatingSignalErrorMessage(int terminatingSignal) { } } +folly::StringPiece toString(StreamType t) { + switch (t) { + case StreamType::REQUEST_RESPONSE: + return "REQUEST_RESPONSE"; + case StreamType::STREAM: + return "STREAM"; + case StreamType::CHANNEL: + return "CHANNEL"; + case StreamType::FNF: + return "FNF"; + default: + DCHECK(false); + return "(invalid StreamType)"; + } +} + +std::ostream& operator<<(std::ostream& os, StreamType t) { + return os << toString(t); +} + +std::ostream& operator<<(std::ostream& os, RSocketMode mode) { + switch (mode) { + case RSocketMode::CLIENT: + return os << "CLIENT"; + case RSocketMode::SERVER: + return os << "SERVER"; + } + DLOG(FATAL) << "Invalid RSocketMode"; + return os << "INVALID_RSOCKET_MODE"; +} + std::string to_string(StreamCompletionSignal signal) { switch (signal) { case StreamCompletionSignal::COMPLETE: @@ -69,6 +107,7 @@ std::string to_string(StreamCompletionSignal signal) { } // this should be never hit because the switch is over all cases LOG(FATAL) << "unknown StreamCompletionSignal=" << static_cast(signal); + return ""; } std::ostream& operator<<(std::ostream& os, StreamCompletionSignal signal) { @@ -79,35 +118,23 @@ StreamInterruptedException::StreamInterruptedException(int _terminatingSignal) : std::runtime_error(getTerminatingSignalErrorMessage(_terminatingSignal)), terminatingSignal(_terminatingSignal) {} -ResumeIdentificationToken::ResumeIdentificationToken() {} +std::string humanify(std::unique_ptr const& buf) { + std::string ret; + size_t cursor = 0; -ResumeIdentificationToken ResumeIdentificationToken::generateNew() { - constexpr size_t kSize = 16; - std::vector data; - data.reserve(kSize); - for (size_t i = 0; i < kSize; i++) { - data.push_back(static_cast(folly::Random::rand32())); + for (const auto& range : *buf) { + for (const unsigned char chr : range) { + if (cursor >= 20) + goto outer; + ret += chr; + cursor++; + } } - return ResumeIdentificationToken(std::move(data)); -} +outer: -void ResumeIdentificationToken::set(std::vector newBits) { - CHECK(newBits.size() <= std::numeric_limits::max()); - bits_ = std::move(newBits); + return folly::humanify(ret); } - -std::ostream& operator<<( - std::ostream& out, - const ResumeIdentificationToken& token) { - out << "0x"; - for (auto b : token.data()) { - out << HEX_CHARS[(b & 0xF0) >> 4]; - out << HEX_CHARS[b & 0x0F]; - } - return out; -} - std::string hexDump(folly::StringPiece s) { - return folly::hexDump(s.data(), s.size()); + return folly::hexDump(s.data(), std::min(0xFF, s.size())); } -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/Common.h b/rsocket/internal/Common.h index 8dd021036..a096a5545 100644 --- a/rsocket/internal/Common.h +++ b/rsocket/internal/Common.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -21,18 +33,18 @@ class IOBuf; template class Range; typedef Range StringPiece; -} +} // namespace folly namespace rsocket { -constexpr int64_t kMaxRequestN = std::numeric_limits::max(); - /// A unique identifier of a stream. using StreamId = uint32_t; -using ResumePosition = int64_t; -constexpr const ResumePosition kUnspecifiedResumePosition = -1; +constexpr std::chrono::seconds kDefaultKeepaliveInterval{5}; + +constexpr int64_t kMaxRequestN = std::numeric_limits::max(); +std::string humanify(std::unique_ptr const&); std::string hexDump(folly::StringPiece s); /// Indicates the reason why the stream stateMachine received a terminal signal @@ -50,7 +62,9 @@ enum class StreamCompletionSignal { SOCKET_CLOSED, }; -enum class ReactiveSocketMode { SERVER, CLIENT }; +enum class RSocketMode : uint8_t { SERVER, CLIENT }; + +std::ostream& operator<<(std::ostream&, RSocketMode); enum class StreamType { REQUEST_RESPONSE, @@ -59,111 +73,23 @@ enum class StreamType { FNF, }; +folly::StringPiece toString(StreamType); +std::ostream& operator<<(std::ostream&, StreamType); + +enum class RequestOriginator { + LOCAL, + REMOTE, +}; + std::string to_string(StreamCompletionSignal); std::ostream& operator<<(std::ostream&, StreamCompletionSignal); class StreamInterruptedException : public std::runtime_error { public: explicit StreamInterruptedException(int _terminatingSignal); - int terminatingSignal; + const int terminatingSignal; }; -class ResumeIdentificationToken { - public: - /// Creates an empty token. - ResumeIdentificationToken(); - static ResumeIdentificationToken generateNew(); - - const std::vector& data() const { - return bits_; - } - - void set(std::vector newBits); - - bool operator==(const ResumeIdentificationToken& right) const { - return data() == right.data(); - } - - bool operator!=(const ResumeIdentificationToken& right) const { - return data() != right.data(); - } - - bool operator<(const ResumeIdentificationToken& right) const { - return data() < right.data(); - } - - private: - explicit ResumeIdentificationToken(std::vector bits) - : bits_(std::move(bits)) {} - - std::vector bits_; -}; - -std::ostream& operator<<(std::ostream&, const ResumeIdentificationToken&); - -// bug in GCC: https://bugzilla.redhat.com/show_bug.cgi?id=130601 -#pragma push_macro("major") -#pragma push_macro("minor") -#undef major -#undef minor - -struct ProtocolVersion { - uint16_t major{}; - uint16_t minor{}; - - constexpr ProtocolVersion() = default; - constexpr ProtocolVersion(uint16_t _major, uint16_t _minor) - : major(_major), minor(_minor) {} - - static const ProtocolVersion Unknown; - static const ProtocolVersion Latest; -}; - -#pragma pop_macro("major") -#pragma pop_macro("minor") - -std::ostream& operator<<(std::ostream&, const ProtocolVersion&); - -constexpr inline bool operator==( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return left.major == right.major && left.minor == right.minor; -} - -constexpr inline bool operator!=( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return !(left == right); -} - -constexpr inline bool operator<( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return left != ProtocolVersion::Unknown && - right != ProtocolVersion::Unknown && - (left.major < right.major || - (left.major == right.major && left.minor < right.minor)); -} - -constexpr inline bool operator>( - const ProtocolVersion& left, - const ProtocolVersion& right) { - return left != ProtocolVersion::Unknown && - right != ProtocolVersion::Unknown && - (left.major > right.major || - (left.major == right.major && left.minor > right.minor)); -} - class FrameSink; -// Client Side Keepalive Timer -class KeepaliveTimer { - public: - virtual ~KeepaliveTimer() = default; - - virtual std::chrono::milliseconds keepaliveTime() = 0; - virtual void stop() = 0; - virtual void start(const std::shared_ptr& connection) = 0; - virtual void keepaliveReceived() = 0; -}; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/ConnectionSet.cpp b/rsocket/internal/ConnectionSet.cpp new file mode 100644 index 000000000..0ed32db6a --- /dev/null +++ b/rsocket/internal/ConnectionSet.cpp @@ -0,0 +1,108 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/internal/ConnectionSet.h" + +#include "rsocket/statemachine/RSocketStateMachine.h" + +#include + +namespace rsocket { + +ConnectionSet::ConnectionSet() {} + +ConnectionSet::~ConnectionSet() { + if (!shutDown_) { + shutdownAndWait(); + } +} + +void ConnectionSet::shutdownAndWait() { + VLOG(1) << "Started ConnectionSet::shutdownAndWait"; + shutDown_ = true; + + SCOPE_EXIT { + VLOG(1) << "Finished ConnectionSet::shutdownAndWait"; + }; + + StateMachineMap map; + + // Move all the connections out of the synchronized map so we don't block + // while closing the state machines. + { + const auto locked = machines_.lock(); + if (locked->empty()) { + VLOG(2) << "No connections to close, early exit"; + return; + } + + targetRemoves_ = removes_ + locked->size(); + map.swap(*locked); + } + + VLOG(2) << "Need to close " << map.size() << " connections"; + + for (auto& kv : map) { + auto rsocket = std::move(kv.first); + auto evb = kv.second; + + const auto close = [rs = std::move(rsocket)] { + rs->close({}, StreamCompletionSignal::SOCKET_CLOSED); + }; + + // We could be closing on the same thread as the state machine. In that + // case, close the state machine inline, otherwise we hang. + if (evb->isInEventBaseThread()) { + VLOG(3) << "Closing connection inline"; + close(); + } else { + VLOG(3) << "Closing connection asynchronously"; + evb->runInEventBaseThread(close); + } + } + + VLOG(2) << "Waiting for connections to close"; + shutdownDone_.wait(); + VLOG(2) << "Connections have closed"; +} + +bool ConnectionSet::insert( + std::shared_ptr machine, + folly::EventBase* evb) { + VLOG(4) << "insert(" << machine.get() << ", " << evb << ")"; + + if (shutDown_) { + return false; + } + machines_.lock()->emplace(std::move(machine), evb); + return true; +} + +void ConnectionSet::remove(RSocketStateMachine& machine) { + VLOG(4) << "remove(" << &machine << ")"; + + const auto locked = machines_.lock(); + auto const result = locked->erase(machine.shared_from_this()); + DCHECK_LE(result, 1); + + if (++removes_ == targetRemoves_) { + shutdownDone_.post(); + } +} + +size_t ConnectionSet::size() const { + return machines_.lock()->size(); +} + +} // namespace rsocket diff --git a/rsocket/internal/ConnectionSet.h b/rsocket/internal/ConnectionSet.h new file mode 100644 index 000000000..b679b96f2 --- /dev/null +++ b/rsocket/internal/ConnectionSet.h @@ -0,0 +1,60 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include + +#include "rsocket/statemachine/RSocketStateMachine.h" + +namespace folly { +class EventBase; +} + +namespace rsocket { + +/// Set of RSocketStateMachine objects. Stores them until they call +/// RSocketStateMachine::close(). +/// +/// Also tracks which EventBase is controlling each state machine so that they +/// can be closed on the correct thread. +class ConnectionSet : public RSocketStateMachine::CloseCallback { + public: + ConnectionSet(); + virtual ~ConnectionSet(); + + bool insert(std::shared_ptr, folly::EventBase*); + void remove(RSocketStateMachine&) override; + + size_t size() const; + + void shutdownAndWait(); + + private: + using StateMachineMap = std:: + unordered_map, folly::EventBase*>; + + folly::Synchronized machines_; + folly::Baton<> shutdownDone_; + size_t removes_{0}; + size_t targetRemoves_{0}; + std::atomic shutDown_{false}; +}; + +} // namespace rsocket diff --git a/rsocket/internal/FollyKeepaliveTimer.cpp b/rsocket/internal/FollyKeepaliveTimer.cpp deleted file mode 100644 index 8398722d9..000000000 --- a/rsocket/internal/FollyKeepaliveTimer.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/internal/FollyKeepaliveTimer.h" - -namespace rsocket { - -FollyKeepaliveTimer::FollyKeepaliveTimer( - folly::EventBase& eventBase, - std::chrono::milliseconds period) - : eventBase_(eventBase), - generation_(std::make_shared(0)), - period_(period) {} - -FollyKeepaliveTimer::~FollyKeepaliveTimer() { - stop(); -} - -std::chrono::milliseconds FollyKeepaliveTimer::keepaliveTime() { - return period_; -} - -void FollyKeepaliveTimer::schedule() { - auto scheduledGeneration = *generation_; - auto generation = generation_; - eventBase_.runAfterDelay( - [this, generation, scheduledGeneration]() { - if (*generation == scheduledGeneration) { - sendKeepalive(); - } - }, - static_cast(keepaliveTime().count())); -} - -void FollyKeepaliveTimer::sendKeepalive() { - if (pending_) { - // Make sure connection_ is not deleted (via external call to stop) - // while we still mid-operation - auto localPtr = connection_; - stop(); - // TODO: we need to use max lifetime from the setup frame for this - localPtr->disconnectOrCloseWithError( - Frame_ERROR::connectionError("no response to keepalive")); - } else { - connection_->sendKeepalive(); - pending_ = true; - schedule(); - } -} - -// must be called from the same thread as start -void FollyKeepaliveTimer::stop() { - *generation_ += 1; - pending_ = false; - connection_ = nullptr; -} - -// must be called from the same thread as stop -void FollyKeepaliveTimer::start(const std::shared_ptr& connection) { - connection_ = connection; - *generation_ += 1; - DCHECK(!pending_); - - schedule(); -} - -void FollyKeepaliveTimer::keepaliveReceived() { - pending_ = false; -} -} diff --git a/rsocket/internal/FollyKeepaliveTimer.h b/rsocket/internal/FollyKeepaliveTimer.h deleted file mode 100644 index 37e44e638..000000000 --- a/rsocket/internal/FollyKeepaliveTimer.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "rsocket/statemachine/RSocketStateMachine.h" - -namespace rsocket { - -class FollyKeepaliveTimer : public KeepaliveTimer { - public: - FollyKeepaliveTimer( - folly::EventBase& eventBase, - std::chrono::milliseconds period); - - ~FollyKeepaliveTimer(); - - std::chrono::milliseconds keepaliveTime() override; - - void schedule(); - - void stop() override; - - void start(const std::shared_ptr& connection) override; - - void sendKeepalive(); - - void keepaliveReceived() override; - - private: - std::shared_ptr connection_; - folly::EventBase& eventBase_; - std::shared_ptr generation_; - std::chrono::milliseconds period_; - std::atomic pending_{false}; -}; -} diff --git a/rsocket/internal/KeepaliveTimer.cpp b/rsocket/internal/KeepaliveTimer.cpp new file mode 100644 index 000000000..6fdaa39d0 --- /dev/null +++ b/rsocket/internal/KeepaliveTimer.cpp @@ -0,0 +1,87 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/internal/KeepaliveTimer.h" + +namespace rsocket { + +KeepaliveTimer::KeepaliveTimer( + std::chrono::milliseconds period, + folly::EventBase& eventBase) + : eventBase_(eventBase), + generation_(std::make_shared(0)), + period_(period) {} + +KeepaliveTimer::~KeepaliveTimer() { + stop(); +} + +std::chrono::milliseconds KeepaliveTimer::keepaliveTime() const { + return period_; +} + +void KeepaliveTimer::schedule() { + const auto scheduledGeneration = *generation_; + const auto generation = generation_; + eventBase_.runAfterDelay( + [this, + wpConnection = std::weak_ptr(connection_), + generation, + scheduledGeneration]() { + auto spConnection = wpConnection.lock(); + if (!spConnection) { + return; + } + if (*generation == scheduledGeneration) { + sendKeepalive(*spConnection); + } + }, + static_cast(keepaliveTime().count())); +} + +void KeepaliveTimer::sendKeepalive(FrameSink& sink) { + if (pending_) { + stop(); + // TODO: we need to use max lifetime from the setup frame for this + sink.disconnectOrCloseWithError( + Frame_ERROR::connectionError("no response to keepalive")); + } else { + // this must happen before sendKeepalive as it can potentially result in + // stop() being called + pending_ = true; + sink.sendKeepalive(); + schedule(); + } +} + +// must be called from the same thread as start +void KeepaliveTimer::stop() { + *generation_ += 1; + pending_ = false; + connection_.reset(); +} + +// must be called from the same thread as stop +void KeepaliveTimer::start(const std::shared_ptr& connection) { + connection_ = connection; + *generation_ += 1; + DCHECK(!pending_); + + schedule(); +} + +void KeepaliveTimer::keepaliveReceived() { + pending_ = false; +} +} // namespace rsocket diff --git a/rsocket/internal/KeepaliveTimer.h b/rsocket/internal/KeepaliveTimer.h new file mode 100644 index 000000000..51bb6c3c2 --- /dev/null +++ b/rsocket/internal/KeepaliveTimer.h @@ -0,0 +1,48 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/statemachine/RSocketStateMachine.h" + +namespace rsocket { + +class KeepaliveTimer { + public: + KeepaliveTimer(std::chrono::milliseconds period, folly::EventBase& eventBase); + + ~KeepaliveTimer(); + + std::chrono::milliseconds keepaliveTime() const; + + void schedule(); + + void stop(); + + void start(const std::shared_ptr& connection); + + void sendKeepalive(FrameSink& sink); + + void keepaliveReceived(); + + private: + std::shared_ptr connection_; + folly::EventBase& eventBase_; + const std::shared_ptr generation_; + const std::chrono::milliseconds period_; + std::atomic pending_{false}; +}; +} // namespace rsocket diff --git a/rsocket/internal/RSocketConnectionManager.cpp b/rsocket/internal/RSocketConnectionManager.cpp deleted file mode 100644 index 5ca326390..000000000 --- a/rsocket/internal/RSocketConnectionManager.cpp +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/internal/RSocketConnectionManager.h" - -#include -#include -#include - -#include "rsocket/RSocketConnectionEvents.h" -#include "rsocket/statemachine/RSocketStateMachine.h" - -namespace rsocket { - -RSocketConnectionManager::~RSocketConnectionManager() { - // Asynchronously close all existing ReactiveSockets. If there are none, then - // we can do an early exit. - VLOG(1) << "Destroying RSocketConnectionManager..."; - auto scopeGuard = folly::makeGuard([]{ VLOG(1) << "Destroying RSocketConnectionManager... DONE"; }); - - { - auto locked = sockets_.lock(); - if (locked->empty()) { - return; - } - - shutdown_.emplace(); - - for (auto& connectionPair : *locked) { - // close() has to be called on the same executor as the socket - auto& executor_ = connectionPair.second; - executor_.add([rs = std::move(connectionPair.first)] { - rs->close( - folly::exception_wrapper(), StreamCompletionSignal::SOCKET_CLOSED); - }); - } - } - - // Wait for all ReactiveSockets to close. - shutdown_->wait(); - DCHECK(sockets_.lock()->empty()); -} - -void RSocketConnectionManager::manageConnection( - std::shared_ptr socket, - folly::EventBase& eventBase) { - class ConnectionEventsWrapper : public RSocketConnectionEvents { - public: - ConnectionEventsWrapper( - RSocketConnectionManager& connectionManager, - std::shared_ptr socket, - folly::EventBase& eventBase) - : connectionManager_(connectionManager), - socket_(std::move(socket)), - eventBase_(eventBase) {} - - void onConnected() override { - if (inner) { - inner->onConnected(); - } - } - - void onDisconnected(const folly::exception_wrapper& ex) override { - if (inner) { - inner->onDisconnected(ex); - } - } - - void onClosed(const folly::exception_wrapper& ex) override { - // Enqueue another event to remove and delete it. We cannot delete - // the RSocketStateMachine now as it still needs to finish processing - // the onClosed handlers in the stack frame above us. - eventBase_.add([connectionManager = &connectionManager_, socket = std::move(socket_)] { - connectionManager->removeConnection(socket); - }); - - if (inner) { - inner->onClosed(ex); - } - } - - RSocketConnectionManager& connectionManager_; - std::shared_ptr socket_; - folly::EventBase& eventBase_; - - std::shared_ptr inner; - }; - - auto connectionEventsWrapper = - std::make_shared(*this, socket, eventBase); - connectionEventsWrapper->inner = std::move(socket->connectionEvents()); - socket->connectionEvents() = std::move(connectionEventsWrapper); - - sockets_.lock()->insert({std::move(socket), eventBase}); -} - -void RSocketConnectionManager::removeConnection( - const std::shared_ptr& socket) { - auto locked = sockets_.lock(); - locked->erase(socket); - - VLOG(2) << "Removed RSocketStateMachine"; - - if (shutdown_ && locked->empty()) { - shutdown_->post(); - } -} -} // namespace rsocket diff --git a/rsocket/internal/RSocketConnectionManager.h b/rsocket/internal/RSocketConnectionManager.h deleted file mode 100644 index 0634fd8d2..000000000 --- a/rsocket/internal/RSocketConnectionManager.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include - -#include -#include - -namespace folly { -class EventBase; -} - -namespace rsocket { - -class RSocketStateMachine; - -class RSocketConnectionManager { - public: - ~RSocketConnectionManager(); - - void manageConnection( - std::shared_ptr, - folly::EventBase&); - - private: - void removeConnection(const std::shared_ptr&); - - /// Set of currently open ReactiveSockets. - folly::Synchronized< - std::unordered_map< - std::shared_ptr, - folly::EventBase&>, - std::mutex> - sockets_; - - folly::Optional> shutdown_; -}; -} // namespace rsocket diff --git a/rsocket/internal/ResumeCache.cpp b/rsocket/internal/ResumeCache.cpp deleted file mode 100644 index d02cef90b..000000000 --- a/rsocket/internal/ResumeCache.cpp +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/internal/ResumeCache.h" - -#include - -#include "rsocket/framing/Frame.h" -#include "rsocket/framing/FrameTransport.h" -#include "rsocket/statemachine/RSocketStateMachine.h" - -namespace { - -using rsocket::FrameType; - -bool shouldTrackFrame(const FrameType frameType) { - switch (frameType) { - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_STREAM: - case FrameType::REQUEST_RESPONSE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_N: - case FrameType::CANCEL: - case FrameType::ERROR: - case FrameType::PAYLOAD: - return true; - case FrameType::RESERVED: - case FrameType::SETUP: - case FrameType::LEASE: - case FrameType::KEEPALIVE: - case FrameType::METADATA_PUSH: - case FrameType::RESUME: - case FrameType::RESUME_OK: - case FrameType::EXT: - default: - return false; - } -} - -} // anonymous - -namespace rsocket { - -ResumeCache::~ResumeCache() { - clearFrames(position_); -} - -void ResumeCache::trackReceivedFrame( - const folly::IOBuf& serializedFrame, - const FrameType frameType, - const StreamId streamId) { - onStreamOpen(streamId, frameType); - if (shouldTrackFrame(frameType)) { - VLOG(6) << "received frame " << frameType; - // TODO(tmont): this could be expensive, find a better way to get length - impliedPosition_ += serializedFrame.computeChainDataLength(); - } -} - -void ResumeCache::trackSentFrame( - const folly::IOBuf& serializedFrame, - const FrameType frameType, - const folly::Optional streamIdPtr) { - if (streamIdPtr) { - const StreamId streamId = *streamIdPtr; - onStreamOpen(streamId, frameType); - } - - if (shouldTrackFrame(frameType)) { - // TODO(tmont): this could be expensive, find a better way to get length - auto frameDataLength = serializedFrame.computeChainDataLength(); - - // if the frame is too huge, we don't cache it - if (frameDataLength > capacity_) { - resetUpToPosition(position_); - position_ += frameDataLength; - DCHECK(size_ == 0); - return; - } - - addFrame(serializedFrame, frameDataLength); - position_ += frameDataLength; - } -} - -void ResumeCache::resetUpToPosition(ResumePosition position) { - if (position <= resetPosition_) { - return; - } - - if (position > position_) { - position = position_; - } - - clearFrames(position); - - resetPosition_ = position; - DCHECK(frames_.empty() || frames_.front().first == resetPosition_); -} - -bool ResumeCache::isPositionAvailable(ResumePosition position) const { - return (position_ == position) || - std::binary_search( - frames_.begin(), - frames_.end(), - std::make_pair(position, std::unique_ptr()), - [](decltype(frames_.back()) pairA, - decltype(frames_.back()) pairB) { - return pairA.first < pairB.first; - }); -} - -void ResumeCache::addFrame(const folly::IOBuf& frame, size_t frameDataLength) { - size_ += frameDataLength; - while (size_ > capacity_) { - evictFrame(); - } - frames_.emplace_back(position_, frame.clone()); - stats_->resumeBufferChanged(1, static_cast(frameDataLength)); -} - -void ResumeCache::evictFrame() { - DCHECK(!frames_.empty()); - - auto position = - frames_.size() > 1 ? std::next(frames_.begin())->first : position_; - resetUpToPosition(position); -} - -void ResumeCache::clearFrames(ResumePosition position) { - if (frames_.empty()) { - return; - } - DCHECK(position <= position_); - DCHECK(position >= resetPosition_); - - auto end = std::lower_bound( - frames_.begin(), - frames_.end(), - position, - [](decltype(frames_.back()) pair, ResumePosition pos) { - return pair.first < pos; - }); - DCHECK(end == frames_.end() || end->first >= resetPosition_); - auto pos = end == frames_.end() ? position : end->first; - stats_->resumeBufferChanged( - -static_cast(std::distance(frames_.begin(), end)), - -static_cast(pos - resetPosition_)); - - frames_.erase(frames_.begin(), end); - size_ -= static_cast(pos - resetPosition_); -} - -void ResumeCache::sendFramesFromPosition( - ResumePosition position, - FrameTransport& frameTransport) const { - DCHECK(isPositionAvailable(position)); - - if (position == position_) { - // idle resumption - return; - } - - auto found = std::lower_bound( - frames_.begin(), - frames_.end(), - position, - [](decltype(frames_.back()) pair, ResumePosition pos) { - return pair.first < pos; - }); - - DCHECK(found != frames_.end()); - DCHECK(found->first == position); - - while (found != frames_.end()) { - frameTransport.outputFrameOrEnqueue(found->second->clone()); - found++; - } -} - -void ResumeCache::onStreamClosed(StreamId streamId) { - // This is crude. We could try to preserve the stream type in - // RSocketStateMachine and pass it down explicitly here. - activeRequestStreams_.erase(streamId); - activeRequestChannels_.erase(streamId); - activeRequestResponses_.erase(streamId); -} - -void ResumeCache::onStreamOpen(StreamId streamId, FrameType frameType) { - if (frameType == FrameType::REQUEST_STREAM) { - activeRequestStreams_.insert(streamId); - } else if (frameType == FrameType::REQUEST_CHANNEL) { - activeRequestChannels_.insert(streamId); - } else if (frameType == FrameType::REQUEST_RESPONSE) { - activeRequestResponses_.insert(streamId); - } -} - -} // reactivesocket diff --git a/rsocket/internal/ResumeCache.h b/rsocket/internal/ResumeCache.h deleted file mode 100644 index 14a546cf0..000000000 --- a/rsocket/internal/ResumeCache.h +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include - -#include - -#include "rsocket/RSocketStats.h" -#include "rsocket/internal/Common.h" - -namespace folly { -class IOBuf; -} - -namespace rsocket { - -class RSocketStateMachine; -class FrameTransport; - -// This class stores information necessary to resume the RSocket session. The -// stored information fall into two categories. (1) Sent: Here we have a -// buffer queue of sent frames (limited by capacity). We have two pointers - -// position_ and resetPosition_, which track the position (in bytes) of the -// first and last frames we have in queue. (2) Rcvd: We have a -// impliedPosition_ byte counter, which determines the bytes until which we -// have received data from the other side. -class ResumeCache { - public: - explicit ResumeCache( - std::shared_ptr stats, - size_t capacity = DEFAULT_CAPACITY) - : stats_(std::move(stats)), capacity_(capacity) {} - ~ResumeCache(); - - // Tracks a received frame. - void trackReceivedFrame( - const folly::IOBuf& serializedFrame, - const FrameType frameType, - const StreamId streamId); - - // Tracks a sent frame. - void trackSentFrame( - const folly::IOBuf& serializedFrame, - const FrameType frameType, - const folly::Optional streamIdPtr); - - // Resets the send buffer buffer until the given position. - // This is triggered on KeepAlive reception or when we hit capacity. - void resetUpToPosition(ResumePosition position); - - bool isPositionAvailable(ResumePosition position) const; - - void sendFramesFromPosition( - ResumePosition position, - FrameTransport& transport) const; - - ResumePosition lastResetPosition() const { - return resetPosition_; - } - - ResumePosition position() const { - return position_; - } - - ResumePosition impliedPosition() { - return impliedPosition_; - } - - bool canResumeFrom(ResumePosition clientPosition) const { - return clientPosition <= impliedPosition_; - } - - size_t size() const { - return size_; - } - - void onStreamOpen(StreamId streamId, FrameType frameType); - - void onStreamClosed(StreamId streamId); - - private: - void addFrame(const folly::IOBuf&, size_t); - void evictFrame(); - - // Called before clearing cached frames to update stats. - void clearFrames(ResumePosition position); - - std::shared_ptr stats_; - - // End position of the send buffer queue - ResumePosition position_{0}; - // Start position of the send buffer queue - ResumePosition resetPosition_{0}; - // Inferred position of the rcvd frames - ResumePosition impliedPosition_{0}; - - // Active REQUEST_STREAMs are preserved here - std::set activeRequestStreams_; - - // Active REQUEST_CHANNELs are preserved here - std::set activeRequestChannels_; - - // Active REQUEST_RESPONSEs are preserved here - std::set activeRequestResponses_; - - std::deque>> frames_; - - constexpr static size_t DEFAULT_CAPACITY = 1024 * 1024; // 1MB - const size_t capacity_; - size_t size_{0}; -}; -} diff --git a/rsocket/internal/ScheduledRSocketResponder.cpp b/rsocket/internal/ScheduledRSocketResponder.cpp index a047c0513..d534657c8 100644 --- a/rsocket/internal/ScheduledRSocketResponder.cpp +++ b/rsocket/internal/ScheduledRSocketResponder.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ScheduledRSocketResponder.h" @@ -11,67 +23,60 @@ namespace rsocket { ScheduledRSocketResponder::ScheduledRSocketResponder( std::shared_ptr inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) {} + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} -yarpl::Reference> +std::shared_ptr> ScheduledRSocketResponder::handleRequestResponse( Payload request, StreamId streamId) { - auto innerFlowable = inner_->handleRequestResponse(std::move(request), - streamId); + auto innerFlowable = + inner_->handleRequestResponse(std::move(request), streamId); return yarpl::single::Singles::create( - [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( - yarpl::Reference> - observer) { - innerFlowable->subscribe(yarpl::make_ref< - ScheduledSingleObserver> - (std::move(observer), *eventBase)); - }); + [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( + std::shared_ptr> observer) { + innerFlowable->subscribe( + std::make_shared>( + std::move(observer), *eventBase)); + }); } -yarpl::Reference> +std::shared_ptr> ScheduledRSocketResponder::handleRequestStream( Payload request, StreamId streamId) { - auto innerFlowable = inner_->handleRequestStream(std::move(request), - streamId); - return yarpl::flowable::Flowables::fromPublisher( - [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( - yarpl::Reference> - subscriber) { - innerFlowable->subscribe(yarpl::make_ref< - ScheduledSubscriber> - (std::move(subscriber), *eventBase)); - }); + auto innerFlowable = + inner_->handleRequestStream(std::move(request), streamId); + return yarpl::flowable::internal::flowableFromSubscriber( + [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( + std::shared_ptr> subscriber) { + innerFlowable->subscribe(std::make_shared>( + std::move(subscriber), *eventBase)); + }); } -yarpl::Reference> +std::shared_ptr> ScheduledRSocketResponder::handleRequestChannel( Payload request, - yarpl::Reference> - requestStream, + std::shared_ptr> requestStream, StreamId streamId) { - auto requestStreamFlowable = yarpl::flowable::Flowables::fromPublisher( - [requestStream = std::move(requestStream), eventBase = &eventBase_]( - yarpl::Reference> - subscriber) { - requestStream->subscribe(yarpl::make_ref< - ScheduledSubscriptionSubscriber> - (std::move(subscriber), *eventBase)); - }); - auto innerFlowable = inner_->handleRequestChannel(std::move(request), - std::move( - requestStreamFlowable), - streamId); - return yarpl::flowable::Flowables::fromPublisher( - [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( - yarpl::Reference> - subscriber) { - innerFlowable->subscribe(yarpl::make_ref< - ScheduledSubscriber> - (std::move(subscriber), *eventBase)); - }); + auto requestStreamFlowable = + yarpl::flowable::internal::flowableFromSubscriber( + [requestStream = std::move(requestStream), eventBase = &eventBase_]( + std::shared_ptr> + subscriber) { + requestStream->subscribe( + std::make_shared>( + std::move(subscriber), *eventBase)); + }); + auto innerFlowable = inner_->handleRequestChannel( + std::move(request), std::move(requestStreamFlowable), streamId); + return yarpl::flowable::internal::flowableFromSubscriber( + [innerFlowable = std::move(innerFlowable), eventBase = &eventBase_]( + std::shared_ptr> subscriber) { + innerFlowable->subscribe(std::make_shared>( + std::move(subscriber), *eventBase)); + }); } void ScheduledRSocketResponder::handleFireAndForget( @@ -80,4 +85,4 @@ void ScheduledRSocketResponder::handleFireAndForget( inner_->handleFireAndForget(std::move(request), streamId); } -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledRSocketResponder.h b/rsocket/internal/ScheduledRSocketResponder.h index e943c6ef2..fe9039dcc 100644 --- a/rsocket/internal/ScheduledRSocketResponder.h +++ b/rsocket/internal/ScheduledRSocketResponder.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -20,30 +32,24 @@ class ScheduledRSocketResponder : public RSocketResponder { std::shared_ptr inner, folly::EventBase& eventBase); - yarpl::Reference> - handleRequestResponse( + std::shared_ptr> handleRequestResponse( Payload request, StreamId streamId) override; - yarpl::Reference> - handleRequestStream( + std::shared_ptr> handleRequestStream( Payload request, StreamId streamId) override; - yarpl::Reference> - handleRequestChannel( + std::shared_ptr> handleRequestChannel( Payload request, - yarpl::Reference> - requestStream, + std::shared_ptr> requestStream, StreamId streamId) override; - void handleFireAndForget( - Payload request, - StreamId streamId) override; + void handleFireAndForget(Payload request, StreamId streamId) override; private: - std::shared_ptr inner_; + const std::shared_ptr inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSingleObserver.h b/rsocket/internal/ScheduledSingleObserver.h index 1db49cc35..167b5458e 100644 --- a/rsocket/internal/ScheduledSingleObserver.h +++ b/rsocket/internal/ScheduledSingleObserver.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -17,24 +29,23 @@ namespace rsocket { // application code so that calls to on{Subscribe,Success,Error} are // scheduled on the right EventBase. // -template +template class ScheduledSingleObserver : public yarpl::single::SingleObserver { public: ScheduledSingleObserver( - yarpl::Reference> observer, - folly::EventBase& eventBase) : - inner_(std::move(observer)), eventBase_(eventBase) {} + std::shared_ptr> observer, + folly::EventBase& eventBase) + : inner_(std::move(observer)), eventBase_(eventBase) {} - void onSubscribe( - yarpl::Reference subscription) override { + void onSubscribe(std::shared_ptr + subscription) override { if (eventBase_.isInEventBaseThread()) { inner_->onSubscribe(std::move(subscription)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, subscription = std::move(subscription)] - { - inner->onSubscribe(std::move(subscription)); - }); + [inner = inner_, subscription = std::move(subscription)] { + inner->onSubscribe(std::move(subscription)); + }); } } @@ -44,26 +55,26 @@ class ScheduledSingleObserver : public yarpl::single::SingleObserver { inner_->onSuccess(std::move(value)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, value = std::move(value)]() mutable { - inner->onSuccess(std::move(value)); - }); + [inner = inner_, value = std::move(value)]() mutable { + inner->onSuccess(std::move(value)); + }); } } // No further calls to the subscription after this method is invoked. - void onError(std::exception_ptr ex) override { + void onError(folly::exception_wrapper ex) override { if (eventBase_.isInEventBaseThread()) { inner_->onError(std::move(ex)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, ex = std::move(ex)]() mutable { - inner->onError(std::move(ex)); - }); + [inner = inner_, ex = std::move(ex)]() mutable { + inner->onError(std::move(ex)); + }); } } private: - yarpl::Reference> inner_; + const std::shared_ptr> inner_; folly::EventBase& eventBase_; }; @@ -73,18 +84,19 @@ class ScheduledSingleObserver : public yarpl::single::SingleObserver { // application code will be wrapped with a scheduled subscription to make the // call to Subscription::cancel safe. // -template -class ScheduledSubscriptionSingleObserver : public yarpl::single::SingleObserver { +template +class ScheduledSubscriptionSingleObserver + : public yarpl::single::SingleObserver { public: ScheduledSubscriptionSingleObserver( - yarpl::Reference> observer, - folly::EventBase& eventBase) : - inner_(std::move(observer)), eventBase_(eventBase) {} + std::shared_ptr> observer, + folly::EventBase& eventBase) + : inner_(std::move(observer)), eventBase_(eventBase) {} - void onSubscribe( - yarpl::Reference subscription) override { - inner_->onSubscribe( - yarpl::make_ref(std::move(subscription), eventBase_)); + void onSubscribe(std::shared_ptr + subscription) override { + inner_->onSubscribe(std::make_shared( + std::move(subscription), eventBase_)); } // No further calls to the subscription after this method is invoked. @@ -93,12 +105,12 @@ class ScheduledSubscriptionSingleObserver : public yarpl::single::SingleObserver } // No further calls to the subscription after this method is invoked. - void onError(std::exception_ptr ex) override { + void onError(folly::exception_wrapper ex) override { inner_->onError(std::move(ex)); } private: - yarpl::Reference> inner_; + const std::shared_ptr> inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSingleSubscription.cpp b/rsocket/internal/ScheduledSingleSubscription.cpp index 4f5167608..b56f76c0d 100644 --- a/rsocket/internal/ScheduledSingleSubscription.cpp +++ b/rsocket/internal/ScheduledSingleSubscription.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ScheduledSingleSubscription.h" @@ -7,20 +19,16 @@ namespace rsocket { ScheduledSingleSubscription::ScheduledSingleSubscription( - yarpl::Reference inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) { -} + std::shared_ptr inner, + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} void ScheduledSingleSubscription::cancel() { if (eventBase_.isInEventBaseThread()) { inner_->cancel(); } else { - eventBase_.runInEventBaseThread([inner = inner_] - { - inner->cancel(); - }); + eventBase_.runInEventBaseThread([inner = inner_] { inner->cancel(); }); } } -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSingleSubscription.h b/rsocket/internal/ScheduledSingleSubscription.h index 5877c4914..1d29412e4 100644 --- a/rsocket/internal/ScheduledSingleSubscription.h +++ b/rsocket/internal/ScheduledSingleSubscription.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -11,20 +23,20 @@ class EventBase; namespace rsocket { // -// A decorator of the SingleSubscription object which schedules the method calls on the -// provided EventBase +// A decorator of the SingleSubscription object which schedules the method calls +// on the provided EventBase // class ScheduledSingleSubscription : public yarpl::single::SingleSubscription { public: ScheduledSingleSubscription( - yarpl::Reference inner, + std::shared_ptr inner, folly::EventBase& eventBase); void cancel() override; private: - yarpl::Reference inner_; + const std::shared_ptr inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSubscriber.h b/rsocket/internal/ScheduledSubscriber.h index 7dc21b588..f73ee44f3 100644 --- a/rsocket/internal/ScheduledSubscriber.h +++ b/rsocket/internal/ScheduledSubscriber.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -18,24 +30,23 @@ namespace rsocket { // right EventBase. // -template +template class ScheduledSubscriber : public yarpl::flowable::Subscriber { public: ScheduledSubscriber( - yarpl::Reference> inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) {} + std::shared_ptr> inner, + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} void onSubscribe( - yarpl::Reference subscription) override { + std::shared_ptr subscription) override { if (eventBase_.isInEventBaseThread()) { inner_->onSubscribe(std::move(subscription)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, subscription = std::move(subscription)] - { - inner->onSubscribe(std::move(subscription)); - }); + [inner = inner_, subscription = std::move(subscription)] { + inner->onSubscribe(std::move(subscription)); + }); } } @@ -45,21 +56,18 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { inner_->onComplete(); } else { eventBase_.runInEventBaseThread( - [inner = inner_] - { - inner->onComplete(); - }); + [inner = inner_] { inner->onComplete(); }); } } - void onError(std::exception_ptr ex) override { + void onError(folly::exception_wrapper ex) override { if (eventBase_.isInEventBaseThread()) { inner_->onError(std::move(ex)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, ex = std::move(ex)]() mutable { - inner->onError(std::move(ex)); - }); + [inner = inner_, ex = std::move(ex)]() mutable { + inner->onError(std::move(ex)); + }); } } @@ -68,14 +76,14 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { inner_->onNext(std::move(value)); } else { eventBase_.runInEventBaseThread( - [inner = inner_, value = std::move(value)]() mutable { - inner->onNext(std::move(value)); - }); + [inner = inner_, value = std::move(value)]() mutable { + inner->onNext(std::move(value)); + }); } } private: - yarpl::Reference> inner_; + const std::shared_ptr> inner_; folly::EventBase& eventBase_; }; @@ -87,36 +95,38 @@ class ScheduledSubscriber : public yarpl::flowable::Subscriber { // wrapped in the ScheduledSubscription since the application code calls // request and cancel from any thread. // -template +template class ScheduledSubscriptionSubscriber : public yarpl::flowable::Subscriber { public: ScheduledSubscriptionSubscriber( - yarpl::Reference> inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) {} + std::shared_ptr> inner, + folly::EventBase& eventBase) + : inner_(std::move(inner)), eventBase_(eventBase) {} void onSubscribe( - yarpl::Reference subscription) override { - inner_->onSubscribe( - yarpl::make_ref(subscription, eventBase_)); + std::shared_ptr sub) override { + auto scheduled = + std::make_shared(std::move(sub), eventBase_); + inner_->onSubscribe(std::move(scheduled)); } - // No further calls to the subscription after this method is invoked. - void onComplete() override { - inner_->onComplete(); + void onNext(T value) override { + inner_->onNext(std::move(value)); } - void onError(std::exception_ptr ex) override { - inner_->onError(std::move(ex)); + void onComplete() override { + auto inner = std::move(inner_); + inner->onComplete(); } - void onNext(T value) override { - inner_->onNext(std::move(value)); + void onError(folly::exception_wrapper ew) override { + auto inner = std::move(inner_); + inner->onError(std::move(ew)); } private: - yarpl::Reference> inner_; + std::shared_ptr> inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSubscription.cpp b/rsocket/internal/ScheduledSubscription.cpp index 761f9aa0e..a92687aa9 100644 --- a/rsocket/internal/ScheduledSubscription.cpp +++ b/rsocket/internal/ScheduledSubscription.cpp @@ -1,37 +1,42 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/ScheduledSubscription.h" -#include - namespace rsocket { ScheduledSubscription::ScheduledSubscription( - yarpl::Reference inner, - folly::EventBase& eventBase) : inner_(std::move(inner)), - eventBase_(eventBase) { -} + std::shared_ptr inner, + folly::EventBase& eventBase) + : inner_{std::move(inner)}, eventBase_{eventBase} {} -void ScheduledSubscription::request(int64_t n) noexcept { +void ScheduledSubscription::request(int64_t n) { if (eventBase_.isInEventBaseThread()) { inner_->request(n); } else { - eventBase_.runInEventBaseThread([inner = inner_, n] - { - inner->request(n); - }); + eventBase_.runInEventBaseThread([inner = inner_, n] { inner->request(n); }); } } -void ScheduledSubscription::cancel() noexcept { +void ScheduledSubscription::cancel() { if (eventBase_.isInEventBaseThread()) { - inner_->cancel(); + auto inner = std::move(inner_); + inner->cancel(); } else { - eventBase_.runInEventBaseThread([inner = inner_] - { - inner->cancel(); - }); + eventBase_.runInEventBaseThread( + [inner = std::move(inner_)] { inner->cancel(); }); } } -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/ScheduledSubscription.h b/rsocket/internal/ScheduledSubscription.h index 9e595472f..14c058cb4 100644 --- a/rsocket/internal/ScheduledSubscription.h +++ b/rsocket/internal/ScheduledSubscription.h @@ -1,32 +1,39 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "yarpl/flowable/Subscription.h" +#include -namespace folly { -class EventBase; -} +#include "yarpl/flowable/Subscription.h" namespace rsocket { -// -// A decorator of the Subscription object which schedules the method calls on the -// provided EventBase -// +// A wrapper over Subscription that schedules all of the subscription's methods +// on an EventBase. class ScheduledSubscription : public yarpl::flowable::Subscription { public: ScheduledSubscription( - yarpl::Reference inner, - folly::EventBase& eventBase); - - void request(int64_t n) noexcept override; + std::shared_ptr, + folly::EventBase&); - void cancel() noexcept override; + void request(int64_t) override; + void cancel() override; private: - yarpl::Reference inner_; + std::shared_ptr inner_; folly::EventBase& eventBase_; }; -} // rsocket +} // namespace rsocket diff --git a/rsocket/internal/SetupResumeAcceptor.cpp b/rsocket/internal/SetupResumeAcceptor.cpp index 2ff93d1f1..828e4dbdd 100644 --- a/rsocket/internal/SetupResumeAcceptor.cpp +++ b/rsocket/internal/SetupResumeAcceptor.cpp @@ -1,77 +1,97 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/internal/SetupResumeAcceptor.h" #include #include -#include "rsocket/DuplexConnection.h" #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameProcessor.h" #include "rsocket/framing/FrameSerializer.h" -#include "rsocket/framing/FrameTransport.h" namespace rsocket { -namespace { - -folly::exception_wrapper error(folly::StringPiece message) { - std::runtime_error exn{message.str()}; - auto eptr = std::make_exception_ptr(exn); - return folly::exception_wrapper{std::move(eptr), exn}; -} -} - -class OneFrameProcessor : public FrameProcessor { +/// Subscriber that owns a connection, sets itself as that connection's input, +/// and reads out a single frame before cancelling. +class SetupResumeAcceptor::OneFrameSubscriber final + : public yarpl::flowable::BaseSubscriber> { public: - OneFrameProcessor( + OneFrameSubscriber( SetupResumeAcceptor& acceptor, - yarpl::Reference transport, + std::unique_ptr connection, SetupResumeAcceptor::OnSetup onSetup, SetupResumeAcceptor::OnResume onResume) - : acceptor_(acceptor), - transport_(std::move(transport)), - onSetup_(std::move(onSetup)), - onResume_(std::move(onResume)) { - DCHECK(transport_); + : acceptor_{acceptor}, + connection_{std::move(connection)}, + onSetup_{std::move(onSetup)}, + onResume_{std::move(onResume)} { + DCHECK(connection_); DCHECK(onSetup_); DCHECK(onResume_); + DCHECK(acceptor_.inOwnerThread()); + } + + void setInput() { + DCHECK(acceptor_.inOwnerThread()); + connection_->setInput(ref_from_this(this)); + } + + /// Shut down the DuplexConnection, breaking the cycle between it and this + /// subscriber. Expects the DuplexConnection's destructor to call + /// onComplete/onError on its input subscriber (this). + void close() { + auto self = ref_from_this(this); + connection_.reset(); } - void processFrame(std::unique_ptr buf) override { + void onSubscribeImpl() override { + DCHECK(acceptor_.inOwnerThread()); + this->request(1); + } + + void onNextImpl(std::unique_ptr buf) override { + DCHECK(connection_) << "OneFrameSubscriber received more than one frame"; + DCHECK(acceptor_.inOwnerThread()); + + this->cancel(); // calls onTerminateImpl + acceptor_.processFrame( - std::move(transport_), + std::move(connection_), std::move(buf), std::move(onSetup_), std::move(onResume_)); - // No more code here as the instance might be gone by now. } - void onTerminal(folly::exception_wrapper ew) override { - onSetup_ = nullptr; - onResume_ = nullptr; + void onCompleteImpl() override {} + void onErrorImpl(folly::exception_wrapper) override {} - acceptor_.close(std::move(transport_), std::move(ew)); - // No more code here as the instance might be gone by now. + void onTerminateImpl() override { + DCHECK(acceptor_.inOwnerThread()); + acceptor_.remove(ref_from_this(this)); } private: SetupResumeAcceptor& acceptor_; - yarpl::Reference transport_; + std::unique_ptr connection_; SetupResumeAcceptor::OnSetup onSetup_; SetupResumeAcceptor::OnResume onResume_; }; -SetupResumeAcceptor::SetupResumeAcceptor( - ProtocolVersion version, - folly::EventBase* eventBase) - : eventBase_(eventBase) { +SetupResumeAcceptor::SetupResumeAcceptor(folly::EventBase* eventBase) + : eventBase_{eventBase} { CHECK(eventBase_); - - // If the version is unknown we'll try to autodetect it from the first frame. - if (version != ProtocolVersion::Unknown) { - defaultSerializer_ = FrameSerializer::createFrameSerializer(version); - } } SetupResumeAcceptor::~SetupResumeAcceptor() { @@ -79,20 +99,20 @@ SetupResumeAcceptor::~SetupResumeAcceptor() { } void SetupResumeAcceptor::processFrame( - yarpl::Reference transport, + std::unique_ptr connection, std::unique_ptr buf, SetupResumeAcceptor::OnSetup onSetup, SetupResumeAcceptor::OnResume onResume) { - DCHECK(eventBase_->isInEventBaseThread()); + DCHECK(inOwnerThread()); + DCHECK(connection); if (closed_) { - transport->closeWithError(error("SetupResumeAcceptor is shutting down")); return; } - auto serializer = createSerializer(*buf); + const auto serializer = FrameSerializer::createAutodetectedSerializer(*buf); if (!serializer) { - close(std::move(transport), error("Unable to detect protocol version")); + VLOG(2) << "Unable to detect protocol version"; return; } @@ -100,9 +120,9 @@ void SetupResumeAcceptor::processFrame( case FrameType::SETUP: { Frame_SETUP frame; if (!serializer->deserializeFrom(frame, std::move(buf))) { - transport->outputFrameOrEnqueue( - serializer->serializeOut(Frame_ERROR::invalidFrame())); - close(std::move(transport), error("Cannot decode SETUP frame")); + constexpr auto msg = "Cannot decode SETUP frame"; + auto err = serializer->serializeOut(Frame_ERROR::connectionError(msg)); + connection->send(std::move(err)); break; } @@ -112,34 +132,22 @@ void SetupResumeAcceptor::processFrame( frame.moveToSetupPayload(params); if (serializer->protocolVersion() != params.protocolVersion) { - constexpr folly::StringPiece message{ - "SETUP frame has invalid protocol version"}; - transport->outputFrameOrEnqueue(serializer->serializeOut( - Frame_ERROR::badSetupFrame(message.str()))); - close(transport, error(message)); + constexpr auto msg = "SETUP frame has invalid protocol version"; + auto err = serializer->serializeOut(Frame_ERROR::invalidSetup(msg)); + connection->send(std::move(err)); break; } - remove(transport); - - try { - onSetup(transport, std::move(params)); - } catch (const std::exception& exn) { - folly::exception_wrapper ew{std::current_exception(), exn}; - auto errFrame = Frame_ERROR::rejectedSetup(ew.what().toStdString()); - transport->outputFrameOrEnqueue( - serializer->serializeOut(std::move(errFrame))); - close(std::move(transport), std::move(ew)); - } + onSetup(std::move(connection), std::move(params)); break; } case FrameType::RESUME: { Frame_RESUME frame; if (!serializer->deserializeFrom(frame, std::move(buf))) { - transport->outputFrameOrEnqueue( - serializer->serializeOut(Frame_ERROR::invalidFrame())); - close(std::move(transport), error("Cannot decode RESUME frame")); + constexpr auto msg = "Cannot decode RESUME frame"; + auto err = serializer->serializeOut(Frame_ERROR::connectionError(msg)); + connection->send(std::move(err)); break; } @@ -152,33 +160,20 @@ void SetupResumeAcceptor::processFrame( ProtocolVersion(frame.versionMajor_, frame.versionMinor_)); if (serializer->protocolVersion() != params.protocolVersion) { - constexpr folly::StringPiece message{ - "RESUME frame has invalid protocol version"}; - transport->outputFrameOrEnqueue(serializer->serializeOut( - Frame_ERROR::badSetupFrame(message.str()))); - close(std::move(transport), error(message)); + constexpr auto msg = "RESUME frame has invalid protocol version"; + auto err = serializer->serializeOut(Frame_ERROR::rejectedResume(msg)); + connection->send(std::move(err)); break; } - remove(transport); - - try { - onResume(transport, std::move(params)); - } catch (const std::exception& exn) { - folly::exception_wrapper ew{std::current_exception(), exn}; - auto errFrame = Frame_ERROR::rejectedResume(ew.what().toStdString()); - transport->outputFrameOrEnqueue( - serializer->serializeOut(std::move(errFrame))); - close(std::move(transport), std::move(ew)); - } + onResume(std::move(connection), std::move(params)); break; } default: { - transport->outputFrameOrEnqueue( - serializer->serializeOut(Frame_ERROR::unexpectedFrame())); - close( - std::move(transport), error("Invalid frame, expected SETUP/RESUME")); + constexpr auto msg = "Invalid frame, expected SETUP/RESUME"; + auto err = serializer->serializeOut(Frame_ERROR::connectionError(msg)); + connection->send(std::move(err)); break; } } @@ -188,57 +183,27 @@ void SetupResumeAcceptor::accept( std::unique_ptr connection, OnSetup onSetup, OnResume onResume) { - auto transport = yarpl::make_ref(std::move(connection)); - auto processor = std::make_shared( - *this, transport, std::move(onSetup), std::move(onResume)); - connections_.insert(transport); - // Transport can receive frames right away. - transport->setFrameProcessor(std::move(processor)); -} - -std::shared_ptr SetupResumeAcceptor::createSerializer( - const folly::IOBuf& frame) { - if (defaultSerializer_) { - return defaultSerializer_; - } + DCHECK(inOwnerThread()); - auto serializer = FrameSerializer::createAutodetectedSerializer(frame); - if (!serializer) { - VLOG(2) << "Unable to detect protocol version"; - return nullptr; + if (closed_) { + return; } - VLOG(3) << "Detected protocol version " << serializer->protocolVersion(); - return std::move(serializer); -} - -void SetupResumeAcceptor::close( - yarpl::Reference tport, - folly::exception_wrapper e) { - DCHECK(eventBase_->isInEventBaseThread()); - - // This method always gets called with a FrameTransport::onNext() stack frame - // above it. Closing the transport too early will destroy it and we'll unwind - // back up and try to access it. - eventBase_->runInEventBaseThread( - [ this, transport = std::move(tport), ew = std::move(e) ]() mutable { - if (ew) { - transport->closeWithError(std::move(ew)); - } else { - transport->close(); - } - connections_.erase(transport); - }); + const auto subscriber = std::make_shared( + *this, std::move(connection), std::move(onSetup), std::move(onResume)); + connections_.insert(subscriber); + subscriber->setInput(); } void SetupResumeAcceptor::remove( - const yarpl::Reference& transport) { - transport->setFrameProcessor(nullptr); - connections_.erase(transport); + const std::shared_ptr& + subscriber) { + DCHECK(inOwnerThread()); + connections_.erase(subscriber); } folly::Future SetupResumeAcceptor::close() { - if (eventBase_->isInEventBaseThread()) { + if (inOwnerThread()) { closeAll(); return folly::makeFuture(); } @@ -246,12 +211,18 @@ folly::Future SetupResumeAcceptor::close() { } void SetupResumeAcceptor::closeAll() { - DCHECK(eventBase_->isInEventBaseThread()); + DCHECK(inOwnerThread()); closed_ = true; - for (auto& connection : connections_) { - connection->closeWithError(error("SetupResumeAcceptor is shutting down")); + auto connections = std::move(connections_); + for (auto& connection : connections) { + connection->close(); } } + +bool SetupResumeAcceptor::inOwnerThread() const { + return eventBase_->isInEventBaseThread(); } + +} // namespace rsocket diff --git a/rsocket/internal/SetupResumeAcceptor.h b/rsocket/internal/SetupResumeAcceptor.h index 758b6072e..7ae246e78 100644 --- a/rsocket/internal/SetupResumeAcceptor.h +++ b/rsocket/internal/SetupResumeAcceptor.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -8,8 +20,8 @@ #include #include +#include "rsocket/DuplexConnection.h" #include "rsocket/RSocketParameters.h" -#include "rsocket/internal/Common.h" #include "yarpl/Refcounted.h" namespace folly { @@ -17,14 +29,10 @@ class EventBase; class Executor; class IOBuf; class exception_wrapper; -} +} // namespace folly namespace rsocket { -class DuplexConnection; -class FrameSerializer; -class FrameTransport; - /// Acceptor of DuplexConnections that lets us decide whether the connection is /// trying to setup a new connection or resume an existing one. /// @@ -32,12 +40,12 @@ class FrameTransport; /// SetupResumeAcceptor::accept() entry point is not thread-safe. class SetupResumeAcceptor final { public: - using OnSetup = - folly::Function, SetupParameters)>; - using OnResume = - folly::Function, ResumeParameters)>; + using OnSetup = folly::Function< + void(std::unique_ptr, SetupParameters) noexcept>; + using OnResume = folly::Function< + void(std::unique_ptr, ResumeParameters) noexcept>; - SetupResumeAcceptor(ProtocolVersion, folly::EventBase*); + explicit SetupResumeAcceptor(folly::EventBase*); ~SetupResumeAcceptor(); /// Wait for and process the first frame on a DuplexConnection, calling the @@ -45,36 +53,38 @@ class SetupResumeAcceptor final { void accept(std::unique_ptr, OnSetup, OnResume); /// Close all open connections, and prevent new ones from being accepted. Can - /// be called from any thread. + /// be called from any thread, and also after the EventBase has been + /// destroyed, provided we know the ID of the owner thread. folly::Future close(); private: - friend class OneFrameProcessor; + class OneFrameSubscriber; void processFrame( - yarpl::Reference, + std::unique_ptr, std::unique_ptr, OnSetup, OnResume); - /// Close and remove a FrameTransport from the set. - void close(yarpl::Reference, folly::exception_wrapper); - - /// Remove a FrameTransport from the set. Drop the attached OneFrameProcessor - /// if it has one. - void remove(const yarpl::Reference&); + /// Remove a OneFrameSubscriber from the set. + void remove(const std::shared_ptr&); /// Close all open connections. void closeAll(); - /// Get the default FrameSerializer if one exists, otherwise try to autodetect - /// the correct FrameSerializer from the given frame. - std::shared_ptr createSerializer(const folly::IOBuf&); + /// Whether we're running in the thread that owns this object. If the ctor + /// specified an owner thread ID, then this will not access the EventBase + /// pointer. + /// + /// Useful if the EventBase has been destroyed but we still want to do some + /// work within the owner thread. + bool inOwnerThread() const; + + std::unordered_set> connections_; - std::unordered_set> connections_; bool closed_{false}; - std::shared_ptr defaultSerializer_; - folly::EventBase* eventBase_; + folly::EventBase* const eventBase_; }; -} + +} // namespace rsocket diff --git a/rsocket/internal/StackTraceUtils.h b/rsocket/internal/StackTraceUtils.h index 4d8d05069..b99d5b943 100644 --- a/rsocket/internal/StackTraceUtils.h +++ b/rsocket/internal/StackTraceUtils.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -14,4 +26,4 @@ inline std::string getStackTrace() { } #endif -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/internal/SwappableEventBase.cpp b/rsocket/internal/SwappableEventBase.cpp index 7abd96fe0..f745a9365 100644 --- a/rsocket/internal/SwappableEventBase.cpp +++ b/rsocket/internal/SwappableEventBase.cpp @@ -1,34 +1,47 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "SwappableEventBase.h" namespace rsocket { bool SwappableEventBase::runInEventBaseThread(CbFunc cb) { - std::lock_guard l(hasSebDtored_->l_); + const std::lock_guard l(hasSebDtored_->l_); - if(this->isSwapping()) { + if (this->isSwapping()) { queued_.push_back(std::move(cb)); return false; } - return eb_->runInEventBaseThread([eb = eb_, cb_ = std::move(cb)]() mutable { - return cb_(*eb); - }); + eb_->runInEventBaseThread( + [eb = eb_, cb_ = std::move(cb)]() mutable { return cb_(*eb); }); + + return true; } void SwappableEventBase::setEventBase(folly::EventBase& newEb) { - std::lock_guard l(hasSebDtored_->l_); + const std::lock_guard l(hasSebDtored_->l_); auto const alreadySwapping = this->isSwapping(); nextEb_ = &newEb; - if(alreadySwapping) { + if (alreadySwapping) { return; } eb_->runInEventBaseThread([this, hasSebDtored = hasSebDtored_]() { - std::lock_guard lInner(hasSebDtored->l_); - if(hasSebDtored->destroyed_) { + const std::lock_guard lInner(hasSebDtored->l_); + if (hasSebDtored->destroyed_) { // SEB was destroyed, any queued callbacks were appended to the old eb_ return; } @@ -38,10 +51,9 @@ void SwappableEventBase::setEventBase(folly::EventBase& newEb) { // enqueue tasks that were being buffered while this was waiting // for the previous EB to drain - for(auto& cb : queued_) { - eb_->runInEventBaseThread([cb = std::move(cb), eb = eb_]() mutable { - return cb(*eb); - }); + for (auto& cb : queued_) { + eb_->runInEventBaseThread( + [cb = std::move(cb), eb = eb_]() mutable { return cb(*eb); }); } queued_.clear(); @@ -53,13 +65,12 @@ bool SwappableEventBase::isSwapping() const { } SwappableEventBase::~SwappableEventBase() { - std::lock_guard l(hasSebDtored_->l_); + const std::lock_guard l(hasSebDtored_->l_); hasSebDtored_->destroyed_ = true; - for(auto& cb : queued_) { - eb_->runInEventBaseThread([cb = std::move(cb), eb = eb_]() mutable { - return cb(*eb); - }); + for (auto& cb : queued_) { + eb_->runInEventBaseThread( + [cb = std::move(cb), eb = eb_]() mutable { return cb(*eb); }); } queued_.clear(); } diff --git a/rsocket/internal/SwappableEventBase.h b/rsocket/internal/SwappableEventBase.h index 97df41368..456eb67bf 100644 --- a/rsocket/internal/SwappableEventBase.h +++ b/rsocket/internal/SwappableEventBase.h @@ -1,9 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include #include +#include #include namespace rsocket { @@ -18,16 +30,16 @@ class SwappableEventBase final { // lock for synchronization on destroyed_, and all members of the parent SEB std::mutex l_; // has the SEB's destructor ran? - bool destroyed_ {false}; + bool destroyed_{false}; }; -public: + public: using CbFunc = folly::Function; explicit SwappableEventBase(folly::EventBase& eb) - : eb_(&eb), - nextEb_(nullptr), - hasSebDtored_(std::make_shared()) {} + : eb_(&eb), + nextEb_(nullptr), + hasSebDtored_(std::make_shared()) {} // Run or enqueue 'cb', in order with all prior calls to runInEventBaseThread // If setEventBase has been called, and the prior EventBase is still @@ -47,7 +59,7 @@ class SwappableEventBase final { // there are any pending by the time the SEB is destroyed ~SwappableEventBase(); -private: + private: folly::EventBase* eb_; folly::EventBase* nextEb_; // also indicate if we're in the middle of a swap @@ -66,5 +78,4 @@ class SwappableEventBase final { std::vector queued_; }; - -} /* ns rsocket */ +} // namespace rsocket diff --git a/rsocket/internal/WarmResumeManager.cpp b/rsocket/internal/WarmResumeManager.cpp new file mode 100644 index 000000000..c67de86e9 --- /dev/null +++ b/rsocket/internal/WarmResumeManager.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/internal/WarmResumeManager.h" + +#include + +namespace rsocket { + +WarmResumeManager::~WarmResumeManager() { + clearFrames(lastSentPosition_); +} + +void WarmResumeManager::trackReceivedFrame( + size_t frameLength, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) { + if (shouldTrackFrame(frameType)) { + VLOG(6) << "Track received frame " << frameType << " StreamId: " << streamId + << " Allowance: " << consumerAllowance; + impliedPosition_ += frameLength; + } +} + +void WarmResumeManager::trackSentFrame( + const folly::IOBuf& serializedFrame, + FrameType frameType, + StreamId, + size_t consumerAllowance) { + if (shouldTrackFrame(frameType)) { + // TODO(tmont): this could be expensive, find a better way to get length + const auto frameDataLength = serializedFrame.computeChainDataLength(); + + VLOG(6) << "Track sent frame " << frameType + << " Allowance: " << consumerAllowance; + // If the frame is too huge, we don't cache it. + // We empty the entire cache instead. + if (frameDataLength > capacity_) { + resetUpToPosition(lastSentPosition_); + lastSentPosition_ += frameDataLength; + firstSentPosition_ += frameDataLength; + DCHECK(firstSentPosition_ == lastSentPosition_); + DCHECK(size_ == 0); + return; + } + + addFrame(serializedFrame, frameDataLength); + lastSentPosition_ += frameDataLength; + } +} + +void WarmResumeManager::resetUpToPosition(ResumePosition position) { + if (position <= firstSentPosition_) { + return; + } + + if (position > lastSentPosition_) { + position = lastSentPosition_; + } + + clearFrames(position); + + firstSentPosition_ = position; + DCHECK(frames_.empty() || frames_.front().first == firstSentPosition_); +} + +bool WarmResumeManager::isPositionAvailable(ResumePosition position) const { + return (lastSentPosition_ == position) || + std::binary_search( + frames_.begin(), + frames_.end(), + std::make_pair(position, std::unique_ptr()), + [](decltype(frames_.back()) pairA, + decltype(frames_.back()) pairB) { + return pairA.first < pairB.first; + }); +} + +void WarmResumeManager::addFrame( + const folly::IOBuf& frame, + size_t frameDataLength) { + size_ += frameDataLength; + while (size_ > capacity_) { + evictFrame(); + } + frames_.emplace_back(lastSentPosition_, frame.clone()); + stats_->resumeBufferChanged(1, static_cast(frameDataLength)); +} + +void WarmResumeManager::evictFrame() { + DCHECK(!frames_.empty()); + + const auto position = frames_.size() > 1 ? std::next(frames_.begin())->first + : lastSentPosition_; + resetUpToPosition(position); +} + +void WarmResumeManager::clearFrames(ResumePosition position) { + if (frames_.empty()) { + return; + } + DCHECK(position <= lastSentPosition_); + DCHECK(position >= firstSentPosition_); + + const auto end = std::lower_bound( + frames_.begin(), + frames_.end(), + position, + [](decltype(frames_.back()) pair, ResumePosition pos) { + return pair.first < pos; + }); + DCHECK(end == frames_.end() || end->first >= firstSentPosition_); + const auto pos = end == frames_.end() ? position : end->first; + stats_->resumeBufferChanged( + -static_cast(std::distance(frames_.begin(), end)), + -static_cast(pos - firstSentPosition_)); + + frames_.erase(frames_.begin(), end); + size_ -= static_cast(pos - firstSentPosition_); +} + +void WarmResumeManager::sendFramesFromPosition( + ResumePosition position, + FrameTransport& frameTransport) const { + DCHECK(isPositionAvailable(position)); + + if (position == lastSentPosition_) { + // idle resumption + return; + } + + auto found = std::lower_bound( + frames_.begin(), + frames_.end(), + position, + [](decltype(frames_.back()) pair, ResumePosition pos) { + return pair.first < pos; + }); + + DCHECK(found != frames_.end()); + DCHECK(found->first == position); + + while (found != frames_.end()) { + frameTransport.outputFrameOrDrop(found->second->clone()); + found++; + } +} + +std::shared_ptr ResumeManager::makeEmpty() { + class Empty : public WarmResumeManager { + public: + Empty() : WarmResumeManager(nullptr, 0) {} + bool shouldTrackFrame(FrameType) const override { + return false; + } + }; + + return std::make_shared(); +} + +} // namespace rsocket diff --git a/rsocket/internal/WarmResumeManager.h b/rsocket/internal/WarmResumeManager.h new file mode 100644 index 000000000..b14969ca9 --- /dev/null +++ b/rsocket/internal/WarmResumeManager.h @@ -0,0 +1,116 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +#include "rsocket/RSocketStats.h" +#include "rsocket/ResumeManager.h" + +namespace folly { +class IOBuf; +} + +namespace rsocket { + +class RSocketStateMachine; +class FrameTransport; + +class WarmResumeManager : public ResumeManager { + public: + explicit WarmResumeManager( + std::shared_ptr stats, + size_t capacity = DEFAULT_CAPACITY) + : stats_(std::move(stats)), capacity_(capacity) {} + ~WarmResumeManager(); + + void trackReceivedFrame( + size_t frameLength, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) override; + + void trackSentFrame( + const folly::IOBuf& serializedFrame, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) override; + + void resetUpToPosition(ResumePosition position) override; + + bool isPositionAvailable(ResumePosition position) const override; + + void sendFramesFromPosition( + ResumePosition position, + FrameTransport& transport) const override; + + ResumePosition firstSentPosition() const override { + return firstSentPosition_; + } + + ResumePosition lastSentPosition() const override { + return lastSentPosition_; + } + + ResumePosition impliedPosition() const override { + return impliedPosition_; + } + + // No action to perform for WarmResumeManager + void onStreamOpen(StreamId, RequestOriginator, std::string, StreamType) + override {} + + // No action to perform for WarmResumeManager + void onStreamClosed(StreamId) override {} + + const StreamResumeInfos& getStreamResumeInfos() const override { + LOG(FATAL) << "Not Implemented for Warm Resumption"; + folly::assume_unreachable(); + } + + StreamId getLargestUsedStreamId() const override { + LOG(FATAL) << "Not Implemented for Warm Resumption"; + folly::assume_unreachable(); + } + + size_t size() const { + return size_; + } + + protected: + void addFrame(const folly::IOBuf&, size_t); + void evictFrame(); + + // Called before clearing cached frames to update stats. + void clearFrames(ResumePosition position); + + const std::shared_ptr stats_; + + // Start position of the send buffer queue + ResumePosition firstSentPosition_{0}; + // End position of the send buffer queue + ResumePosition lastSentPosition_{0}; + // Inferred position of the rcvd frames + ResumePosition impliedPosition_{0}; + + std::deque>> frames_; + + constexpr static size_t DEFAULT_CAPACITY = 1024 * 1024; // 1MB + const size_t capacity_; + size_t size_{0}; +}; +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelRequester.cpp b/rsocket/statemachine/ChannelRequester.cpp index 637376732..6798613a1 100644 --- a/rsocket/statemachine/ChannelRequester.cpp +++ b/rsocket/statemachine/ChannelRequester.cpp @@ -1,133 +1,155 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/ChannelRequester.h" -#include "yarpl/utils/ExceptionString.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - -ChannelRequester::ChannelRequester(const ConsumerBase::Parameters& params) - : ConsumerBase(params), PublisherBase(/*initialRequestN=*/1) {} - void ChannelRequester::onSubscribe( - Reference subscription) noexcept { + std::shared_ptr subscription) { CHECK(!requested_); publisherSubscribe(std::move(subscription)); + + if (hasInitialRequest_) { + initStream(std::move(request_)); + } } -void ChannelRequester::onNext(Payload request) noexcept { - if(!requested_) { - requested_ = true; - - size_t initialN = initialResponseAllowance_.drainWithLimit( - Frame_REQUEST_N::kMaxRequestN); - size_t remainingN = initialResponseAllowance_.drain(); - // Send as much as possible with the initial request. - CHECK_GE(Frame_REQUEST_N::kMaxRequestN, initialN); - newStream( - StreamType::CHANNEL, - static_cast(initialN), - std::move(request), - false); - // We must inform ConsumerBase about an implicit allowance we have - // requested from the remote end. - ConsumerBase::addImplicitAllowance(initialN); - // Pump the remaining allowance into the ConsumerBase _after_ sending the - // initial request. - if (remainingN) { - ConsumerBase::generateRequest(remainingN); - } +void ChannelRequester::onNext(Payload request) { + if (!requested_) { + initStream(std::move(request)); return; } - checkPublisherOnNext(); - writePayload(std::move(request), false); + if (!publisherClosed()) { + writePayload(std::move(request)); + } } // TODO: consolidate code in onCompleteImpl, onErrorImpl, cancelImpl -void ChannelRequester::onComplete() noexcept { +void ChannelRequester::onComplete() { if (!requested_) { - closeStream(StreamCompletionSignal::CANCEL); + endStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); return; } - publisherComplete(); - completeStream(); - tryCompleteChannel(); + if (!publisherClosed()) { + publisherComplete(); + writeComplete(); + tryCompleteChannel(); + } } -void ChannelRequester::onError(std::exception_ptr ex) noexcept { +void ChannelRequester::onError(folly::exception_wrapper ex) { if (!requested_) { - closeStream(StreamCompletionSignal::CANCEL); + endStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); return; } - publisherComplete(); - applicationError(yarpl::exceptionStr(ex)); - tryCompleteChannel(); + if (!publisherClosed()) { + publisherComplete(); + endStream(StreamCompletionSignal::ERROR); + writeApplicationError(ex.get_exception()->what()); + tryCompleteChannel(); + } } -void ChannelRequester::request(int64_t n) noexcept { +void ChannelRequester::request(int64_t n) { if (!requested_) { // The initial request has not been sent out yet, hence we must accumulate // the unsynchronised allowance, portion of which will be sent out with // the initial request frame, and the rest will be dispatched via // ConsumerBase:request (ultimately by sending REQUEST_N frames). - initialResponseAllowance_.release(n); + initialResponseAllowance_.add(n); return; } - checkConsumerRequest(); ConsumerBase::generateRequest(n); } -void ChannelRequester::cancel() noexcept { +void ChannelRequester::cancel() { if (!requested_) { - closeStream(StreamCompletionSignal::CANCEL); + endStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); return; } cancelConsumer(); - cancelStream(); + writeCancel(); tryCompleteChannel(); } -void ChannelRequester::endStream(StreamCompletionSignal signal) { - terminatePublisher(); - ConsumerBase::endStream(signal); -} - -void ChannelRequester::tryCompleteChannel() { - if (publisherClosed() && consumerClosed()) { - closeStream(StreamCompletionSignal::COMPLETE); - } -} - void ChannelRequester::handlePayload( Payload&& payload, - bool complete, - bool next) { + bool flagsComplete, + bool flagsNext, + bool flagsFollows) { CHECK(requested_); - processPayload(std::move(payload), next); + bool finalComplete = processFragmentedPayload( + std::move(payload), flagsNext, flagsComplete, flagsFollows); - if (complete) { + if (finalComplete) { completeConsumer(); tryCompleteChannel(); } } -void ChannelRequester::handleError(folly::exception_wrapper ex) { +void ChannelRequester::handleRequestN(uint32_t n) { CHECK(requested_); - errorConsumer(std::move(ex)); - tryCompleteChannel(); + PublisherBase::processRequestN(n); } -void ChannelRequester::handleRequestN(uint32_t n) { +void ChannelRequester::handleError(folly::exception_wrapper ew) { CHECK(requested_); - PublisherBase::processRequestN(n); + errorConsumer(std::move(ew)); + terminatePublisher(); } void ChannelRequester::handleCancel() { CHECK(requested_); - publisherComplete(); + terminatePublisher(); tryCompleteChannel(); } -} // reactivesocket + +void ChannelRequester::endStream(StreamCompletionSignal signal) { + terminatePublisher(); + ConsumerBase::endStream(signal); +} + +void ChannelRequester::initStream(Payload&& request) { + requested_ = true; + + const size_t initialN = initialResponseAllowance_.consumeUpTo(kMaxRequestN); + const size_t remainingN = initialResponseAllowance_.consumeAll(); + + // Send as much as possible with the initial request. + CHECK_GE(static_cast(kMaxRequestN), initialN); + newStream( + StreamType::CHANNEL, static_cast(initialN), std::move(request)); + // We must inform ConsumerBase about an implicit allowance we have + // requested from the remote end. + ConsumerBase::addImplicitAllowance(initialN); + // Pump the remaining allowance into the ConsumerBase _after_ sending the + // initial request. + if (remainingN) { + ConsumerBase::generateRequest(remainingN); + } +} + +void ChannelRequester::tryCompleteChannel() { + if (publisherClosed() && consumerClosed()) { + endStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); + } +} + +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelRequester.h b/rsocket/statemachine/ChannelRequester.h index ea2d9be82..7c05b2028 100644 --- a/rsocket/statemachine/ChannelRequester.h +++ b/rsocket/statemachine/ChannelRequester.h @@ -1,18 +1,24 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include - #include "rsocket/Payload.h" #include "rsocket/statemachine/ConsumerBase.h" #include "rsocket/statemachine/PublisherBase.h" #include "yarpl/flowable/Subscriber.h" -namespace folly { -class exception_wrapper; -} - namespace rsocket { /// Implementation of stream stateMachine that represents a Channel requester. @@ -20,31 +26,49 @@ class ChannelRequester : public ConsumerBase, public PublisherBase, public yarpl::flowable::Subscriber { public: - explicit ChannelRequester(const ConsumerBase::Parameters& params); + ChannelRequester( + Payload request, + std::shared_ptr writer, + StreamId streamId) + : ConsumerBase(std::move(writer), streamId), + PublisherBase(0 /*initialRequestN*/), + request_(std::move(request)), + hasInitialRequest_(true) {} - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload) noexcept override; - void onComplete() noexcept override; - void onError(std::exception_ptr) noexcept override; - - // implementation from ConsumerBase::SubscriptionBase - void request(int64_t) noexcept override; - void cancel() noexcept override; - - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleRequestN(uint32_t n) override; - void handleError(folly::exception_wrapper errorPayload) override; + ChannelRequester(std::shared_ptr writer, StreamId streamId) + : ConsumerBase(std::move(writer), streamId), + PublisherBase(1 /*initialRequestN*/) {} + + void onSubscribe(std::shared_ptr) override; + void onNext(Payload) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; + + void request(int64_t) override; + void cancel() override; + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleRequestN(uint32_t) override; + void handleError(folly::exception_wrapper) override; void handleCancel() override; void endStream(StreamCompletionSignal) override; + + private: + void initStream(Payload&&); void tryCompleteChannel(); - /// An allowance accumulated before the stream is initialised. - /// Remaining part of the allowance is forwarded to the ConsumerBase. - AllowanceSemaphore initialResponseAllowance_; + /// An allowance accumulated before the stream is initialised. Remaining part + /// of the allowance is forwarded to the ConsumerBase. + Allowance initialResponseAllowance_; + + Payload request_; bool requested_{false}; + bool hasInitialRequest_{false}; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelResponder.cpp b/rsocket/statemachine/ChannelResponder.cpp index f2298f5d9..db366e70d 100644 --- a/rsocket/statemachine/ChannelResponder.cpp +++ b/rsocket/statemachine/ChannelResponder.cpp @@ -1,99 +1,122 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/ChannelResponder.h" -#include "yarpl/utils/ExceptionString.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void ChannelResponder::onSubscribe( - Reference subscription) noexcept { + std::shared_ptr subscription) { publisherSubscribe(std::move(subscription)); } -void ChannelResponder::onNext(Payload response) noexcept { - checkPublisherOnNext(); - writePayload(std::move(response), false); -} - -void ChannelResponder::onComplete() noexcept { - publisherComplete(); - completeStream(); - tryCompleteChannel(); +void ChannelResponder::onNext(Payload response) { + if (!publisherClosed()) { + writePayload(std::move(response)); + } } -void ChannelResponder::onError(std::exception_ptr ex) noexcept { - publisherComplete(); - applicationError(yarpl::exceptionStr(ex)); - tryCompleteChannel(); +void ChannelResponder::onComplete() { + if (!publisherClosed()) { + publisherComplete(); + writeComplete(); + tryCompleteChannel(); + } } -void ChannelResponder::tryCompleteChannel() { - if (publisherClosed() && consumerClosed()) { - closeStream(StreamCompletionSignal::COMPLETE); +void ChannelResponder::onError(folly::exception_wrapper ex) { + if (!publisherClosed()) { + publisherComplete(); + endStream(StreamCompletionSignal::ERROR); + if (!ex.with_exception([this](rsocket::ErrorWithPayload& err) { + writeApplicationError(std::move(err.payload)); + })) { + writeApplicationError(ex.get_exception()->what()); + } + tryCompleteChannel(); } } -void ChannelResponder::request(int64_t n) noexcept { - checkConsumerRequest(); +void ChannelResponder::request(int64_t n) { ConsumerBase::generateRequest(n); } -void ChannelResponder::cancel() noexcept { +void ChannelResponder::cancel() { cancelConsumer(); - cancelStream(); + writeCancel(); tryCompleteChannel(); } -void ChannelResponder::endStream(StreamCompletionSignal signal) { - terminatePublisher(); - ConsumerBase::endStream(signal); -} - -// TODO: remove this unused function -void ChannelResponder::processInitialFrame(Frame_REQUEST_CHANNEL&& frame) { - onNextPayloadFrame( - frame.requestN_, - std::move(frame.payload_), - frame.header_.flagsComplete(), - true); -} - void ChannelResponder::handlePayload( Payload&& payload, - bool complete, - bool flagsNext) { - onNextPayloadFrame(0, std::move(payload), complete, flagsNext); -} + bool flagsComplete, + bool flagsNext, + bool flagsFollows) { + payloadFragments_.addPayload(std::move(payload), flagsNext, flagsComplete); + + if (flagsFollows) { + // there will be more fragments to come + return; + } -void ChannelResponder::onNextPayloadFrame( - uint32_t requestN, - Payload&& payload, - bool complete, - bool next) { - processRequestN(requestN); - processPayload(std::move(payload), next); + bool finalFlagsComplete, finalFlagsNext; + Payload finalPayload; + + std::tie(finalPayload, finalFlagsNext, finalFlagsComplete) = + payloadFragments_.consumePayloadAndFlags(); + + if (newStream_) { + newStream_ = false; + auto channelOutputSubscriber = onNewStreamReady( + StreamType::CHANNEL, + std::move(finalPayload), + std::static_pointer_cast(shared_from_this())); + subscribe(std::move(channelOutputSubscriber)); + } else { + processPayload(std::move(finalPayload), finalFlagsNext); + } - if (complete) { + if (finalFlagsComplete) { completeConsumer(); tryCompleteChannel(); } } -void ChannelResponder::handleCancel() { - publisherComplete(); - tryCompleteChannel(); -} - void ChannelResponder::handleRequestN(uint32_t n) { processRequestN(n); } -void ChannelResponder::handleError( - folly::exception_wrapper ex) { - errorConsumer(std::move(ex)); +void ChannelResponder::handleError(folly::exception_wrapper ew) { + errorConsumer(std::move(ew)); + terminatePublisher(); +} + +void ChannelResponder::handleCancel() { + terminatePublisher(); tryCompleteChannel(); } + +void ChannelResponder::endStream(StreamCompletionSignal signal) { + terminatePublisher(); + ConsumerBase::endStream(signal); +} + +void ChannelResponder::tryCompleteChannel() { + if (publisherClosed() && consumerClosed()) { + endStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); + } } + +} // namespace rsocket diff --git a/rsocket/statemachine/ChannelResponder.h b/rsocket/statemachine/ChannelResponder.h index 2fa1b7239..c0e6de708 100644 --- a/rsocket/statemachine/ChannelResponder.h +++ b/rsocket/statemachine/ChannelResponder.h @@ -1,9 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include - #include "rsocket/statemachine/ConsumerBase.h" #include "rsocket/statemachine/PublisherBase.h" #include "yarpl/flowable/Subscriber.h" @@ -15,37 +25,37 @@ class ChannelResponder : public ConsumerBase, public PublisherBase, public yarpl::flowable::Subscriber { public: - explicit ChannelResponder( - uint32_t initialRequestN, - const ConsumerBase::Parameters& params) - : ConsumerBase(params), PublisherBase(initialRequestN) {} - - void processInitialFrame(Frame_REQUEST_CHANNEL&&); + ChannelResponder( + std::shared_ptr writer, + StreamId streamId, + uint32_t initialRequestN) + : ConsumerBase(std::move(writer), streamId), + PublisherBase(initialRequestN) {} + + void onSubscribe(std::shared_ptr) override; + void onNext(Payload) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; + + void request(int64_t) override; + void cancel() override; + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload) noexcept override; - void onComplete() noexcept override; - void onError(std::exception_ptr) noexcept override; - - // implementation from ConsumerBase::SubscriptionBase - void request(int64_t n) noexcept override; - void cancel() noexcept override; - - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleRequestN(uint32_t n) override; + void handleRequestN(uint32_t) override; + void handleError(folly::exception_wrapper) override; void handleCancel() override; - void handleError(folly::exception_wrapper ex) override; - - void onNextPayloadFrame( - uint32_t requestN, - Payload&& payload, - bool complete, - bool next); void endStream(StreamCompletionSignal) override; + private: void tryCompleteChannel(); + + bool newStream_{true}; }; -} // reactivesocket + +} // namespace rsocket diff --git a/rsocket/statemachine/ConsumerBase.cpp b/rsocket/statemachine/ConsumerBase.cpp index 8dde632c2..21d1cedc5 100644 --- a/rsocket/statemachine/ConsumerBase.cpp +++ b/rsocket/statemachine/ConsumerBase.cpp @@ -1,117 +1,151 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/ConsumerBase.h" -#include #include -#include "rsocket/Payload.h" -#include "yarpl/flowable/Subscription.h" +#include namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void ConsumerBase::subscribe( - Reference> subscriber) { - if (Base::isTerminated()) { - subscriber->onSubscribe(yarpl::flowable::Subscription::empty()); + std::shared_ptr> subscriber) { + if (state_ == State::CLOSED) { + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); subscriber->onComplete(); return; } DCHECK(!consumingSubscriber_); consumingSubscriber_ = std::move(subscriber); - consumingSubscriber_->onSubscribe(Reference(this)); -} - -void ConsumerBase::checkConsumerRequest() { - DCHECK(consumingSubscriber_); - CHECK(state_ == State::RESPONDING); + consumingSubscriber_->onSubscribe(shared_from_this()); } void ConsumerBase::cancelConsumer() { state_ = State::CLOSED; + VLOG(5) << "ConsumerBase::cancelConsumer()"; consumingSubscriber_ = nullptr; } +void ConsumerBase::addImplicitAllowance(size_t n) { + allowance_.add(n); + activeRequests_.add(n); +} void ConsumerBase::generateRequest(size_t n) { - allowance_.release(n); - pendingAllowance_.release(n); + allowance_.add(n); + pendingAllowance_.add(n); sendRequests(); } void ConsumerBase::endStream(StreamCompletionSignal signal) { + VLOG(5) << "ConsumerBase::endStream(" << signal << ")"; + state_ = State::CLOSED; if (auto subscriber = std::move(consumingSubscriber_)) { if (signal == StreamCompletionSignal::COMPLETE || signal == StreamCompletionSignal::CANCEL) { // TODO: remove CANCEL + VLOG(5) << "Closing ConsumerBase subscriber with calling onComplete"; subscriber->onComplete(); } else { - subscriber->onError(std::make_exception_ptr( - StreamInterruptedException(static_cast(signal)))); + VLOG(5) << "Closing ConsumerBase subscriber with calling onError"; + subscriber->onError(StreamInterruptedException(static_cast(signal))); } } - Base::endStream(signal); } -//void ConsumerBase::pauseStream(RequestHandler& requestHandler) { -// if (consumingSubscriber_) { -// requestHandler.onSubscriberPaused(consumingSubscriber_); -// } -//} -// -//void ConsumerBase::resumeStream(RequestHandler& requestHandler) { -// if (consumingSubscriber_) { -// requestHandler.onSubscriberResumed(consumingSubscriber_); -// } -//} +size_t ConsumerBase::getConsumerAllowance() const { + return allowance_.get(); +} void ConsumerBase::processPayload(Payload&& payload, bool onNext) { - if (payload || onNext) { - // Frames carry application-level payloads are taken into account when - // figuring out flow control allowance. - if (allowance_.tryAcquire()) { - sendRequests(); - consumingSubscriber_->onNext(std::move(payload)); - } else { - handleFlowControlError(); - return; - } + if (!payload && !onNext) { + return; } + + // Frames carrying application-level payloads are taken into account when + // figuring out flow control allowance. + if (!allowance_.tryConsume(1) || !activeRequests_.tryConsume(1)) { + handleFlowControlError(); + return; + } + + sendRequests(); + if (consumingSubscriber_) { + consumingSubscriber_->onNext(std::move(payload)); + } else { + LOG(ERROR) << "Consuming subscriber is missing, might be a race on " + << "cancel/onNext"; + } +} + +bool ConsumerBase::processFragmentedPayload( + Payload&& payload, + bool flagsNext, + bool flagsComplete, + bool flagsFollows) { + payloadFragments_.addPayload(std::move(payload), flagsNext, flagsComplete); + + if (flagsFollows) { + // there will be more fragments to come + return false; + } + + bool finalFlagsComplete, finalFlagsNext; + Payload finalPayload; + + std::tie(finalPayload, finalFlagsNext, finalFlagsComplete) = + payloadFragments_.consumePayloadAndFlags(); + processPayload(std::move(finalPayload), finalFlagsNext); + return finalFlagsComplete; } void ConsumerBase::completeConsumer() { state_ = State::CLOSED; + VLOG(5) << "ConsumerBase::completeConsumer()"; if (auto subscriber = std::move(consumingSubscriber_)) { subscriber->onComplete(); } } -void ConsumerBase::errorConsumer(folly::exception_wrapper ex) { +void ConsumerBase::errorConsumer(folly::exception_wrapper ew) { state_ = State::CLOSED; + VLOG(5) << "ConsumerBase::errorConsumer()"; if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError(ex.to_exception_ptr()); + subscriber->onError(std::move(ew)); } } void ConsumerBase::sendRequests() { - // TODO(stupaq): batch if remote end has some spare allowance - // TODO(stupaq): limit how much is synced to the other end - size_t toSync = Frame_REQUEST_N::kMaxRequestN; - toSync = pendingAllowance_.drainWithLimit(toSync); - if (toSync > 0) { - writeRequestN(static_cast(toSync)); + auto toSync = std::min(pendingAllowance_.get(), kMaxRequestN); + auto actives = activeRequests_.get(); + if (actives <= toSync) { + toSync = pendingAllowance_.consumeUpTo(toSync); + if (toSync > 0) { + writeRequestN(static_cast(toSync)); + activeRequests_.add(toSync); + } } } void ConsumerBase::handleFlowControlError() { if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError( - std::make_exception_ptr(std::runtime_error("surplus response"))); + subscriber->onError(std::runtime_error("Surplus response")); } - errorStream("flow control error"); + writeInvalidError("Flow control error"); + endStream(StreamCompletionSignal::ERROR); + removeFromWriter(); } -} +} // namespace rsocket diff --git a/rsocket/statemachine/ConsumerBase.h b/rsocket/statemachine/ConsumerBase.h index f0b73665f..773c1350c 100644 --- a/rsocket/statemachine/ConsumerBase.h +++ b/rsocket/statemachine/ConsumerBase.h @@ -1,82 +1,87 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include -#include -#include - #include "rsocket/Payload.h" -#include "rsocket/internal/AllowanceSemaphore.h" -#include "rsocket/internal/Common.h" -#include "rsocket/statemachine/RSocketStateMachine.h" +#include "rsocket/internal/Allowance.h" #include "rsocket/statemachine/StreamStateMachineBase.h" +#include "yarpl/flowable/Subscriber.h" #include "yarpl/flowable/Subscription.h" namespace rsocket { -enum class StreamCompletionSignal; - /// A class that represents a flow-control-aware consumer of data. class ConsumerBase : public StreamStateMachineBase, - public yarpl::flowable::Subscription { - using Base = StreamStateMachineBase; - + public yarpl::flowable::Subscription, + public std::enable_shared_from_this { public: - using Base::Base; + using StreamStateMachineBase::StreamStateMachineBase; + + void subscribe(std::shared_ptr>); /// Adds implicit allowance. /// /// This portion of allowance will not be synced to the remote end, but will /// count towards the limit of allowance the remote PublisherBase may use. - void addImplicitAllowance(size_t n) { - allowance_.release(n); - } - - /// @{ - void subscribe( - yarpl::Reference> subscriber); + void addImplicitAllowance(size_t); - void generateRequest(size_t n); - /// @} - - protected: - void checkConsumerRequest(); - void cancelConsumer(); + void generateRequest(size_t); bool consumerClosed() const { return state_ == State::CLOSED; } - void endStream(StreamCompletionSignal signal) override; - -// void pauseStream(RequestHandler& requestHandler) override; -// void resumeStream(RequestHandler& requestHandler) override; + size_t getConsumerAllowance() const override; + void endStream(StreamCompletionSignal) override; + protected: void processPayload(Payload&&, bool onNext); + // returns true if the stream is completed + bool + processFragmentedPayload(Payload&&, bool next, bool complete, bool follows); + + void cancelConsumer(); void completeConsumer(); - void errorConsumer(folly::exception_wrapper ex); + void errorConsumer(folly::exception_wrapper); private: + enum class State : uint8_t { + RESPONDING, + CLOSED, + }; + void sendRequests(); void handleFlowControlError(); - /// A Subscriber that will consume payloads. - /// This is responsible for delivering a terminal signal to the - /// Subscriber once the stream ends. - yarpl::Reference> consumingSubscriber_; + /// A Subscriber that will consume payloads. This is responsible for + /// delivering a terminal signal to the Subscriber once the stream ends. + std::shared_ptr> consumingSubscriber_; /// A total, net allowance (requested less delivered) by this consumer. - AllowanceSemaphore allowance_; + Allowance allowance_; /// An allowance that have yet to be synced to the other end by sending /// REQUEST_N frames. - AllowanceSemaphore pendingAllowance_; + Allowance pendingAllowance_; - enum class State : uint8_t { - RESPONDING, - CLOSED, - } state_{State::RESPONDING}; + /// The number of already requested payload count. Prevent excessive requestN + /// calls. + Allowance activeRequests_; + + State state_{State::RESPONDING}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/FireAndForgetResponder.cpp b/rsocket/statemachine/FireAndForgetResponder.cpp new file mode 100644 index 000000000..2a15b87a6 --- /dev/null +++ b/rsocket/statemachine/FireAndForgetResponder.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/FireAndForgetResponder.h" + +namespace rsocket { + +using namespace yarpl::flowable; + +void FireAndForgetResponder::handlePayload( + Payload&& payload, + bool /*flagsComplete*/, + bool /*flagsNext*/, + bool flagsFollows) { + payloadFragments_.addPayloadIgnoreFlags(std::move(payload)); + + if (flagsFollows) { + // there will be more fragments to come + return; + } + + Payload finalPayload = payloadFragments_.consumePayloadIgnoreFlags(); + onNewStreamReady( + StreamType::FNF, + std::move(finalPayload), + std::shared_ptr>(nullptr)); + removeFromWriter(); +} + +void FireAndForgetResponder::handleCancel() { + removeFromWriter(); +} + +} // namespace rsocket diff --git a/rsocket/statemachine/FireAndForgetResponder.h b/rsocket/statemachine/FireAndForgetResponder.h new file mode 100644 index 000000000..bf9ad3397 --- /dev/null +++ b/rsocket/statemachine/FireAndForgetResponder.h @@ -0,0 +1,41 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/statemachine/StreamStateMachineBase.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleSubscription.h" + +namespace rsocket { + +/// Helper class for handling receiving fragmented payload +class FireAndForgetResponder : public StreamStateMachineBase { + public: + FireAndForgetResponder( + std::shared_ptr writer, + StreamId streamId) + : StreamStateMachineBase(std::move(writer), streamId) {} + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + + private: + void handleCancel() override; +}; +} // namespace rsocket diff --git a/rsocket/statemachine/PublisherBase.cpp b/rsocket/statemachine/PublisherBase.cpp index bf5d2b33f..867ae4255 100644 --- a/rsocket/statemachine/PublisherBase.cpp +++ b/rsocket/statemachine/PublisherBase.cpp @@ -1,18 +1,28 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/PublisherBase.h" #include -#include "rsocket/statemachine/RSocketStateMachine.h" - namespace rsocket { PublisherBase::PublisherBase(uint32_t initialRequestN) - : initialRequestN_(initialRequestN) {} + : initialRequestN_(initialRequestN) {} void PublisherBase::publisherSubscribe( - yarpl::Reference subscription) { + std::shared_ptr subscription) { if (state_ == State::CLOSED) { subscription->cancel(); return; @@ -20,15 +30,10 @@ void PublisherBase::publisherSubscribe( DCHECK(!producingSubscription_); producingSubscription_ = std::move(subscription); if (initialRequestN_) { - producingSubscription_->request(initialRequestN_.drain()); + producingSubscription_->request(initialRequestN_.consumeAll()); } } -void PublisherBase::checkPublisherOnNext() { - DCHECK(producingSubscription_); - CHECK(state_ == State::RESPONDING); -} - void PublisherBase::publisherComplete() { state_ = State::CLOSED; producingSubscription_ = nullptr; @@ -39,16 +44,16 @@ bool PublisherBase::publisherClosed() const { } void PublisherBase::processRequestN(uint32_t requestN) { - if (!requestN || state_ == State::CLOSED) { + if (requestN == 0 || state_ == State::CLOSED) { return; } - // we might not have the subscription set yet as there can be REQUEST_N - // frames scheduled on the executor before onSubscribe method + // We might not have the subscription set yet as there can be REQUEST_N frames + // scheduled on the executor before onSubscribe method. if (producingSubscription_) { producingSubscription_->request(requestN); } else { - initialRequestN_.release(requestN); + initialRequestN_.add(requestN); } } @@ -58,4 +63,5 @@ void PublisherBase::terminatePublisher() { subscription->cancel(); } } -} + +} // namespace rsocket diff --git a/rsocket/statemachine/PublisherBase.h b/rsocket/statemachine/PublisherBase.h index 4a9d6e208..b5df39909 100644 --- a/rsocket/statemachine/PublisherBase.h +++ b/rsocket/statemachine/PublisherBase.h @@ -1,42 +1,46 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include "rsocket/Payload.h" -#include "rsocket/internal/AllowanceSemaphore.h" +#include "rsocket/internal/Allowance.h" #include "yarpl/flowable/Subscription.h" namespace rsocket { -enum class StreamCompletionSignal; - /// A class that represents a flow-control-aware producer of data. class PublisherBase { public: explicit PublisherBase(uint32_t initialRequestN); - void publisherSubscribe( - yarpl::Reference subscription); - - void checkPublisherOnNext(); + void publisherSubscribe(std::shared_ptr); + void processRequestN(uint32_t); void publisherComplete(); - bool publisherClosed() const; - - void processRequestN(uint32_t requestN); + bool publisherClosed() const; void terminatePublisher(); private: - /// A Subscription that constrols production of payloads. - /// This is responsible for delivering a terminal signal to the - /// Subscription once the stream ends. - yarpl::Reference producingSubscription_; - AllowanceSemaphore initialRequestN_; - enum class State : uint8_t { RESPONDING, CLOSED, - } state_{State::RESPONDING}; + }; + + std::shared_ptr producingSubscription_; + Allowance initialRequestN_; + State state_{State::RESPONDING}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/RSocketStateMachine.cpp b/rsocket/statemachine/RSocketStateMachine.cpp index d54aa824a..4d914052a 100644 --- a/rsocket/statemachine/RSocketStateMachine.cpp +++ b/rsocket/statemachine/RSocketStateMachine.cpp @@ -1,11 +1,25 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/RSocketStateMachine.h" #include +#include #include #include -#include +#include +#include #include "rsocket/DuplexConnection.h" #include "rsocket/RSocketConnectionEvents.h" @@ -14,38 +28,90 @@ #include "rsocket/RSocketStats.h" #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameSerializer.h" -#include "rsocket/framing/FrameTransport.h" +#include "rsocket/framing/FrameTransportImpl.h" #include "rsocket/internal/ClientResumeStatusCallback.h" -#include "rsocket/internal/ResumeCache.h" +#include "rsocket/internal/ScheduledSubscriber.h" +#include "rsocket/internal/WarmResumeManager.h" +#include "rsocket/statemachine/ChannelRequester.h" #include "rsocket/statemachine/ChannelResponder.h" -#include "rsocket/statemachine/StreamState.h" +#include "rsocket/statemachine/FireAndForgetResponder.h" +#include "rsocket/statemachine/RequestResponseRequester.h" +#include "rsocket/statemachine/RequestResponseResponder.h" +#include "rsocket/statemachine/StreamRequester.h" +#include "rsocket/statemachine/StreamResponder.h" #include "rsocket/statemachine/StreamStateMachineBase.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/single/SingleSubscriptions.h" + namespace rsocket { +namespace { + +void disconnectError( + std::shared_ptr> subscriber) { + std::runtime_error exn{"RSocket connection is disconnected or closed"}; + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onError(std::move(exn)); +} + +void disconnectError( + std::shared_ptr> observer) { + auto exn = folly::make_exception_wrapper( + "RSocket connection is disconnected or closed"); + observer->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + observer->onError(std::move(exn)); +} + +} // namespace + RSocketStateMachine::RSocketStateMachine( - folly::Executor& executor, std::shared_ptr requestResponder, std::unique_ptr keepaliveTimer, - ReactiveSocketMode mode, + RSocketMode mode, std::shared_ptr stats, - std::shared_ptr connectionEvents) - : mode_(mode), - resumeCache_(std::make_shared(stats)), - streamState_(std::make_shared(*stats)), - requestResponder_(std::move(requestResponder)), - keepaliveTimer_(std::move(keepaliveTimer)), - streamsFactory_(*this, mode), - stats_(stats), - connectionEvents_(connectionEvents), - executor_(executor) { + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler) + : RSocketStateMachine( + std::make_shared( + std::move(requestResponder)), + std::move(keepaliveTimer), + mode, + std::move(stats), + std::move(connectionEvents), + std::move(resumeManager), + std::move(coldResumeHandler)) {} + +RSocketStateMachine::RSocketStateMachine( + std::shared_ptr requestResponder, + std::unique_ptr keepaliveTimer, + RSocketMode mode, + std::shared_ptr stats, + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler) + : mode_{mode}, + stats_{stats ? stats : RSocketStats::noop()}, + // Streams initiated by a client MUST use odd-numbered and streams + // initiated by the server MUST use even-numbered stream identifiers + nextStreamId_(mode == RSocketMode::CLIENT ? 1 : 2), + resumeManager_(std::move(resumeManager)), + requestResponder_{std::move(requestResponder)}, + keepaliveTimer_{std::move(keepaliveTimer)}, + coldResumeHandler_{std::move(coldResumeHandler)}, + connectionEvents_{connectionEvents} { + CHECK(resumeManager_) + << "provide ResumeManager::makeEmpty() instead of nullptr"; + // We deliberately do not "open" input or output to avoid having c'tor on the // stack when processing any signals from the connection. See ::connect and // ::onSubscribe. - CHECK(streamState_); + CHECK(requestResponder_); stats_->socketCreated(); + VLOG(2) << "Creating RSocketStateMachine"; } RSocketStateMachine::~RSocketStateMachine() { @@ -53,105 +119,167 @@ RSocketStateMachine::~RSocketStateMachine() { // automatons destroyed on different threads can be the last ones referencing // this. - VLOG(6) << "~RSocketStateMachine"; + VLOG(3) << "~RSocketStateMachine"; // We rely on SubscriptionPtr and SubscriberPtr to dispatch appropriate // terminal signals. DCHECK(!resumeCallback_); - DCHECK(isDisconnectedOrClosed()); // the instance should be closed by via + DCHECK(isDisconnected()); // the instance should be closed by via // close method } void RSocketStateMachine::setResumable(bool resumable) { - debugCheckCorrectExecutor(); // We should set this flag before we are connected - DCHECK(isDisconnectedOrClosed()); - remoteResumeable_ = isResumable_ = resumable; + DCHECK(isDisconnected()); + isResumable_ = resumable; } -bool RSocketStateMachine::connectServer( - yarpl::Reference frameTransport, +void RSocketStateMachine::connectServer( + std::shared_ptr frameTransport, const SetupParameters& setupParams) { setResumable(setupParams.resumable); - return connect(std::move(frameTransport), true, setupParams.protocolVersion); + setProtocolVersionOrThrow(setupParams.protocolVersion, frameTransport); + connect(std::move(frameTransport)); + sendPendingFrames(); } bool RSocketStateMachine::resumeServer( - yarpl::Reference frameTransport, + std::shared_ptr frameTransport, const ResumeParameters& resumeParams) { - return connect( - std::move(frameTransport), false, resumeParams.protocolVersion) && - resumeFromPositionOrClose( - resumeParams.serverPosition, resumeParams.clientPosition); -} - -bool RSocketStateMachine::connect( - yarpl::Reference frameTransport, - bool sendingPendingFrames, - ProtocolVersion protocolVersion) { - debugCheckCorrectExecutor(); - CHECK(isDisconnectedOrClosed()); - CHECK(frameTransport); - CHECK(!frameTransport->isClosed()); - if (protocolVersion != ProtocolVersion::Unknown) { - if (frameSerializer_) { - if (frameSerializer_->protocolVersion() != protocolVersion) { - DCHECK(false); - std::runtime_error exn("Protocol version mismatch"); - frameTransport->closeWithError(std::move(exn)); - return false; - } - } else { - frameSerializer_ = - FrameSerializer::createFrameSerializer(protocolVersion); - if (!frameSerializer_) { - DCHECK(false); - std::runtime_error exn("Invalid protocol version"); - frameTransport->closeWithError(std::move(exn)); - return false; - } - } + const folly::Optional clientAvailable = + (resumeParams.clientPosition == kUnspecifiedResumePosition) + ? folly::none + : folly::make_optional( + resumeManager_->impliedPosition() - resumeParams.clientPosition); + + const int64_t serverAvailable = + resumeManager_->lastSentPosition() - resumeManager_->firstSentPosition(); + const int64_t serverDelta = + resumeManager_->lastSentPosition() - resumeParams.serverPosition; + + if (frameTransport) { + stats_->socketDisconnected(); } + closeFrameTransport( + std::runtime_error{"Connection being resumed, dropping old connection"}); + setProtocolVersionOrThrow(resumeParams.protocolVersion, frameTransport); + connect(std::move(frameTransport)); + + const auto result = resumeFromPositionOrClose( + resumeParams.serverPosition, resumeParams.clientPosition); + + stats_->serverResume( + clientAvailable, + serverAvailable, + serverDelta, + result ? RSocketStats::ResumeOutcome::SUCCESS + : RSocketStats::ResumeOutcome::FAILURE); + + return result; +} + +void RSocketStateMachine::connectClient( + std::shared_ptr transport, + SetupParameters params) { + auto const version = params.protocolVersion == ProtocolVersion::Unknown + ? ProtocolVersion::Latest + : params.protocolVersion; + + setProtocolVersionOrThrow(version, transport); + setResumable(params.resumable); - frameTransport_ = std::move(frameTransport); + Frame_SETUP frame( + (params.resumable ? FrameFlags::RESUME_ENABLE : FrameFlags::EMPTY_) | + (params.payload.metadata ? FrameFlags::METADATA : FrameFlags::EMPTY_), + version.major, + version.minor, + getKeepaliveTime(), + Frame_SETUP::kMaxLifetime, + std::move(params.token), + std::move(params.metadataMimeType), + std::move(params.dataMimeType), + std::move(params.payload)); + + // TODO: when the server returns back that it doesn't support resumability, we + // should retry without resumability + + VLOG(3) << "Out: " << frame; + + connect(std::move(transport)); + // making sure we send setup frame first + outputFrame(frameSerializer_->serializeOut(std::move(frame))); + // then the rest of the cached frames will be sent + sendPendingFrames(); +} + +void RSocketStateMachine::resumeClient( + ResumeIdentificationToken token, + std::shared_ptr transport, + std::unique_ptr resumeCallback, + ProtocolVersion version) { + // Cold-resumption. Set the serializer. + if (!frameSerializer_) { + CHECK(coldResumeHandler_); + coldResumeInProgress_ = true; + } + + setProtocolVersionOrThrow( + version == ProtocolVersion::Unknown ? ProtocolVersion::Latest : version, + transport); + + Frame_RESUME resumeFrame( + std::move(token), + resumeManager_->impliedPosition(), + resumeManager_->firstSentPosition(), + frameSerializer_->protocolVersion()); + VLOG(3) << "Out: " << resumeFrame; + + // Disconnect a previous client if there is one. + disconnect(std::runtime_error{"Resuming client on a different connection"}); + + setResumable(true); + reconnect(std::move(transport), std::move(resumeCallback)); + outputFrame(frameSerializer_->serializeOut(std::move(resumeFrame))); +} + +void RSocketStateMachine::connect(std::shared_ptr transport) { + VLOG(2) << "Connecting to transport " << transport.get(); + + CHECK(isDisconnected()); + CHECK(transport); + + // Keep a reference to the argument, make sure the instance survives until + // setFrameProcessor() returns. There can be terminating signals processed in + // that call which will nullify frameTransport_. + frameTransport_ = transport; + + CHECK(frameSerializer_); + frameSerializer_->preallocateFrameSizeField() = + transport->isConnectionFramed(); if (connectionEvents_) { connectionEvents_->onConnected(); } - // We need to create a hard reference to frameTransport_ to make sure the - // instance survives until the setFrameProcessor returns. There can be - // terminating signals processed in that call which will nullify - // frameTransport_. - auto frameTransportCopy = frameTransport_; + // Keep a reference to stats, as processing frames might close this instance. + auto const stats = stats_; + frameTransport_->setFrameProcessor(shared_from_this()); + stats->socketConnected(); +} - // Keep a reference to this, as processing frames might close the - // ReactiveSocket instance. - auto copyThis = shared_from_this(); - frameTransport_->setFrameProcessor(copyThis); +void RSocketStateMachine::sendPendingFrames() { + DCHECK(!resumeCallback_); - if (sendingPendingFrames) { - DCHECK(!resumeCallback_); - // we are free to try to send frames again - // not all frames might be sent if the connection breaks, the rest of them - // will queue up again - auto outputFrames = streamState_->moveOutputPendingFrames(); - for (auto& frame : outputFrames) { - outputFrameOrEnqueue(std::move(frame)); - } + StreamsWriterImpl::sendPendingFrames(); - // TODO: turn on only after setup frame was received - if (keepaliveTimer_) { - keepaliveTimer_->start(shared_from_this()); - } + // TODO: turn on only after setup frame was received + if (keepaliveTimer_) { + keepaliveTimer_->start(shared_from_this()); } - - return true; } void RSocketStateMachine::disconnect(folly::exception_wrapper ex) { - debugCheckCorrectExecutor(); - VLOG(6) << "disconnect"; - if (isDisconnectedOrClosed()) { + VLOG(2) << "Disconnecting transport"; + if (isDisconnected()) { return; } @@ -159,43 +287,46 @@ void RSocketStateMachine::disconnect(folly::exception_wrapper ex) { connectionEvents_->onDisconnected(ex); } - closeFrameTransport(std::move(ex), StreamCompletionSignal::CONNECTION_END); - pauseStreams(); + closeFrameTransport(std::move(ex)); + + if (connectionEvents_) { + connectionEvents_->onStreamsPaused(); + } + stats_->socketDisconnected(); } void RSocketStateMachine::close( folly::exception_wrapper ex, StreamCompletionSignal signal) { - debugCheckCorrectExecutor(); - - if (isClosed_) { + if (isClosed()) { return; } + isClosed_ = true; stats_->socketClosed(signal); VLOG(6) << "close"; - if (resumeCallback_) { - resumeCallback_->onResumeError( - ConnectionException(ex ? ex.what().c_str() : "RS closing")); - resumeCallback_.reset(); + if (auto resumeCallback = std::move(resumeCallback_)) { + resumeCallback->onResumeError( + ConnectionException(ex ? ex.get_exception()->what() : "RS closing")); } - auto connectionEvents = std::move(connectionEvents_); - if (connectionEvents) { - connectionEvents->onClosed(ex); + closeStreams(signal); + closeFrameTransport(ex); + + if (auto connectionEvents = std::move(connectionEvents_)) { + connectionEvents->onClosed(std::move(ex)); } - closeStreams(signal); - closeFrameTransport(std::move(ex), signal); + if (closeCallback_) { + closeCallback_->remove(*this); + } } -void RSocketStateMachine::closeFrameTransport( - folly::exception_wrapper ex, - StreamCompletionSignal signal) { - if (isDisconnectedOrClosed()) { +void RSocketStateMachine::closeFrameTransport(folly::exception_wrapper ex) { + if (isDisconnected()) { DCHECK(!resumeCallback_); return; } @@ -205,37 +336,30 @@ void RSocketStateMachine::closeFrameTransport( keepaliveTimer_->stop(); } - if (resumeCallback_) { - resumeCallback_->onResumeError( - ConnectionException(ex ? ex.what().c_str() : "connection closing")); - resumeCallback_.reset(); + if (auto resumeCallback = std::move(resumeCallback_)) { + resumeCallback->onResumeError(ConnectionException( + ex ? ex.get_exception()->what() : "connection closing")); } // Echo the exception to the frameTransport only if the frameTransport started // closing with error. Otherwise we sent some error frame over the wire and // we are closing the transport cleanly. - if (signal == StreamCompletionSignal::CONNECTION_ERROR) { - frameTransport_->closeWithError(std::move(ex)); - } else { + if (frameTransport_) { frameTransport_->close(); + frameTransport_ = nullptr; } - - frameTransport_ = nullptr; } void RSocketStateMachine::disconnectOrCloseWithError(Frame_ERROR&& errorFrame) { - debugCheckCorrectExecutor(); if (isResumable_) { - disconnect(std::runtime_error(errorFrame.payload_.data->cloneAsValue() - .moveToFbString() - .toStdString())); + std::runtime_error exn{errorFrame.payload_.moveDataToString()}; + disconnect(std::move(exn)); } else { closeWithError(std::move(errorFrame)); } } void RSocketStateMachine::closeWithError(Frame_ERROR&& error) { - debugCheckCorrectExecutor(); VLOG(3) << "closeWithError " << error.payload_.data->cloneAsValue().moveToFbString(); @@ -265,204 +389,293 @@ void RSocketStateMachine::closeWithError(Frame_ERROR&& error) { signal = StreamCompletionSignal::ERROR; } - auto exception = std::runtime_error( - error.payload_.data->cloneAsValue().moveToFbString().toStdString()); - + std::runtime_error exn{error.payload_.cloneDataToString()}; if (frameSerializer_) { - outputFrameOrEnqueue(std::move(error)); + outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(error))); } - close(std::move(exception), signal); + close(std::move(exn), signal); } void RSocketStateMachine::reconnect( - yarpl::Reference newFrameTransport, + std::shared_ptr newFrameTransport, std::unique_ptr resumeCallback) { - debugCheckCorrectExecutor(); CHECK(newFrameTransport); CHECK(resumeCallback); + CHECK(!resumeCallback_); CHECK(isResumable_); - CHECK(mode_ == ReactiveSocketMode::CLIENT); + CHECK(mode_ == RSocketMode::CLIENT); // TODO: output frame buffer should not be written to the new connection until // we receive resume ok resumeCallback_ = std::move(resumeCallback); - connect(std::move(newFrameTransport), false, ProtocolVersion::Unknown); + connect(std::move(newFrameTransport)); } -void RSocketStateMachine::addStream( - StreamId streamId, - yarpl::Reference stateMachine) { - debugCheckCorrectExecutor(); - auto result = - streamState_->streams_.emplace(streamId, std::move(stateMachine)); - (void)result; - assert(result.second); -} - -void RSocketStateMachine::endStream( - StreamId streamId, - StreamCompletionSignal signal) { - debugCheckCorrectExecutor(); - VLOG(6) << "endStream"; - // The signal must be idempotent. - if (!endStreamInternal(streamId, signal)) { +void RSocketStateMachine::requestStream( + Payload request, + std::shared_ptr> responseSink) { + if (isDisconnected()) { + disconnectError(std::move(responseSink)); return; } - DCHECK( - signal == StreamCompletionSignal::CANCEL || - signal == StreamCompletionSignal::COMPLETE || - signal == StreamCompletionSignal::APPLICATION_ERROR || - signal == StreamCompletionSignal::ERROR); + + auto const streamId = getNextStreamId(); + auto stateMachine = std::make_shared( + shared_from_this(), streamId, std::move(request)); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe(std::move(responseSink)); } -bool RSocketStateMachine::endStreamInternal( - StreamId streamId, - StreamCompletionSignal signal) { - VLOG(6) << "endStreamInternal"; - auto it = streamState_->streams_.find(streamId); - if (it == streamState_->streams_.end()) { - // Unsubscribe handshake initiated by the connection, we're done. - return false; +std::shared_ptr> +RSocketStateMachine::requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> responseSink) { + if (isDisconnected()) { + disconnectError(std::move(responseSink)); + return nullptr; } - resumeCache_->onStreamClosed(streamId); - - // Remove from the map before notifying the stateMachine. - auto stateMachine = std::move(it->second); - streamState_->streams_.erase(it); - stateMachine->endStream(signal); - return true; + auto const streamId = getNextStreamId(); + std::shared_ptr stateMachine; + if (hasInitialRequest) { + stateMachine = std::make_shared( + std::move(request), shared_from_this(), streamId); + } else { + stateMachine = + std::make_shared(shared_from_this(), streamId); + } + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe(std::move(responseSink)); + return stateMachine; } -void RSocketStateMachine::closeStreams(StreamCompletionSignal signal) { - // Close all streams. - while (!streamState_->streams_.empty()) { - auto oldSize = streamState_->streams_.size(); - auto result = - endStreamInternal(streamState_->streams_.begin()->first, signal); - (void)oldSize; - (void)result; - // TODO(stupaq): what kind of a user action could violate these - // assertions? - assert(result); - assert(streamState_->streams_.size() == oldSize - 1); +void RSocketStateMachine::requestResponse( + Payload request, + std::shared_ptr> responseSink) { + if (isDisconnected()) { + disconnectError(std::move(responseSink)); + return; } -} -void RSocketStateMachine::pauseStreams() { - // for (auto& streamKV : streamState_->streams_) { - // streamKV.second->pauseStream(*requestHandler_); - // } + auto const streamId = getNextStreamId(); + auto stateMachine = std::make_shared( + shared_from_this(), streamId, std::move(request)); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe(std::move(responseSink)); } -void RSocketStateMachine::resumeStreams() { - // for (auto& streamKV : streamState_->streams_) { - // streamKV.second->resumeStream(*requestHandler_); - // } +void RSocketStateMachine::closeStreams(StreamCompletionSignal signal) { + while (!streams_.empty()) { + auto it = streams_.begin(); + auto streamStateMachine = std::move(it->second); + streams_.erase(it); + streamStateMachine->endStream(signal); + } } void RSocketStateMachine::processFrame(std::unique_ptr frame) { - auto thisPtr = this->shared_from_this(); - executor_.add([ thisPtr, frame = std::move(frame) ]() mutable { - thisPtr->processFrameImpl(std::move(frame)); - }); -} - -void RSocketStateMachine::processFrameImpl( - std::unique_ptr frame) { if (isClosed()) { + VLOG(4) << "StateMachine has been closed. Discarding incoming frame"; return; } if (!ensureOrAutodetectFrameSerializer(*frame)) { - DLOG(FATAL) << "frame serializer is not set"; - // Failed to autodetect protocol version - closeWithError(Frame_ERROR::invalidFrame()); + constexpr auto msg = "Cannot detect protocol version"; + closeWithError(Frame_ERROR::connectionError(msg)); return; } - auto frameType = frameSerializer_->peekFrameType(*frame); + const auto frameType = frameSerializer_->peekFrameType(*frame); stats_->frameRead(frameType); - auto streamIdPtr = frameSerializer_->peekStreamId(*frame); - if (!streamIdPtr) { - // Failed to deserialize the frame. - closeWithError(Frame_ERROR::invalidFrame()); + const auto optStreamId = frameSerializer_->peekStreamId(*frame, false); + if (!optStreamId) { + constexpr auto msg = "Cannot decode stream ID"; + closeWithError(Frame_ERROR::connectionError(msg)); + return; + } + + const auto frameLength = frame->computeChainDataLength(); + const auto streamId = *optStreamId; + handleFrame(streamId, frameType, std::move(frame)); + resumeManager_->trackReceivedFrame( + frameLength, frameType, streamId, getConsumerAllowance(streamId)); +} + +void RSocketStateMachine::onTerminal(folly::exception_wrapper ex) { + if (isResumable_) { + disconnect(std::move(ex)); return; } - auto streamId = *streamIdPtr; - resumeCache_->trackReceivedFrame(*frame, frameType, streamId); - if (streamId == 0) { - handleConnectionFrame(frameType, std::move(frame)); + const auto termSignal = ex ? StreamCompletionSignal::CONNECTION_ERROR + : StreamCompletionSignal::CONNECTION_END; + close(std::move(ex), termSignal); +} + +void RSocketStateMachine::onKeepAliveFrame( + ResumePosition resumePosition, + std::unique_ptr data, + bool keepAliveRespond) { + resumeManager_->resetUpToPosition(resumePosition); + if (mode_ == RSocketMode::SERVER) { + if (keepAliveRespond) { + sendKeepalive(FrameFlags::EMPTY_, std::move(data)); + } else { + closeWithError(Frame_ERROR::connectionError("keepalive without flag")); + } + } else { + if (keepAliveRespond) { + closeWithError(Frame_ERROR::connectionError( + "client received keepalive with respond flag")); + } else if (keepaliveTimer_) { + keepaliveTimer_->keepaliveReceived(); + } + stats_->keepaliveReceived(); + } +} + +void RSocketStateMachine::onMetadataPushFrame( + std::unique_ptr metadata) { + requestResponder_->handleMetadataPush(std::move(metadata)); +} + +void RSocketStateMachine::onResumeOkFrame(ResumePosition resumePosition) { + if (!resumeCallback_) { + constexpr auto msg = "Received RESUME_OK while not resuming"; + closeWithError(Frame_ERROR::connectionError(msg)); return; } - // during the time when we are resuming we are can't receive any other - // than connection level frames which drives the resumption - // TODO(lehecka): this assertion should be handled more elegantly using - // different state machine - if (resumeCallback_) { - LOG(ERROR) << "received stream frames during resumption"; - closeWithError(Frame_ERROR::invalidFrame()); + if (!resumeManager_->isPositionAvailable(resumePosition)) { + auto const msg = folly::sformat( + "Client cannot resume, server position {} is not available", + resumePosition); + closeWithError(Frame_ERROR::connectionError(msg)); return; } - handleStreamFrame(streamId, frameType, std::move(frame)); -} + if (coldResumeInProgress_) { + setNextStreamId(resumeManager_->getLargestUsedStreamId()); + for (const auto& it : resumeManager_->getStreamResumeInfos()) { + const auto streamId = it.first; + const StreamResumeInfo& streamResumeInfo = it.second; + if (streamResumeInfo.requester == RequestOriginator::LOCAL && + streamResumeInfo.streamType == StreamType::STREAM) { + auto subscriber = coldResumeHandler_->handleRequesterResumeStream( + streamResumeInfo.streamToken, streamResumeInfo.consumerAllowance); + + auto stateMachine = std::make_shared( + shared_from_this(), streamId, Payload()); + // Set requested to true (since cold resumption) + stateMachine->setRequested(streamResumeInfo.consumerAllowance); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); + stateMachine->subscribe( + std::make_shared>( + std::move(subscriber), + *folly::EventBaseManager::get()->getEventBase())); + } + } + coldResumeInProgress_ = false; + } -void RSocketStateMachine::onTerminal(folly::exception_wrapper ex) { - auto thisPtr = this->shared_from_this(); - executor_.add([ thisPtr, e = std::move(ex) ]() mutable { - thisPtr->onTerminalImpl(std::move(e)); - }); + auto resumeCallback = std::move(resumeCallback_); + resumeCallback->onResumeOk(); + resumeFromPosition(resumePosition); } -void RSocketStateMachine::onTerminalImpl(folly::exception_wrapper ex) { - if (isResumable_) { - disconnect(std::move(ex)); +void RSocketStateMachine::onErrorFrame( + StreamId streamId, + ErrorCode errorCode, + Payload payload) { + if (streamId != 0) { + if (!ensureNotInResumption()) { + return; + } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + if (errorCode != ErrorCode::APPLICATION_ERROR) { + // Encapsulate non-user errors with runtime_error, which is more + // suitable for LOGging. + stateMachine->handleError( + std::runtime_error(payload.moveDataToString())); + } else { + // Don't expose user errors + stateMachine->handleError(ErrorWithPayload(std::move(payload))); + } + } } else { - auto termSignal = ex ? StreamCompletionSignal::CONNECTION_ERROR - : StreamCompletionSignal::CONNECTION_END; - close(std::move(ex), termSignal); + // TODO: handle INVALID_SETUP, UNSUPPORTED_SETUP, REJECTED_SETUP + if ((errorCode == ErrorCode::CONNECTION_ERROR || + errorCode == ErrorCode::REJECTED_RESUME) && + resumeCallback_) { + auto resumeCallback = std::move(resumeCallback_); + resumeCallback->onResumeError( + ResumptionException(payload.cloneDataToString())); + // fall through + } + close( + std::runtime_error(payload.moveDataToString()), + StreamCompletionSignal::ERROR); } } -void RSocketStateMachine::handleConnectionFrame( +void RSocketStateMachine::onSetupFrame() { + // this should be processed in SetupResumeAcceptor + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onResumeFrame() { + // this should be processed in SetupResumeAcceptor + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onReservedFrame() { + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onLeaseFrame() { + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onExtFrame() { + onUnexpectedFrame(0); +} + +void RSocketStateMachine::onUnexpectedFrame(StreamId streamId) { + auto&& msg = folly::sformat("Unexpected frame for stream {}", streamId); + closeWithError(Frame_ERROR::connectionError(msg)); +} + +void RSocketStateMachine::handleFrame( + StreamId streamId, FrameType frameType, std::unique_ptr payload) { switch (frameType) { case FrameType::KEEPALIVE: { Frame_KEEPALIVE frame; - if (!deserializeFrameOrError( - remoteResumeable_, frame, std::move(payload))) { + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - resumeCache_->resetUpToPosition(frame.position_); - if (mode_ == ReactiveSocketMode::SERVER) { - if (!!(frame.header_.flags_ & FrameFlags::KEEPALIVE_RESPOND)) { - sendKeepalive(FrameFlags::EMPTY, std::move(frame.data_)); - } else { - closeWithError( - Frame_ERROR::connectionError("keepalive without flag")); - } - } else { - if (!!(frame.header_.flags_ & FrameFlags::KEEPALIVE_RESPOND)) { - closeWithError(Frame_ERROR::connectionError( - "client received keepalive with respond flag")); - } else if (keepaliveTimer_) { - keepaliveTimer_->keepaliveReceived(); - } - } + VLOG(3) << mode_ << " In: " << frame; + onKeepAliveFrame( + frame.position_, + std::move(frame.data_), + !!(frame.header_.flags & FrameFlags::KEEPALIVE_RESPOND)); return; } case FrameType::METADATA_PUSH: { Frame_METADATA_PUSH frame; - if (deserializeFrameOrError(frame, std::move(payload))) { - VLOG(3) << "In: " << frame; - requestResponder_->handleMetadataPush(std::move(frame.metadata_)); + if (!deserializeFrameOrError(frame, std::move(payload))) { + return; } + VLOG(3) << mode_ << " In: " << frame; + onMetadataPushFrame(std::move(frame.metadata_)); return; } case FrameType::RESUME_OK: { @@ -470,21 +683,8 @@ void RSocketStateMachine::handleConnectionFrame( if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - if (resumeCallback_) { - if (resumeCache_->isPositionAvailable(frame.position_)) { - resumeCallback_->onResumeOk(); - resumeCallback_.reset(); - resumeFromPosition(frame.position_); - } else { - closeWithError(Frame_ERROR::connectionError(folly::to( - "Client cannot resume, server position ", - frame.position_, - " is not available."))); - } - } else { - closeWithError(Frame_ERROR::invalidFrame()); - } + VLOG(3) << mode_ << " In: " << frame; + onResumeOkFrame(frame.position_); return; } case FrameType::ERROR: { @@ -492,197 +692,297 @@ void RSocketStateMachine::handleConnectionFrame( if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - - // TODO: handle INVALID_SETUP, UNSUPPORTED_SETUP, REJECTED_SETUP - - if ((frame.errorCode_ == ErrorCode::CONNECTION_ERROR || - frame.errorCode_ == ErrorCode::REJECTED_RESUME) && - resumeCallback_) { - resumeCallback_->onResumeError( - ResumptionException(frame.payload_.moveDataToString())); - resumeCallback_.reset(); - // fall through - } - - close( - std::runtime_error(frame.payload_.moveDataToString()), - StreamCompletionSignal::ERROR); + VLOG(3) << mode_ << " In: " << frame; + onErrorFrame(streamId, frame.errorCode_, std::move(frame.payload_)); return; } - case FrameType::SETUP: // this should be processed in SetupResumeAcceptor - case FrameType::RESUME: // this should be processed in SetupResumeAcceptor + case FrameType::SETUP: + onSetupFrame(); + return; + case FrameType::RESUME: + onResumeFrame(); + return; case FrameType::RESERVED: + onReservedFrame(); + return; case FrameType::LEASE: - case FrameType::REQUEST_RESPONSE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_STREAM: - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_N: - case FrameType::CANCEL: - case FrameType::PAYLOAD: - case FrameType::EXT: - default: - closeWithError(Frame_ERROR::unexpectedFrame()); + onLeaseFrame(); return; - } -} - -void RSocketStateMachine::handleStreamFrame( - StreamId streamId, - FrameType frameType, - std::unique_ptr serializedFrame) { - auto it = streamState_->streams_.find(streamId); - if (it == streamState_->streams_.end()) { - handleUnknownStream(streamId, frameType, std::move(serializedFrame)); - return; - } - - // we are purposely making a copy of the reference here to avoid problems with - // lifetime of the stateMachine when a terminating signal is delivered which - // will cause the stateMachine to be destroyed while in one of its methods - auto stateMachine = it->second; - - switch (frameType) { case FrameType::REQUEST_N: { Frame_REQUEST_N frameRequestN; - if (!deserializeFrameOrError(frameRequestN, std::move(serializedFrame))) { + if (!deserializeFrameOrError(frameRequestN, std::move(payload))) { return; } - VLOG(3) << "In: " << frameRequestN; - stateMachine->handleRequestN(frameRequestN.requestN_); + VLOG(3) << mode_ << " In: " << frameRequestN; + onRequestNFrame(streamId, frameRequestN.requestN_); break; } case FrameType::CANCEL: { - VLOG(3) << "In: " << Frame_CANCEL(); - stateMachine->handleCancel(); + VLOG(3) << mode_ << " In: " << Frame_CANCEL(streamId); + onCancelFrame(streamId); break; } case FrameType::PAYLOAD: { Frame_PAYLOAD framePayload; - if (!deserializeFrameOrError(framePayload, std::move(serializedFrame))) { + if (!deserializeFrameOrError(framePayload, std::move(payload))) { return; } - VLOG(3) << "In: " << framePayload; - stateMachine->handlePayload( + VLOG(3) << mode_ << " In: " << framePayload; + onPayloadFrame( + streamId, std::move(framePayload.payload_), + framePayload.header_.flagsFollows(), framePayload.header_.flagsComplete(), framePayload.header_.flagsNext()); break; } - case FrameType::ERROR: { - Frame_ERROR frameError; - if (!deserializeFrameOrError(frameError, std::move(serializedFrame))) { - return; - } - VLOG(3) << "In: " << frameError; - stateMachine->handleError( - std::runtime_error(frameError.payload_.moveDataToString())); - break; - } - case FrameType::REQUEST_CHANNEL: - case FrameType::REQUEST_RESPONSE: - case FrameType::RESERVED: - case FrameType::SETUP: - case FrameType::LEASE: - case FrameType::KEEPALIVE: - case FrameType::REQUEST_FNF: - case FrameType::REQUEST_STREAM: - case FrameType::METADATA_PUSH: - case FrameType::RESUME: - case FrameType::RESUME_OK: - case FrameType::EXT: - closeWithError(Frame_ERROR::unexpectedFrame()); - break; - default: - // because of compatibility with future frame types we will just ignore - // unknown frames - break; - } -} - -void RSocketStateMachine::handleUnknownStream( - StreamId streamId, - FrameType frameType, - std::unique_ptr serializedFrame) { - DCHECK(streamId != 0); - // TODO: comparing string versions is odd because from version - // 10.0 the lexicographic comparison doesn't work - // we should change the version to struct - if (frameSerializer_->protocolVersion() > ProtocolVersion{0, 0} && - !streamsFactory_.registerNewPeerStreamId(streamId)) { - return; - } - - switch (frameType) { case FrameType::REQUEST_CHANNEL: { Frame_REQUEST_CHANNEL frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - auto stateMachine = - streamsFactory_.createChannelResponder(frame.requestN_, streamId); - auto requestSink = requestResponder_->handleRequestChannelCore( - std::move(frame.payload_), streamId, stateMachine); - stateMachine->subscribe(requestSink); + VLOG(3) << mode_ << " In: " << frame; + onRequestChannelFrame( + streamId, + frame.requestN_, + std::move(frame.payload_), + frame.header_.flagsComplete(), + frame.header_.flagsNext(), + frame.header_.flagsFollows()); break; } case FrameType::REQUEST_STREAM: { Frame_REQUEST_STREAM frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - auto stateMachine = - streamsFactory_.createStreamResponder(frame.requestN_, streamId); - requestResponder_->handleRequestStreamCore( - std::move(frame.payload_), streamId, stateMachine); + VLOG(3) << mode_ << " In: " << frame; + onRequestStreamFrame( + streamId, + frame.requestN_, + std::move(frame.payload_), + frame.header_.flagsFollows()); break; } case FrameType::REQUEST_RESPONSE: { Frame_REQUEST_RESPONSE frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - auto stateMachine = - streamsFactory_.createRequestResponseResponder(streamId); - requestResponder_->handleRequestResponseCore( - std::move(frame.payload_), streamId, stateMachine); + VLOG(3) << mode_ << " In: " << frame; + onRequestResponseFrame( + streamId, std::move(frame.payload_), frame.header_.flagsFollows()); break; } case FrameType::REQUEST_FNF: { Frame_REQUEST_FNF frame; - if (!deserializeFrameOrError(frame, std::move(serializedFrame))) { + if (!deserializeFrameOrError(frame, std::move(payload))) { return; } - VLOG(3) << "In: " << frame; - // no stream tracking is necessary - requestResponder_->handleFireAndForget( - std::move(frame.payload_), streamId); + VLOG(3) << mode_ << " In: " << frame; + onFireAndForgetFrame( + streamId, std::move(frame.payload_), frame.header_.flagsFollows()); break; } - - case FrameType::RESUME: - case FrameType::SETUP: - case FrameType::METADATA_PUSH: - case FrameType::LEASE: - case FrameType::KEEPALIVE: - case FrameType::RESERVED: - case FrameType::REQUEST_N: - case FrameType::CANCEL: - case FrameType::PAYLOAD: - case FrameType::ERROR: - case FrameType::RESUME_OK: case FrameType::EXT: + onExtFrame(); + return; + + default: { + stats_->unknownFrameReceived(); + // per rsocket spec, we will ignore any other unknown frames + return; + } + } +} + +std::shared_ptr +RSocketStateMachine::getStreamStateMachine(StreamId streamId) { + const auto&& it = streams_.find(streamId); + if (it == streams_.end()) { + return nullptr; + } + // we are purposely making a copy of the reference here to avoid problems with + // lifetime of the stateMachine when a terminating signal is delivered which + // will cause the stateMachine to be destroyed while in one of its methods + return it->second; +} + +bool RSocketStateMachine::ensureNotInResumption() { + if (resumeCallback_) { + // during the time when we are resuming we are can't receive any other + // than connection level frames which drives the resumption + // TODO(lehecka): this assertion should be handled more elegantly using + // different state machine + constexpr auto msg = "Received stream frame while resuming"; + LOG(ERROR) << msg; + closeWithError(Frame_ERROR::connectionError(msg)); + return false; + } + return true; +} + +void RSocketStateMachine::onRequestNFrame( + StreamId streamId, + uint32_t requestN) { + if (!ensureNotInResumption()) { + return; + } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + stateMachine->handleRequestN(requestN); + } +} + +void RSocketStateMachine::onCancelFrame(StreamId streamId) { + if (!ensureNotInResumption()) { + return; + } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + stateMachine->handleCancel(); + } +} + +void RSocketStateMachine::onPayloadFrame( + StreamId streamId, + Payload payload, + bool flagsFollows, + bool flagsComplete, + bool flagsNext) { + if (!ensureNotInResumption()) { + return; + } + // we ignore messages for streams which don't exist + if (auto stateMachine = getStreamStateMachine(streamId)) { + stateMachine->handlePayload( + std::move(payload), flagsComplete, flagsNext, flagsFollows); + } +} + +void RSocketStateMachine::onRequestStreamFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = + std::make_shared(shared_from_this(), streamId, requestN); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); +} + +void RSocketStateMachine::onRequestChannelFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = std::make_shared( + shared_from_this(), streamId, requestN); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload( + std::move(payload), flagsComplete, flagsNext, flagsFollows); +} + +void RSocketStateMachine::onRequestResponseFrame( + StreamId streamId, + Payload payload, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = + std::make_shared(shared_from_this(), streamId); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); +} + +void RSocketStateMachine::onFireAndForgetFrame( + StreamId streamId, + Payload payload, + bool flagsFollows) { + if (!ensureNotInResumption() || !isNewStreamId(streamId)) { + return; + } + auto stateMachine = + std::make_shared(shared_from_this(), streamId); + const auto result = streams_.emplace(streamId, stateMachine); + DCHECK(result.second); // ensured by calling isNewStreamId + stateMachine->handlePayload(std::move(payload), false, false, flagsFollows); +} + +bool RSocketStateMachine::isNewStreamId(StreamId streamId) { + if (frameSerializer_->protocolVersion() > ProtocolVersion{0, 0} && + !registerNewPeerStreamId(streamId)) { + return false; + } + return true; +} + +std::shared_ptr> +RSocketStateMachine::onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + if (coldResumeHandler_ && streamType != StreamType::FNF) { + auto streamToken = + coldResumeHandler_->generateStreamToken(payload, streamId, streamType); + resumeManager_->onStreamOpen( + streamId, RequestOriginator::REMOTE, streamToken, streamType); + } + + switch (streamType) { + case StreamType::CHANNEL: + return requestResponder_->handleRequestChannel( + std::move(payload), streamId, std::move(response)); + + case StreamType::STREAM: + requestResponder_->handleRequestStream( + std::move(payload), streamId, std::move(response)); + return nullptr; + + case StreamType::REQUEST_RESPONSE: + // the other overload method should be called + CHECK(false); + folly::assume_unreachable(); + + case StreamType::FNF: + requestResponder_->handleFireAndForget(std::move(payload), streamId); + return nullptr; + default: - DLOG(ERROR) << "unknown stream frame (streamId=" << streamId - << " frameType=" << frameType << ")"; - closeWithError(Frame_ERROR::unexpectedFrame()); + CHECK(false) << "unknown value: " << streamType; + folly::assume_unreachable(); } } -/// @} + +void RSocketStateMachine::onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + CHECK(streamType == StreamType::REQUEST_RESPONSE); + + if (coldResumeHandler_) { + auto streamToken = + coldResumeHandler_->generateStreamToken(payload, streamId, streamType); + resumeManager_->onStreamOpen( + streamId, RequestOriginator::REMOTE, streamToken, streamType); + } + requestResponder_->handleRequestResponse( + std::move(payload), streamId, std::move(response)); +} void RSocketStateMachine::sendKeepalive(std::unique_ptr data) { sendKeepalive(FrameFlags::KEEPALIVE_RESPOND, std::move(data)); @@ -691,132 +991,109 @@ void RSocketStateMachine::sendKeepalive(std::unique_ptr data) { void RSocketStateMachine::sendKeepalive( FrameFlags flags, std::unique_ptr data) { - debugCheckCorrectExecutor(); Frame_KEEPALIVE pingFrame( - flags, resumeCache_->impliedPosition(), std::move(data)); - VLOG(3) << "Out: " << pingFrame; - outputFrameOrEnqueue( - frameSerializer_->serializeOut(std::move(pingFrame), remoteResumeable_)); -} - -void RSocketStateMachine::tryClientResume( - const ResumeIdentificationToken& token, - yarpl::Reference frameTransport, - std::unique_ptr resumeCallback) { - Frame_RESUME resumeFrame( - token, - resumeCache_->impliedPosition(), - resumeCache_->lastResetPosition(), - frameSerializer_->protocolVersion()); - VLOG(3) << "Out: " << resumeFrame; - frameTransport->outputFrameOrEnqueue( - frameSerializer_->serializeOut(std::move(resumeFrame))); - - // if the client was still connected we will disconnected the old connection - // with a clear error message - disconnect(std::runtime_error("resuming client on a different connection")); - setResumable(true); - reconnect(std::move(frameTransport), std::move(resumeCallback)); + flags, resumeManager_->impliedPosition(), std::move(data)); + VLOG(3) << mode_ << " Out: " << pingFrame; + outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(pingFrame))); + stats_->keepaliveSent(); } -bool RSocketStateMachine::isPositionAvailable(ResumePosition position) { - debugCheckCorrectExecutor(); - return resumeCache_->isPositionAvailable(position); +bool RSocketStateMachine::isPositionAvailable(ResumePosition position) const { + return resumeManager_->isPositionAvailable(position); } bool RSocketStateMachine::resumeFromPositionOrClose( ResumePosition serverPosition, ResumePosition clientPosition) { - debugCheckCorrectExecutor(); DCHECK(!resumeCallback_); - DCHECK(!isDisconnectedOrClosed()); - DCHECK(mode_ == ReactiveSocketMode::SERVER); + DCHECK(!isDisconnected()); + DCHECK(mode_ == RSocketMode::SERVER); - bool clientPositionExist = (clientPosition == kUnspecifiedResumePosition) || - resumeCache_->canResumeFrom(clientPosition); + const bool clientPositionExist = + (clientPosition == kUnspecifiedResumePosition) || + clientPosition <= resumeManager_->impliedPosition(); if (clientPositionExist && - resumeCache_->isPositionAvailable(serverPosition)) { - Frame_RESUME_OK resumeOkFrame(resumeCache_->impliedPosition()); + resumeManager_->isPositionAvailable(serverPosition)) { + Frame_RESUME_OK resumeOkFrame{resumeManager_->impliedPosition()}; VLOG(3) << "Out: " << resumeOkFrame; - frameTransport_->outputFrameOrEnqueue( + frameTransport_->outputFrameOrDrop( frameSerializer_->serializeOut(std::move(resumeOkFrame))); resumeFromPosition(serverPosition); return true; - } else { - closeWithError(Frame_ERROR::connectionError(folly::to( - "Cannot resume server, client lastServerPosition=", - serverPosition, - " firstClientPosition=", - clientPosition, - " is not available. Last reset position is ", - resumeCache_->lastResetPosition()))); - return false; } + + auto const msg = folly::to( + "Cannot resume server, client lastServerPosition=", + serverPosition, + " firstClientPosition=", + clientPosition, + " is not available. Last reset position is ", + resumeManager_->firstSentPosition()); + + closeWithError(Frame_ERROR::connectionError(msg)); + return false; } void RSocketStateMachine::resumeFromPosition(ResumePosition position) { DCHECK(!resumeCallback_); - DCHECK(!isDisconnectedOrClosed()); - DCHECK(resumeCache_->isPositionAvailable(position)); + DCHECK(!isDisconnected()); + DCHECK(resumeManager_->isPositionAvailable(position)); - resumeStreams(); - resumeCache_->sendFramesFromPosition(position, *frameTransport_); + if (connectionEvents_) { + connectionEvents_->onStreamsResumed(); + } + resumeManager_->sendFramesFromPosition(position, *frameTransport_); - for (auto& frame : streamState_->moveOutputPendingFrames()) { + auto frames = consumePendingOutputFrames(); + for (auto& frame : frames) { outputFrameOrEnqueue(std::move(frame)); } - if (!isDisconnectedOrClosed() && keepaliveTimer_) { + if (!isDisconnected() && keepaliveTimer_) { keepaliveTimer_->start(shared_from_this()); } } -void RSocketStateMachine::outputFrameOrEnqueue( - std::unique_ptr frame) { - debugCheckCorrectExecutor(); +bool RSocketStateMachine::shouldQueue() { // if we are resuming we cant send any frames until we receive RESUME_OK - if (!isDisconnectedOrClosed() && !resumeCallback_) { - outputFrame(std::move(frame)); - } else { - streamState_->enqueueOutputPendingFrame(std::move(frame)); - } + return isDisconnected() || resumeCallback_; } -void RSocketStateMachine::requestFireAndForget(Payload request) { - Frame_REQUEST_FNF frame( - streamsFactory().getNextStreamId(), - FrameFlags::EMPTY, - std::move(std::move(request))); - outputFrameOrEnqueue(std::move(frame)); +void RSocketStateMachine::fireAndForget(Payload request) { + auto const streamId = getNextStreamId(); + Frame_REQUEST_FNF frame{streamId, FrameFlags::EMPTY_, std::move(request)}; + outputFrameOrEnqueue(frameSerializer_->serializeOut(std::move(frame))); } void RSocketStateMachine::metadataPush(std::unique_ptr metadata) { - Frame_METADATA_PUSH metadataPushFrame(std::move(metadata)); - outputFrameOrEnqueue(std::move(metadataPushFrame)); + Frame_METADATA_PUSH metadataPushFrame{std::move(metadata)}; + outputFrameOrEnqueue( + frameSerializer_->serializeOut(std::move(metadataPushFrame))); } void RSocketStateMachine::outputFrame(std::unique_ptr frame) { - DCHECK(!isDisconnectedOrClosed()); + DCHECK(!isDisconnected()); - auto frameType = frameSerializer_->peekFrameType(*frame); + const auto frameType = frameSerializer_->peekFrameType(*frame); stats_->frameWritten(frameType); if (isResumable_) { - auto streamIdPtr = frameSerializer_->peekStreamId(*frame); - resumeCache_->trackSentFrame(*frame, frameType, streamIdPtr); + auto streamIdPtr = frameSerializer_->peekStreamId(*frame, false); + CHECK(streamIdPtr) << "Error in serialized frame."; + resumeManager_->trackSentFrame( + *frame, frameType, *streamIdPtr, getConsumerAllowance(*streamIdPtr)); } - frameTransport_->outputFrameOrEnqueue(std::move(frame)); + frameTransport_->outputFrameOrDrop(std::move(frame)); } uint32_t RSocketStateMachine::getKeepaliveTime() const { - debugCheckCorrectExecutor(); return keepaliveTimer_ ? static_cast(keepaliveTimer_->keepaliveTime().count()) : Frame_SETUP::kMaxKeepaliveTime; } -bool RSocketStateMachine::isDisconnectedOrClosed() const { +bool RSocketStateMachine::isDisconnected() const { return !frameTransport_; } @@ -824,168 +1101,136 @@ bool RSocketStateMachine::isClosed() const { return isClosed_; } -void RSocketStateMachine::debugCheckCorrectExecutor() const { - DCHECK( - !dynamic_cast(&executor_) || - dynamic_cast(&executor_)->isInEventBaseThread()); +void RSocketStateMachine::writeNewStream( + StreamId streamId, + StreamType streamType, + uint32_t initialRequestN, + Payload payload) { + if (coldResumeHandler_ && streamType != StreamType::FNF) { + const auto streamToken = + coldResumeHandler_->generateStreamToken(payload, streamId, streamType); + resumeManager_->onStreamOpen( + streamId, RequestOriginator::LOCAL, streamToken, streamType); + } + + StreamsWriterImpl::writeNewStream( + streamId, streamType, initialRequestN, std::move(payload)); } -void RSocketStateMachine::setFrameSerializer( - std::unique_ptr frameSerializer) { - CHECK(frameSerializer); - // serializer is not interchangeable, it would screw up resumability - // CHECK(!frameSerializer_); - frameSerializer_ = std::move(frameSerializer); +void RSocketStateMachine::onStreamClosed(StreamId streamId) { + streams_.erase(streamId); + resumeManager_->onStreamClosed(streamId); } -void RSocketStateMachine::connectClientSendSetup( - std::unique_ptr connection, - SetupParameters setupParams) { - setFrameSerializer( - setupParams.protocolVersion == ProtocolVersion::Unknown - ? FrameSerializer::createCurrentVersion() - : FrameSerializer::createFrameSerializer( - setupParams.protocolVersion)); +bool RSocketStateMachine::ensureOrAutodetectFrameSerializer( + const folly::IOBuf& firstFrame) { + if (frameSerializer_) { + return true; + } - setResumable(setupParams.resumable); + if (mode_ != RSocketMode::SERVER) { + // this should never happen as clients are initized with FrameSerializer + // instance + DCHECK(false); + return false; + } - auto frameTransport = yarpl::make_ref(std::move(connection)); + auto serializer = FrameSerializer::createAutodetectedSerializer(firstFrame); + if (!serializer) { + LOG(ERROR) << "unable to detect protocol version"; + return false; + } - auto protocolVersion = frameSerializer_->protocolVersion(); + VLOG(2) << "detected protocol version" << serializer->protocolVersion(); + frameSerializer_ = std::move(serializer); + frameSerializer_->preallocateFrameSizeField() = + frameTransport_ && frameTransport_->isConnectionFramed(); - Frame_SETUP frame( - setupParams.resumable ? FrameFlags::RESUME_ENABLE : FrameFlags::EMPTY, - protocolVersion.major, - protocolVersion.minor, - getKeepaliveTime(), - Frame_SETUP::kMaxLifetime, - std::move(setupParams.token), - std::move(setupParams.metadataMimeType), - std::move(setupParams.dataMimeType), - std::move(setupParams.payload)); + return true; +} - // TODO: when the server returns back that it doesn't support resumability, we - // should retry without resumability +size_t RSocketStateMachine::getConsumerAllowance(StreamId streamId) const { + auto const it = streams_.find(streamId); + return it != streams_.end() ? it->second->getConsumerAllowance() : 0; +} - VLOG(3) << "Out: " << frame; - // making sure we send setup frame first - frameTransport->outputFrameOrEnqueue( - frameSerializer_->serializeOut(std::move(frame))); - // then the rest of the cached frames will be sent - connect(std::move(frameTransport), true, ProtocolVersion::Unknown); +void RSocketStateMachine::registerCloseCallback( + RSocketStateMachine::CloseCallback* callback) { + closeCallback_ = callback; } -void RSocketStateMachine::writeNewStream( - StreamId streamId, - StreamType streamType, - uint32_t initialRequestN, - Payload payload, - bool completed) { - switch (streamType) { - case StreamType::CHANNEL: - outputFrameOrEnqueue(Frame_REQUEST_CHANNEL( - streamId, - completed ? FrameFlags::COMPLETE : FrameFlags::EMPTY, - initialRequestN, - std::move(payload))); - break; +DuplexConnection* RSocketStateMachine::getConnection() { + return frameTransport_ ? frameTransport_->getConnection() : nullptr; +} - case StreamType::STREAM: - outputFrameOrEnqueue(Frame_REQUEST_STREAM( - streamId, FrameFlags::EMPTY, initialRequestN, std::move(payload))); - break; +void RSocketStateMachine::setProtocolVersionOrThrow( + ProtocolVersion version, + const std::shared_ptr& transport) { + CHECK(version != ProtocolVersion::Unknown); - case StreamType::REQUEST_RESPONSE: - outputFrameOrEnqueue(Frame_REQUEST_RESPONSE( - streamId, FrameFlags::EMPTY, std::move(payload))); - break; + // TODO(lehecka): this is a temporary guard to make sure the transport is + // explicitly closed when exceptions are thrown. The right solution is to + // automatically close duplex connection in the destructor when unique_ptr + // is released + auto transportGuard = folly::makeGuard([&] { transport->close(); }); - case StreamType::FNF: - outputFrameOrEnqueue( - Frame_REQUEST_FNF(streamId, FrameFlags::EMPTY, std::move(payload))); - break; + if (frameSerializer_) { + if (frameSerializer_->protocolVersion() != version) { + // serializer is not interchangeable, it would screw up resumability + throw std::runtime_error{"Protocol version mismatch"}; + } + } else { + auto frameSerializer = FrameSerializer::createFrameSerializer(version); + if (!frameSerializer) { + throw std::runtime_error{"Invalid protocol version"}; + } - default: - CHECK(false); // unknown type + frameSerializer_ = std::move(frameSerializer); + frameSerializer_->preallocateFrameSizeField() = + frameTransport_ && frameTransport_->isConnectionFramed(); } -} -void RSocketStateMachine::writeRequestN(StreamId streamId, uint32_t n) { - outputFrameOrEnqueue(Frame_REQUEST_N(streamId, n)); + transportGuard.dismiss(); } -void RSocketStateMachine::writePayload( - StreamId streamId, - Payload payload, - bool complete) { - Frame_PAYLOAD frame( - streamId, - FrameFlags::NEXT | (complete ? FrameFlags::COMPLETE : FrameFlags::EMPTY), - std::move(payload)); - outputFrameOrEnqueue(std::move(frame)); -} - -void RSocketStateMachine::writeCloseStream( - StreamId streamId, - StreamCompletionSignal signal, - Payload payload) { - switch (signal) { - case StreamCompletionSignal::COMPLETE: - outputFrameOrEnqueue(Frame_PAYLOAD::complete(streamId)); - break; - - case StreamCompletionSignal::CANCEL: - outputFrameOrEnqueue(Frame_CANCEL(streamId)); - break; - - case StreamCompletionSignal::ERROR: - outputFrameOrEnqueue(Frame_ERROR::error(streamId, std::move(payload))); - break; +StreamId RSocketStateMachine::getNextStreamId() { + constexpr auto limit = + static_cast(std::numeric_limits::max() - 2); - case StreamCompletionSignal::APPLICATION_ERROR: - outputFrameOrEnqueue( - Frame_ERROR::applicationError(streamId, std::move(payload))); - break; + auto const streamId = nextStreamId_; + if (streamId >= limit) { + throw std::runtime_error{"Ran out of stream IDs"}; + } - case StreamCompletionSignal::INVALID_SETUP: - case StreamCompletionSignal::UNSUPPORTED_SETUP: - case StreamCompletionSignal::REJECTED_SETUP: + CHECK_EQ(0, streams_.count(streamId)) + << "Next stream ID already exists in the streams map"; - case StreamCompletionSignal::CONNECTION_ERROR: - case StreamCompletionSignal::CONNECTION_END: - case StreamCompletionSignal::SOCKET_CLOSED: - default: - CHECK(false); // unexpected value - } + nextStreamId_ += 2; + return streamId; } -void RSocketStateMachine::onStreamClosed( - StreamId streamId, - StreamCompletionSignal signal) { - endStream(streamId, signal); +void RSocketStateMachine::setNextStreamId(StreamId streamId) { + nextStreamId_ = streamId + 2; } -bool RSocketStateMachine::ensureOrAutodetectFrameSerializer( - const folly::IOBuf& firstFrame) { - if (frameSerializer_) { - return true; - } - - if (mode_ != ReactiveSocketMode::SERVER) { - // this should never happen as clients are initized with FrameSerializer - // instance - DCHECK(false); +bool RSocketStateMachine::registerNewPeerStreamId(StreamId streamId) { + DCHECK_NE(0, streamId); + if (nextStreamId_ % 2 == streamId % 2) { + // if this is an unknown stream to the socket and this socket is + // generating such stream ids, it is an incoming frame on the stream which + // no longer exist return false; } - - auto serializer = FrameSerializer::createAutodetectedSerializer(firstFrame); - if (!serializer) { - LOG(ERROR) << "unable to detect protocol version"; + if (streamId <= lastPeerStreamId_) { + // receiving frame for a stream which no longer exists return false; } - - VLOG(2) << "detected protocol version" << serializer->protocolVersion(); - frameSerializer_ = std::move(serializer); + lastPeerStreamId_ = streamId; return true; } + +bool RSocketStateMachine::hasStreams() const { + return !streams_.empty(); +} + } // namespace rsocket diff --git a/rsocket/statemachine/RSocketStateMachine.h b/rsocket/statemachine/RSocketStateMachine.h index 1658234d6..71deeb322 100644 --- a/rsocket/statemachine/RSocketStateMachine.h +++ b/rsocket/statemachine/RSocketStateMachine.h @@ -1,34 +1,53 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include +#include #include +#include "rsocket/ColdResumeHandler.h" #include "rsocket/DuplexConnection.h" #include "rsocket/Payload.h" #include "rsocket/RSocketParameters.h" +#include "rsocket/ResumeManager.h" #include "rsocket/framing/FrameProcessor.h" +#include "rsocket/framing/FrameSerializer.h" #include "rsocket/internal/Common.h" -#include "rsocket/statemachine/StreamsFactory.h" +#include "rsocket/internal/KeepaliveTimer.h" +#include "rsocket/statemachine/StreamFragmentAccumulator.h" +#include "rsocket/statemachine/StreamStateMachineBase.h" #include "rsocket/statemachine/StreamsWriter.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/single/SingleObserver.h" namespace rsocket { class ClientResumeStatusCallback; class DuplexConnection; -class Frame_ERROR; -class FrameSerializer; class FrameTransport; +class Frame_ERROR; class KeepaliveTimer; -class ResumeCache; class RSocketConnectionEvents; class RSocketParameters; class RSocketResponder; +class RSocketResponderCore; class RSocketStateMachine; class RSocketStats; -class StreamState; -class StreamStateMachineBase; +class ResumeManager; +class RSocketStateMachineTest; class FrameSink { public: @@ -55,249 +74,276 @@ class FrameSink { class RSocketStateMachine final : public FrameSink, public FrameProcessor, - public StreamsWriter, + public StreamsWriterImpl, public std::enable_shared_from_this { public: RSocketStateMachine( - folly::Executor& executor, + std::shared_ptr requestResponder, + std::unique_ptr keepaliveTimer, + RSocketMode mode, + std::shared_ptr stats, + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler); + + RSocketStateMachine( std::shared_ptr requestResponder, - std::unique_ptr keepaliveTimer_, - ReactiveSocketMode mode, + std::unique_ptr keepaliveTimer, + RSocketMode mode, std::shared_ptr stats, - std::shared_ptr connectionEvents = - std::shared_ptr()); + std::shared_ptr connectionEvents, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler); + + ~RSocketStateMachine(); - void closeWithError(Frame_ERROR&& error); - void disconnectOrCloseWithError(Frame_ERROR&& error) override; + /// Create a new connection as a server. + void connectServer(std::shared_ptr, const SetupParameters&); - /// Kicks off connection procedure. - /// - /// May result, depending on the implementation of the DuplexConnection, in - /// processing of one or more frames. - bool connectServer( - yarpl::Reference, - const SetupParameters& setupParams); + /// Resume a connection as a server. + bool resumeServer(std::shared_ptr, const ResumeParameters&); - bool resumeServer( - yarpl::Reference, - const ResumeParameters& resumeParams); + /// Connect as a client. Sends a SETUP frame. + void connectClient(std::shared_ptr, SetupParameters); - /// Disconnects DuplexConnection from the stateMachine. - /// Existing streams will stay intact. - void disconnect(folly::exception_wrapper ex); + /// Resume a connection as a client. Sends a RESUME frame. + void resumeClient( + ResumeIdentificationToken, + std::shared_ptr, + std::unique_ptr, + ProtocolVersion); - /// Terminates underlying connection. - /// - /// This may synchronously deliver terminal signals to all - /// StreamAutomatonBase attached to this ConnectionAutomaton. + /// Disconnect the state machine's connection. Existing streams will stay + /// intact. + void disconnect(folly::exception_wrapper); + + /// Whether the connection has been disconnected or closed. + bool isDisconnected() const; + + /// Send an ERROR frame, and close the connection and all of its streams. + void closeWithError(Frame_ERROR&&); + + /// Disconnect the connection if it is resumable, otherwise send an ERROR + /// frame and close the connection and all of its streams. + void disconnectOrCloseWithError(Frame_ERROR&&) override; + + /// Close the connection and all of its streams. void close(folly::exception_wrapper, StreamCompletionSignal); - /// Terminate underlying connection and connect new connection - void reconnect( - yarpl::Reference, - std::unique_ptr); + void requestStream( + Payload request, + std::shared_ptr> responseSink); - ~RSocketStateMachine(); + std::shared_ptr> requestChannel( + Payload request, + bool hasInitialRequest, + std::shared_ptr> responseSink); - /// A contract exposed to StreamAutomatonBase, modelled after Subscriber - /// and Subscription contracts, while omitting flow control related signals. + void requestResponse( + Payload payload, + std::shared_ptr> responseSink); - /// Adds a stream stateMachine to the connection. - /// - /// This signal corresponds to Subscriber::onSubscribe. - /// - /// No frames will be issued as a result of this call. Stream stateMachine - /// must take care of writing appropriate frames to the connection, using - /// ::writeFrame after calling this method. - void addStream( - StreamId streamId, - yarpl::Reference stateMachine); + /// Send a REQUEST_FNF frame. + void fireAndForget(Payload); - /// Indicates that the stream should be removed from the connection. - /// - /// No frames will be issued as a result of this call. Stream stateMachine - /// must take care of writing appropriate frames to the connection, using - /// ::writeFrame, prior to calling this method. - /// - /// This signal corresponds to Subscriber::{onComplete,onError} and - /// Subscription::cancel. - /// Per ReactiveStreams specification: - /// 1. no other signal can be delivered during or after this one, - /// 2. "unsubscribe handshake" guarantees that the signal will be delivered - /// at least once, even if the stateMachine initiated stream closure, - /// 3. per "unsubscribe handshake", the stateMachine must deliver - /// corresponding - /// terminal signal to the connection. - /// - /// Additionally, in order to simplify implementation of stream stateMachine: - /// 4. the signal bound with a particular StreamId is idempotent and may be - /// delivered multiple times as long as the caller holds shared_ptr to - /// ConnectionAutomaton. - void endStream(StreamId streamId, StreamCompletionSignal signal); + /// Send a METADATA_PUSH frame. + void metadataPush(std::unique_ptr); - void sendKeepalive(std::unique_ptr data) override; + /// Send a KEEPALIVE frame, with the RESPOND flag set. + void sendKeepalive(std::unique_ptr) override; - void setResumable(bool resumable); + class CloseCallback { + public: + virtual ~CloseCallback() = default; + virtual void remove(RSocketStateMachine&) = 0; + }; - bool isPositionAvailable(ResumePosition position); + /// Register a callback to be called when the StateMachine is closed. + /// It will be used to inform the containers, i.e. ConnectionSet or + /// wangle::ConnectionManager, to don't store the StateMachine anymore. + void registerCloseCallback(CloseCallback* callback); - void outputFrameOrEnqueue(std::unique_ptr frame); + DuplexConnection* getConnection(); - template - void outputFrameOrEnqueue(T&& frame) { - VLOG(3) << "Out: " << frame; - outputFrameOrEnqueue(frameSerializer_->serializeOut(std::forward(frame))); - } + // Has active requests? + bool hasStreams() const; - void requestFireAndForget(Payload request); + private: + // connection scope signals + void onKeepAliveFrame( + ResumePosition resumePosition, + std::unique_ptr data, + bool keepAliveRespond); + void onMetadataPushFrame(std::unique_ptr metadata); + void onResumeOkFrame(ResumePosition resumePosition); + void onErrorFrame(StreamId streamId, ErrorCode errorCode, Payload payload); + + // stream scope signals + void onRequestNFrame(StreamId streamId, uint32_t requestN); + void onCancelFrame(StreamId streamId); + void onPayloadFrame( + StreamId streamId, + Payload payload, + bool flagsFollows, + bool flagsComplete, + bool flagsNext); - template - bool deserializeFrameOrError( - TFrame& frame, - std::unique_ptr payload) { - if (frameSerializer_->deserializeFrom(frame, std::move(payload))) { - return true; - } else { - closeWithError(Frame_ERROR::invalidFrame()); - return false; - } - } + void onRequestStreamFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsFollows); + void onRequestChannelFrame( + StreamId streamId, + uint32_t requestN, + Payload payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows); + void + onRequestResponseFrame(StreamId streamId, Payload payload, bool flagsFollows); + void + onFireAndForgetFrame(StreamId streamId, Payload payload, bool flagsFollows); + void onSetupFrame(); + void onResumeFrame(); + void onReservedFrame(); + void onLeaseFrame(); + void onExtFrame(); + void onUnexpectedFrame(StreamId streamId); + + std::shared_ptr getStreamStateMachine( + StreamId streamId); + + void connect(std::shared_ptr); - template - bool deserializeFrameOrError( - bool resumable, - TFrame& frame, - std::unique_ptr payload) { - if (frameSerializer_->deserializeFrom( - frame, std::move(payload), resumable)) { - return true; - } else { - closeWithError(Frame_ERROR::invalidFrame()); - return false; - } - } + /// Terminate underlying connection and connect new connection + void reconnect( + std::shared_ptr, + std::unique_ptr); + + void setResumable(bool); bool resumeFromPositionOrClose( ResumePosition serverPosition, ResumePosition clientPosition); - uint32_t getKeepaliveTime() const; - bool isDisconnectedOrClosed() const; - bool isClosed() const; - - StreamsFactory& streamsFactory() { - return streamsFactory_; - } - - void connectClientSendSetup( - std::unique_ptr connection, - SetupParameters setupParams); + bool isPositionAvailable(ResumePosition) const; - void metadataPush(std::unique_ptr metadata); + /// Whether the connection has been closed. + bool isClosed() const; - void tryClientResume( - const ResumeIdentificationToken& token, - yarpl::Reference frameTransport, - std::unique_ptr resumeCallback); + uint32_t getKeepaliveTime() const; - void setFrameSerializer(std::unique_ptr); + void sendPendingFrames() override; - RSocketStats& stats() { + // Should buffer the frame if the state machine is disconnected or in the + // process of resuming. + bool shouldQueue() override; + RSocketStats& stats() override { return *stats_; } - std::shared_ptr& connectionEvents() { - return connectionEvents_; + FrameSerializer& serializer() override { + return *frameSerializer_; } - private: - - bool connect( - yarpl::Reference, - bool sendingPendingFrames, - ProtocolVersion protocolVersion); + template + bool deserializeFrameOrError( + TFrame& frame, + std::unique_ptr buf) { + if (frameSerializer_->deserializeFrom(frame, std::move(buf))) { + return true; + } + closeWithError(Frame_ERROR::connectionError("Invalid frame")); + return false; + } - /// Performs the same actions as ::endStream without propagating closure - /// signal to the underlying connection. - /// - /// The call is idempotent and returns false iff a stream has not been found. - bool endStreamInternal(StreamId streamId, StreamCompletionSignal signal); - - /// @{ - /// FrameProcessor methods are implemented with ExecutorBase and automatic - /// marshaling - /// onto the right executor to allow DuplexConnection living on a different - /// executor and calling into ConnectionAutomaton. + // FrameProcessor. void processFrame(std::unique_ptr) override; void onTerminal(folly::exception_wrapper) override; - void processFrameImpl(std::unique_ptr); - void onTerminalImpl(folly::exception_wrapper); - /// @} - - void handleConnectionFrame( - FrameType frameType, - std::unique_ptr); - void handleStreamFrame( - StreamId streamId, - FrameType frameType, - std::unique_ptr frame); - void handleUnknownStream( - StreamId streamId, - FrameType frameType, - std::unique_ptr frame); + void handleFrame(StreamId, FrameType, std::unique_ptr); void closeStreams(StreamCompletionSignal); - void closeFrameTransport( - folly::exception_wrapper, - StreamCompletionSignal signal); - - void sendKeepalive(FrameFlags flags, std::unique_ptr data); + void closeFrameTransport(folly::exception_wrapper); - void resumeFromPosition(ResumePosition position); - void outputFrame(std::unique_ptr); + void sendKeepalive(FrameFlags, std::unique_ptr); - void debugCheckCorrectExecutor() const; - - void pauseStreams(); - void resumeStreams(); + void resumeFromPosition(ResumePosition); + void outputFrame(std::unique_ptr) override; void writeNewStream( StreamId streamId, StreamType streamType, uint32_t initialRequestN, + Payload payload) override; + + std::shared_ptr> onNewStreamReady( + StreamId streamId, + StreamType streamType, Payload payload, - bool completed) override; - void writeRequestN(StreamId streamId, uint32_t n) override; - void writePayload(StreamId streamId, Payload payload, bool complete) override; - void writeCloseStream( + std::shared_ptr> response) override; + void onNewStreamReady( StreamId streamId, - StreamCompletionSignal signal, - Payload payload) override; - void onStreamClosed(StreamId streamId, StreamCompletionSignal signal) + StreamType streamType, + Payload payload, + std::shared_ptr> response) override; + void onStreamClosed(StreamId) override; + bool ensureOrAutodetectFrameSerializer(const folly::IOBuf& firstFrame); + bool ensureNotInResumption(); + + size_t getConsumerAllowance(StreamId) const; + + void setProtocolVersionOrThrow( + ProtocolVersion version, + const std::shared_ptr& transport); + + bool isNewStreamId(StreamId streamId); + bool registerNewPeerStreamId(StreamId streamId); + StreamId getNextStreamId(); - ReactiveSocketMode mode_; + void setNextStreamId(StreamId streamId); + + /// Client/server mode this state machine is operating in. + const RSocketMode mode_; + + /// Whether the connection was initialized as resumable. bool isResumable_{false}; - bool remoteResumeable_{false}; + + /// Whether the connection has closed. bool isClosed_{false}; - std::shared_ptr resumeCache_; - std::shared_ptr streamState_; - std::shared_ptr requestResponder_; - yarpl::Reference frameTransport_; + /// Whether a cold resume is currently in progress. + bool coldResumeInProgress_{false}; + + std::shared_ptr stats_; + + /// Map of all individual stream state machines. + std::unordered_map> + streams_; + StreamId nextStreamId_; + StreamId lastPeerStreamId_{0}; + + // Manages all state needed for warm/cold resumption. + std::shared_ptr resumeManager_; + + const std::shared_ptr requestResponder_; + std::shared_ptr frameTransport_; std::unique_ptr frameSerializer_; const std::unique_ptr keepaliveTimer_; std::unique_ptr resumeCallback_; + std::shared_ptr coldResumeHandler_; - StreamsFactory streamsFactory_; - - const std::shared_ptr stats_; std::shared_ptr connectionEvents_; - folly::Executor& executor_; + + CloseCallback* closeCallback_{nullptr}; + + friend class RSocketStateMachineTest; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseRequester.cpp b/rsocket/statemachine/RequestResponseRequester.cpp index 7492a737f..2d39be17b 100644 --- a/rsocket/statemachine/RequestResponseRequester.cpp +++ b/rsocket/statemachine/RequestResponseRequester.cpp @@ -1,34 +1,41 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/RequestResponseRequester.h" -#include - #include "rsocket/internal/Common.h" #include "rsocket/statemachine/RSocketStateMachine.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void RequestResponseRequester::subscribe( - yarpl::Reference> subscriber) { - DCHECK(!isTerminated()); + std::shared_ptr> subscriber) { + DCHECK(state_ != State::CLOSED); DCHECK(!consumingSubscriber_); consumingSubscriber_ = std::move(subscriber); - consumingSubscriber_->onSubscribe(Reference(this)); + consumingSubscriber_->onSubscribe(shared_from_this()); if (state_ == State::NEW) { state_ = State::REQUESTED; newStream(StreamType::REQUEST_RESPONSE, 1, std::move(initialPayload_)); - } else { - if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError(std::make_exception_ptr( - std::runtime_error("cannot request more than 1 item"))); - } - closeStream(StreamCompletionSignal::ERROR); + return; } + + if (auto subscriber = std::move(consumingSubscriber_)) { + subscriber->onError(std::runtime_error("cannot request more than 1 item")); + } + removeFromWriter(); } void RequestResponseRequester::cancel() noexcept { @@ -36,17 +43,17 @@ void RequestResponseRequester::cancel() noexcept { switch (state_) { case State::NEW: state_ = State::CLOSED; - closeStream(StreamCompletionSignal::CANCEL); + removeFromWriter(); break; case State::REQUESTED: { state_ = State::CLOSED; - cancelStream(); - closeStream(StreamCompletionSignal::CANCEL); + writeCancel(); + removeFromWriter(); } break; case State::CLOSED: break; } - consumingSubscriber_ = nullptr; + consumingSubscriber_.reset(); } void RequestResponseRequester::endStream(StreamCompletionSignal signal) { @@ -64,13 +71,11 @@ void RequestResponseRequester::endStream(StreamCompletionSignal signal) { if (auto subscriber = std::move(consumingSubscriber_)) { DCHECK(signal != StreamCompletionSignal::COMPLETE); DCHECK(signal != StreamCompletionSignal::CANCEL); - subscriber->onError(std::make_exception_ptr( - StreamInterruptedException(static_cast(signal)))); + subscriber->onError(StreamInterruptedException(static_cast(signal))); } } -void RequestResponseRequester::handleError( - folly::exception_wrapper errorPayload) { +void RequestResponseRequester::handleError(folly::exception_wrapper ew) { switch (state_) { case State::NEW: // Cannot receive a frame before sending the initial request. @@ -79,9 +84,9 @@ void RequestResponseRequester::handleError( case State::REQUESTED: state_ = State::CLOSED; if (auto subscriber = std::move(consumingSubscriber_)) { - subscriber->onError(errorPayload.to_exception_ptr()); + subscriber->onError(std::move(ew)); } - closeStream(StreamCompletionSignal::ERROR); + removeFromWriter(); break; case State::CLOSED: break; @@ -90,42 +95,41 @@ void RequestResponseRequester::handleError( void RequestResponseRequester::handlePayload( Payload&& payload, - bool complete, - bool flagsNext) { - switch (state_) { - case State::NEW: - // Cannot receive a frame before sending the initial request. - CHECK(false); - break; - case State::REQUESTED: - state_ = State::CLOSED; - break; - case State::CLOSED: - // should not be receiving frames when closed - // if we ended up here, we broke some internal invariant of the class - CHECK(false); - break; + bool /*flagsComplete*/, + bool flagsNext, + bool flagsFollows) { + // (State::NEW) Cannot receive a frame before sending the initial request. + // (State::CLOSED) should not be receiving frames when closed + // if we fail here, we broke some internal invariant of the class + CHECK(state_ == State::REQUESTED); + + payloadFragments_.addPayload(std::move(payload), flagsNext, false); + + if (flagsFollows) { + // there will be more fragments to come + return; } - if (payload || flagsNext) { - consumingSubscriber_->onSuccess(std::move(payload)); + bool finalFlagsNext, finalFlagsComplete; + Payload finalPayload; + + std::tie(finalPayload, finalFlagsNext, finalFlagsComplete) = + payloadFragments_.consumePayloadAndFlags(); + + state_ = State::CLOSED; + + if (finalPayload || finalFlagsNext) { + consumingSubscriber_->onSuccess(std::move(finalPayload)); consumingSubscriber_ = nullptr; - } else if (!complete) { - errorStream("payload, NEXT or COMPLETE flag expected"); - return; + } else if (!finalFlagsComplete) { + writeInvalidError("Payload, NEXT or COMPLETE flag expected"); + endStream(StreamCompletionSignal::ERROR); } - closeStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); } -//void RequestResponseRequester::pauseStream(RequestHandler& requestHandler) { -// if (consumingSubscriber_) { -// requestHandler.onSubscriberPaused(consumingSubscriber_); -// } -//} -// -//void RequestResponseRequester::resumeStream(RequestHandler& requestHandler) { -// if (consumingSubscriber_) { -// requestHandler.onSubscriberResumed(consumingSubscriber_); -// } -//} +size_t RequestResponseRequester::getConsumerAllowance() const { + return (state_ == State::REQUESTED) ? 1 : 0; } + +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseRequester.h b/rsocket/statemachine/RequestResponseRequester.h index 2d362dfe2..be17cf546 100644 --- a/rsocket/statemachine/RequestResponseRequester.h +++ b/rsocket/statemachine/RequestResponseRequester.h @@ -1,8 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include #include "rsocket/Payload.h" #include "rsocket/statemachine/StreamStateMachineBase.h" #include "yarpl/single/SingleObserver.h" @@ -12,39 +23,48 @@ namespace rsocket { /// Implementation of stream stateMachine that represents a RequestResponse /// requester -class RequestResponseRequester : public StreamStateMachineBase, - public yarpl::single::SingleSubscription { - using Base = StreamStateMachineBase; - +class RequestResponseRequester + : public StreamStateMachineBase, + public yarpl::single::SingleSubscription, + public std::enable_shared_from_this { public: - explicit RequestResponseRequester(const Parameters& params, Payload payload) - : Base(params), initialPayload_(std::move(payload)) {} + RequestResponseRequester( + std::shared_ptr writer, + StreamId streamId, + Payload payload) + : StreamStateMachineBase(std::move(writer), streamId), + initialPayload_(std::move(payload)) {} void subscribe( - yarpl::Reference> subscriber); + std::shared_ptr> subscriber); private: void cancel() noexcept override; - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleError(folly::exception_wrapper errorPayload) override; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleError(folly::exception_wrapper ew) override; void endStream(StreamCompletionSignal signal) override; -// void pauseStream(RequestHandler& requestHandler) override; -// void resumeStream(RequestHandler& requestHandler) override; + size_t getConsumerAllowance() const override; /// State of the Subscription requester. enum class State : uint8_t { NEW, REQUESTED, CLOSED, - } state_{State::NEW}; + }; + + State state_{State::NEW}; /// The observer that will consume payloads. - yarpl::Reference> consumingSubscriber_; + std::shared_ptr> consumingSubscriber_; /// Initial payload which has to be sent with 1st request. Payload initialPayload_; }; -} +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseResponder.cpp b/rsocket/statemachine/RequestResponseResponder.cpp index d88747a74..ca51ff54b 100644 --- a/rsocket/statemachine/RequestResponseResponder.cpp +++ b/rsocket/statemachine/RequestResponseResponder.cpp @@ -1,90 +1,127 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/RequestResponseResponder.h" -#include - -#include "rsocket/Payload.h" -#include "yarpl/utils/ExceptionString.h" - namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void RequestResponseResponder::onSubscribe( - Reference subscription) noexcept { - if (StreamStateMachineBase::isTerminated()) { + std::shared_ptr subscription) { + DCHECK(State::NEW != state_); + if (state_ == State::CLOSED) { subscription->cancel(); return; } - DCHECK(!producingSubscription_); producingSubscription_ = std::move(subscription); } -void RequestResponseResponder::onSuccess(Payload response) noexcept { - DCHECK(producingSubscription_) << "didnt call onSubscribe"; +void RequestResponseResponder::onSuccess(Payload response) { + DCHECK(State::NEW != state_); + if (!producingSubscription_) { + return; + } + switch (state_) { case State::RESPONDING: { state_ = State::CLOSED; - writePayload(std::move(response), true); + writePayload(std::move(response), true /* complete */); producingSubscription_ = nullptr; - closeStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); break; } case State::CLOSED: break; + + case State::NEW: + default: + // class is internally misused + CHECK(false); } } -void RequestResponseResponder::onError(std::exception_ptr ex) noexcept { - DCHECK(producingSubscription_); +void RequestResponseResponder::onError(folly::exception_wrapper ex) { + DCHECK(State::NEW != state_); producingSubscription_ = nullptr; switch (state_) { case State::RESPONDING: { state_ = State::CLOSED; - applicationError(yarpl::exceptionStr(ex)); - closeStream(StreamCompletionSignal::APPLICATION_ERROR); + if (!ex.with_exception([this](rsocket::ErrorWithPayload& err) { + writeApplicationError(std::move(err.payload)); + })) { + writeApplicationError(ex.get_exception()->what()); + } + removeFromWriter(); } break; case State::CLOSED: break; + + case State::NEW: + default: + // class is internally misused + CHECK(false); } } -//void RequestResponseResponder::pauseStream(RequestHandler& requestHandler) { -// pausePublisherStream(requestHandler); -//} -// -//void RequestResponseResponder::resumeStream(RequestHandler& requestHandler) { -// resumePublisherStream(requestHandler); -//} - -void RequestResponseResponder::endStream(StreamCompletionSignal signal) { +void RequestResponseResponder::handleCancel() { switch (state_) { case State::RESPONDING: - // Spontaneous ::endStream signal means an error. - DCHECK(StreamCompletionSignal::COMPLETE != signal); - DCHECK(StreamCompletionSignal::CANCEL != signal); state_ = State::CLOSED; + removeFromWriter(); break; + case State::NEW: case State::CLOSED: break; } - if (auto subscription = std::move(producingSubscription_)) { - subscription->cancel(); +} + +void RequestResponseResponder::handlePayload( + Payload&& payload, + bool /*flagsComplete*/, + bool /*flagsNext*/, + bool flagsFollows) { + payloadFragments_.addPayloadIgnoreFlags(std::move(payload)); + + if (flagsFollows) { + // there will be more fragments to come + return; } - StreamStateMachineBase::endStream(signal); + + CHECK(state_ == State::NEW); + Payload finalPayload = payloadFragments_.consumePayloadIgnoreFlags(); + + state_ = State::RESPONDING; + onNewStreamReady( + StreamType::REQUEST_RESPONSE, + std::move(finalPayload), + shared_from_this()); } -void RequestResponseResponder::handleCancel() { +void RequestResponseResponder::endStream(StreamCompletionSignal signal) { switch (state_) { + case State::NEW: case State::RESPONDING: + // Spontaneous ::endStream signal means an error. + DCHECK(StreamCompletionSignal::COMPLETE != signal); + DCHECK(StreamCompletionSignal::CANCEL != signal); state_ = State::CLOSED; - closeStream(StreamCompletionSignal::CANCEL); break; case State::CLOSED: break; } + if (auto subscription = std::move(producingSubscription_)) { + subscription->cancel(); + } } -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/statemachine/RequestResponseResponder.h b/rsocket/statemachine/RequestResponseResponder.h index df73f4a10..3e7a5e37b 100644 --- a/rsocket/statemachine/RequestResponseResponder.h +++ b/rsocket/statemachine/RequestResponseResponder.h @@ -1,9 +1,21 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include "rsocket/Payload.h" #include "rsocket/statemachine/StreamStateMachineBase.h" -#include "yarpl/flowable/Subscriber.h" #include "yarpl/single/SingleObserver.h" #include "yarpl/single/SingleSubscription.h" @@ -11,31 +23,39 @@ namespace rsocket { /// Implementation of stream stateMachine that represents a RequestResponse /// responder -class RequestResponseResponder : public StreamStateMachineBase, - public yarpl::single::SingleObserver { +class RequestResponseResponder + : public StreamStateMachineBase, + public yarpl::single::SingleObserver, + public std::enable_shared_from_this { public: - explicit RequestResponseResponder(const Parameters& params) - : StreamStateMachineBase(params) {} - - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onSuccess(Payload) noexcept override; - void onError(std::exception_ptr) noexcept override; - + RequestResponseResponder( + std::shared_ptr writer, + StreamId streamId) + : StreamStateMachineBase(std::move(writer), streamId) {} + + void onSubscribe(std::shared_ptr) override; + void onSuccess(Payload) override; + void onError(folly::exception_wrapper) override; + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; void handleCancel() override; -// void pauseStream(RequestHandler&) override; -// void resumeStream(RequestHandler&) override; void endStream(StreamCompletionSignal) override; + private: /// State of the Subscription responder. enum class State : uint8_t { + NEW, RESPONDING, CLOSED, - } state_{State::RESPONDING}; + }; - yarpl::Reference producingSubscription_; + std::shared_ptr producingSubscription_; + State state_{State::NEW}; }; -} // reactivesocket +} // namespace rsocket diff --git a/rsocket/statemachine/StreamFragmentAccumulator.cpp b/rsocket/statemachine/StreamFragmentAccumulator.cpp new file mode 100644 index 000000000..07c7a3986 --- /dev/null +++ b/rsocket/statemachine/StreamFragmentAccumulator.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/StreamFragmentAccumulator.h" + +namespace rsocket { + +StreamFragmentAccumulator::StreamFragmentAccumulator() + : flagsComplete(false), flagsNext(false) {} + +void StreamFragmentAccumulator::addPayloadIgnoreFlags(Payload p) { + if (p.metadata) { + if (!fragments.metadata) { + fragments.metadata = std::move(p.metadata); + } else { + fragments.metadata->prev()->appendChain(std::move(p.metadata)); + } + } + + if (p.data) { + if (!fragments.data) { + fragments.data = std::move(p.data); + } else { + fragments.data->prev()->appendChain(std::move(p.data)); + } + } +} + +void StreamFragmentAccumulator::addPayload( + Payload p, + bool next, + bool complete) { + flagsNext |= next; + flagsComplete |= complete; + addPayloadIgnoreFlags(std::move(p)); +} + +Payload StreamFragmentAccumulator::consumePayloadIgnoreFlags() { + flagsComplete = false; + flagsNext = false; + return std::move(fragments); +} + +std::tuple +StreamFragmentAccumulator::consumePayloadAndFlags() { + auto ret = std::make_tuple( + std::move(fragments), bool(flagsNext), bool(flagsComplete)); + flagsComplete = false; + flagsNext = false; + return ret; +} + +} /* namespace rsocket */ diff --git a/rsocket/statemachine/StreamFragmentAccumulator.h b/rsocket/statemachine/StreamFragmentAccumulator.h new file mode 100644 index 000000000..0ed5227d8 --- /dev/null +++ b/rsocket/statemachine/StreamFragmentAccumulator.h @@ -0,0 +1,41 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" + +namespace rsocket { + +class StreamFragmentAccumulator { + public: + StreamFragmentAccumulator(); + + void addPayloadIgnoreFlags(Payload p); + void addPayload(Payload p, bool next, bool complete); + + Payload consumePayloadIgnoreFlags(); + std::tuple consumePayloadAndFlags(); + + bool anyFragments() const { + return fragments.data || fragments.metadata; + } + + private: + bool flagsComplete : 1; + bool flagsNext : 1; + Payload fragments; +}; + +} /* namespace rsocket */ diff --git a/rsocket/statemachine/StreamRequester.cpp b/rsocket/statemachine/StreamRequester.cpp index b13e8f8d2..52e407be9 100644 --- a/rsocket/statemachine/StreamRequester.cpp +++ b/rsocket/statemachine/StreamRequester.cpp @@ -1,74 +1,87 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/StreamRequester.h" namespace rsocket { -void StreamRequester::request(int64_t n) noexcept { - if (n == 0) { +void StreamRequester::setRequested(size_t n) { + VLOG(3) << "Setting allowance to " << n; + requested_ = true; + addImplicitAllowance(n); +} + +void StreamRequester::request(int64_t signedN) { + if (signedN <= 0 || consumerClosed()) { return; } - if(!requested_) { - requested_ = true; + const size_t n = signedN; - auto initialN = - n > Frame_REQUEST_N::kMaxRequestN ? Frame_REQUEST_N::kMaxRequestN : n; - auto remainingN = n > Frame_REQUEST_N::kMaxRequestN - ? n - Frame_REQUEST_N::kMaxRequestN - : 0; + if (requested_) { + generateRequest(n); + return; + } - // Send as much as possible with the initial request. - CHECK_GE(Frame_REQUEST_N::kMaxRequestN, initialN); + requested_ = true; - // We must inform ConsumerBase about an implicit allowance we have - // requested from the remote end. - addImplicitAllowance(initialN); - newStream( - StreamType::STREAM, - static_cast(initialN), - std::move(initialPayload_)); + // We must inform ConsumerBase about an implicit allowance we have requested + // from the remote end. + auto const initial = std::min(n, kMaxRequestN); + addImplicitAllowance(initial); + newStream(StreamType::STREAM, initial, std::move(initialPayload_)); - // Pump the remaining allowance into the ConsumerBase _after_ sending the - // initial request. - if (remainingN) { - Base::generateRequest(remainingN); - } - return; + // Pump the remaining allowance into the ConsumerBase _after_ sending the + // initial request. + if (n > initial) { + generateRequest(n - initial); } - - checkConsumerRequest(); - ConsumerBase::generateRequest(n); } -void StreamRequester::cancel() noexcept { +void StreamRequester::cancel() { + VLOG(5) << "StreamRequester::cancel(requested_=" << requested_ << ")"; + if (consumerClosed()) { + return; + } + cancelConsumer(); if (requested_) { - cancelConsumer(); - cancelStream(); + writeCancel(); } - closeStream(StreamCompletionSignal::CANCEL); -} - -void StreamRequester::endStream(StreamCompletionSignal signal) { - ConsumerBase::endStream(signal); + removeFromWriter(); } void StreamRequester::handlePayload( Payload&& payload, bool complete, - bool next) { - CHECK(requested_); - processPayload(std::move(payload), next); + bool next, + bool follows) { + if (!requested_) { + handleError(std::runtime_error("Haven't sent REQUEST_STREAM yet")); + return; + } + bool finalComplete = + processFragmentedPayload(std::move(payload), next, complete, follows); - if (complete) { + if (finalComplete) { completeConsumer(); - closeStream(StreamCompletionSignal::COMPLETE); + removeFromWriter(); } } -void StreamRequester::handleError(folly::exception_wrapper errorPayload) { - CHECK(requested_); - errorConsumer(std::move(errorPayload)); - closeStream(StreamCompletionSignal::ERROR); -} +void StreamRequester::handleError(folly::exception_wrapper ew) { + errorConsumer(std::move(ew)); + removeFromWriter(); } + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamRequester.h b/rsocket/statemachine/StreamRequester.h index ecaa9fda1..696b81472 100644 --- a/rsocket/statemachine/StreamRequester.h +++ b/rsocket/statemachine/StreamRequester.h @@ -1,45 +1,51 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include -#include "rsocket/internal/AllowanceSemaphore.h" #include "rsocket/statemachine/ConsumerBase.h" -namespace folly { -class exception_wrapper; -} - namespace rsocket { -enum class StreamCompletionSignal; - /// Implementation of stream stateMachine that represents a Stream requester class StreamRequester : public ConsumerBase { - using Base = ConsumerBase; - public: - // initialization of the ExecutorBase will be ignored for any of the - // derived classes - explicit StreamRequester(const Base::Parameters& params, Payload payload) - : Base(params), initialPayload_(std::move(payload)) {} - - private: - // implementation from ConsumerBase::SubscriptionBase - void request(int64_t) noexcept override; - void cancel() noexcept override; + StreamRequester( + std::shared_ptr writer, + StreamId streamId, + Payload payload) + : ConsumerBase(std::move(writer), streamId), + initialPayload_(std::move(payload)) {} - void handlePayload(Payload&& payload, bool complete, bool flagsNext) override; - void handleError(folly::exception_wrapper errorPayload) override; + void setRequested(size_t); - void endStream(StreamCompletionSignal) override; + void request(int64_t) override; + void cancel() override; - /// An allowance accumulated before the stream is initialised. - /// Remaining part of the allowance is forwarded to the ConsumerBase. - AllowanceSemaphore initialResponseAllowance_; + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleError(folly::exception_wrapper ew) override; - /// Initial payload which has to be sent with 1st request. + private: + /// Payload to be sent with the first request. Payload initialPayload_; + + /// Whether request() has been called. bool requested_{false}; }; -} // reactivesocket + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamResponder.cpp b/rsocket/statemachine/StreamResponder.cpp index 6f2e8c0fb..9dfa8a6a3 100644 --- a/rsocket/statemachine/StreamResponder.cpp +++ b/rsocket/statemachine/StreamResponder.cpp @@ -1,54 +1,102 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/StreamResponder.h" -#include "yarpl/utils/ExceptionString.h" namespace rsocket { -using namespace yarpl; -using namespace yarpl::flowable; - void StreamResponder::onSubscribe( - Reference subscription) noexcept { + std::shared_ptr subscription) { publisherSubscribe(std::move(subscription)); } -void StreamResponder::onNext(Payload response) noexcept { - checkPublisherOnNext(); - writePayload(std::move(response), false); +void StreamResponder::onNext(Payload response) { + if (publisherClosed()) { + return; + } + writePayload(std::move(response)); } -void StreamResponder::onComplete() noexcept { +void StreamResponder::onComplete() { + if (publisherClosed()) { + return; + } publisherComplete(); - completeStream(); - closeStream(StreamCompletionSignal::COMPLETE); + writeComplete(); + removeFromWriter(); } -void StreamResponder::onError(std::exception_ptr ex) noexcept { +void StreamResponder::onError(folly::exception_wrapper ew) { + if (publisherClosed()) { + return; + } publisherComplete(); - applicationError(yarpl::exceptionStr(ex)); - closeStream(StreamCompletionSignal::ERROR); + if (!ew.with_exception([this](rsocket::ErrorWithPayload& err) { + writeApplicationError(std::move(err.payload)); + })) { + writeApplicationError(ew.get_exception()->what()); + } + removeFromWriter(); } -//void StreamResponder::pauseStream(RequestHandler& requestHandler) { -// pausePublisherStream(requestHandler); -//} -// -//void StreamResponder::resumeStream(RequestHandler& requestHandler) { -// resumePublisherStream(requestHandler); -//} +void StreamResponder::handleRequestN(uint32_t n) { + processRequestN(n); +} -void StreamResponder::endStream(StreamCompletionSignal signal) { - terminatePublisher(); - StreamStateMachineBase::endStream(signal); +void StreamResponder::handleError(folly::exception_wrapper) { + handleCancel(); } -void StreamResponder::handleCancel() { - closeStream(StreamCompletionSignal::CANCEL); - publisherComplete(); +void StreamResponder::handlePayload( + Payload&& payload, + bool /*flagsComplete*/, + bool /*flagsNext*/, + bool flagsFollows) { + payloadFragments_.addPayloadIgnoreFlags(std::move(payload)); + + if (flagsFollows) { + // there will be more fragments to come + return; + } + + Payload finalPayload = payloadFragments_.consumePayloadIgnoreFlags(); + + if (newStream_) { + newStream_ = false; + onNewStreamReady( + StreamType::STREAM, std::move(finalPayload), shared_from_this()); + } else { + // per rsocket spec, ignore unexpected frame (payload) if it makes no sense + // in the semantic of the stream + } } -void StreamResponder::handleRequestN(uint32_t n) { - processRequestN(n); +void StreamResponder::handleCancel() { + if (publisherClosed()) { + return; + } + terminatePublisher(); + removeFromWriter(); } + +void StreamResponder::endStream(StreamCompletionSignal signal) { + if (publisherClosed()) { + return; + } + terminatePublisher(); + writeApplicationError(to_string(signal)); + removeFromWriter(); } + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamResponder.h b/rsocket/statemachine/StreamResponder.h index b0164dad4..09b445eda 100644 --- a/rsocket/statemachine/StreamResponder.h +++ b/rsocket/statemachine/StreamResponder.h @@ -1,8 +1,19 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include #include "rsocket/statemachine/PublisherBase.h" #include "rsocket/statemachine/StreamStateMachineBase.h" #include "yarpl/flowable/Subscriber.h" @@ -12,26 +23,34 @@ namespace rsocket { /// Implementation of stream stateMachine that represents a Stream responder class StreamResponder : public StreamStateMachineBase, public PublisherBase, - public yarpl::flowable::Subscriber { + public yarpl::flowable::Subscriber, + public std::enable_shared_from_this { public: - // initialization of the ExecutorBase will be ignored for any of the - // derived classes - explicit StreamResponder(uint32_t initialRequestN, const Parameters& params) - : StreamStateMachineBase(params), PublisherBase(initialRequestN) {} + StreamResponder( + std::shared_ptr writer, + StreamId streamId, + uint32_t initialRequestN) + : StreamStateMachineBase(std::move(writer), streamId), + PublisherBase(initialRequestN) {} - protected: + void onSubscribe(std::shared_ptr) override; + void onNext(Payload) override; + void onComplete() override; + void onError(folly::exception_wrapper) override; + + void handlePayload( + Payload&& payload, + bool flagsComplete, + bool flagsNext, + bool flagsFollows) override; + void handleRequestN(uint32_t) override; + void handleError(folly::exception_wrapper) override; void handleCancel() override; - void handleRequestN(uint32_t n) override; - private: - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload) noexcept override; - void onComplete() noexcept override; - void onError(std::exception_ptr) noexcept override; - -// void pauseStream(RequestHandler&) override; -// void resumeStream(RequestHandler&) override; void endStream(StreamCompletionSignal) override; + + private: + bool newStream_{true}; }; -} // reactivesocket + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamState.cpp b/rsocket/statemachine/StreamState.cpp deleted file mode 100644 index 7cdbd4edd..000000000 --- a/rsocket/statemachine/StreamState.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/statemachine/StreamState.h" - -#include "rsocket/RSocketStats.h" - -namespace rsocket { - -StreamState::StreamState(RSocketStats& stats) : stats_(stats) {} - -StreamState::~StreamState() { - onClearFrames(); -} - -void StreamState::enqueueOutputPendingFrame( - std::unique_ptr frame) { - auto length = frame->computeChainDataLength(); - stats_.streamBufferChanged(1, static_cast(length)); - dataLength_ += length; - outputFrames_.push_back(std::move(frame)); -} - -std::deque> -StreamState::moveOutputPendingFrames() { - onClearFrames(); - return std::move(outputFrames_); -} - -void StreamState::onClearFrames() { - auto numFrames = outputFrames_.size(); - if (numFrames != 0) { - stats_.streamBufferChanged( - -static_cast(numFrames), -static_cast(dataLength_)); - dataLength_ = 0; - } -} -} diff --git a/rsocket/statemachine/StreamState.h b/rsocket/statemachine/StreamState.h deleted file mode 100644 index 14907aaf4..000000000 --- a/rsocket/statemachine/StreamState.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include - -#include "rsocket/statemachine/StreamStateMachineBase.h" -#include "yarpl/Refcounted.h" - -namespace rsocket { - -class RSocketStateMachine; -class RSocketStats; -class StreamStateMachineBase; - -class StreamState { - public: - explicit StreamState(RSocketStats& stats); - ~StreamState(); - - void enqueueOutputPendingFrame(std::unique_ptr frame); - - std::deque> moveOutputPendingFrames(); - - std::unordered_map> - streams_; - - private: - /// Called to update stats when outputFrames_ is about to be cleared. - void onClearFrames(); - - RSocketStats& stats_; - - /// Total data length of all IOBufs in outputFrames_. - uint64_t dataLength_{0}; - - std::deque> outputFrames_; -}; -} diff --git a/rsocket/statemachine/StreamStateMachineBase.cpp b/rsocket/statemachine/StreamStateMachineBase.cpp index 098c595f0..f0988fff7 100644 --- a/rsocket/statemachine/StreamStateMachineBase.cpp +++ b/rsocket/statemachine/StreamStateMachineBase.cpp @@ -1,80 +1,100 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/statemachine/StreamStateMachineBase.h" - #include - #include "rsocket/statemachine/RSocketStateMachine.h" #include "rsocket/statemachine/StreamsWriter.h" namespace rsocket { -void StreamStateMachineBase::handlePayload(Payload&&, bool, bool) { - VLOG(4) << "Unexpected handlePayload"; -} - void StreamStateMachineBase::handleRequestN(uint32_t) { VLOG(4) << "Unexpected handleRequestN"; } void StreamStateMachineBase::handleError(folly::exception_wrapper) { - closeStream(StreamCompletionSignal::ERROR); + endStream(StreamCompletionSignal::ERROR); + removeFromWriter(); } void StreamStateMachineBase::handleCancel() { VLOG(4) << "Unexpected handleCancel"; } -void StreamStateMachineBase::endStream(StreamCompletionSignal) { - isTerminated_ = true; +size_t StreamStateMachineBase::getConsumerAllowance() const { + return 0; } void StreamStateMachineBase::newStream( StreamType streamType, uint32_t initialRequestN, - Payload payload, - bool completed) { + Payload payload) { writer_->writeNewStream( - streamId_, streamType, initialRequestN, std::move(payload), completed); + streamId_, streamType, initialRequestN, std::move(payload)); } -void StreamStateMachineBase::writePayload(Payload&& payload, bool complete) { - writer_->writePayload(streamId_, std::move(payload), complete); +void StreamStateMachineBase::writeRequestN(uint32_t n) { + writer_->writeRequestN(Frame_REQUEST_N{streamId_, n}); } -void StreamStateMachineBase::writeRequestN(uint32_t n) { - writer_->writeRequestN(streamId_, n); +void StreamStateMachineBase::writeCancel() { + writer_->writeCancel(Frame_CANCEL{streamId_}); } -void StreamStateMachineBase::applicationError(std::string errorPayload) { - // TODO: a bad frame for a stream should not bring down the whole socket - // https://github.com/ReactiveSocket/reactivesocket-cpp/issues/311 - writer_->writeCloseStream( - streamId_, - StreamCompletionSignal::APPLICATION_ERROR, - Payload(std::move(errorPayload))); +void StreamStateMachineBase::writePayload(Payload&& payload, bool complete) { + auto const flags = + FrameFlags::NEXT | (complete ? FrameFlags::COMPLETE : FrameFlags::EMPTY_); + Frame_PAYLOAD frame{streamId_, flags, std::move(payload)}; + writer_->writePayload(std::move(frame)); } -void StreamStateMachineBase::errorStream(std::string errorPayload) { - writer_->writeCloseStream( - streamId_, - StreamCompletionSignal::ERROR, - Payload(std::move(errorPayload))); - closeStream(StreamCompletionSignal::ERROR); +void StreamStateMachineBase::writeComplete() { + writer_->writePayload(Frame_PAYLOAD::complete(streamId_)); } -void StreamStateMachineBase::cancelStream() { - writer_->writeCloseStream( - streamId_, StreamCompletionSignal::CANCEL, Payload()); +void StreamStateMachineBase::writeApplicationError(folly::StringPiece msg) { + writer_->writeError(Frame_ERROR::applicationError(streamId_, msg)); } -void StreamStateMachineBase::completeStream() { - writer_->writeCloseStream( - streamId_, StreamCompletionSignal::COMPLETE, Payload()); +void StreamStateMachineBase::writeApplicationError(Payload&& payload) { + writer_->writeError( + Frame_ERROR::applicationError(streamId_, std::move(payload))); } -void StreamStateMachineBase::closeStream(StreamCompletionSignal signal) { - writer_->onStreamClosed(streamId_, signal); +void StreamStateMachineBase::writeInvalidError(folly::StringPiece msg) { + writer_->writeError(Frame_ERROR::invalid(streamId_, msg)); +} + +void StreamStateMachineBase::removeFromWriter() { + writer_->onStreamClosed(streamId_); // TODO: set writer_ to nullptr } -} // reactivesocket + +std::shared_ptr> +StreamStateMachineBase::onNewStreamReady( + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + return writer_->onNewStreamReady( + streamId_, streamType, std::move(payload), std::move(response)); +} + +void StreamStateMachineBase::onNewStreamReady( + StreamType streamType, + Payload payload, + std::shared_ptr> response) { + writer_->onNewStreamReady( + streamId_, streamType, std::move(payload), std::move(response)); +} +} // namespace rsocket diff --git a/rsocket/statemachine/StreamStateMachineBase.h b/rsocket/statemachine/StreamStateMachineBase.h index fb6037814..012c3b7fa 100644 --- a/rsocket/statemachine/StreamStateMachineBase.h +++ b/rsocket/statemachine/StreamStateMachineBase.h @@ -1,15 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once -#include -#include -#include - #include +#include "rsocket/framing/FrameHeader.h" #include "rsocket/internal/Common.h" -#include "yarpl/Refcounted.h" +#include "rsocket/statemachine/StreamFragmentAccumulator.h" +#include "yarpl/Flowable.h" +#include "yarpl/Single.h" namespace folly { class IOBuf; @@ -20,34 +31,29 @@ namespace rsocket { class StreamsWriter; struct Payload; -/// /// A common base class of all state machines. /// /// The instances might be destroyed on a different thread than they were /// created. -class StreamStateMachineBase : public virtual yarpl::Refcounted { +class StreamStateMachineBase { public: - /// A dependent type which encapsulates all parameters needed to initialise - /// any of the classes and the final automata. Must be the only argument to - /// the - /// constructor of any class and must be passed by const& to class's Base. - struct Parameters { - Parameters(std::shared_ptr _writer, StreamId _streamId) - : writer(std::move(_writer)), streamId(_streamId) {} - - std::shared_ptr writer; - StreamId streamId{0}; - }; - - explicit StreamStateMachineBase(Parameters params) - : writer_(std::move(params.writer)), streamId_(params.streamId) {} + StreamStateMachineBase( + std::shared_ptr writer, + StreamId streamId) + : writer_(std::move(writer)), streamId_(streamId) {} virtual ~StreamStateMachineBase() = default; - virtual void handlePayload(Payload&& payload, bool complete, bool flagsNext); + virtual void handlePayload( + Payload&& payload, + bool complete, + bool flagsNext, + bool flagsFollows) = 0; virtual void handleRequestN(uint32_t n); - virtual void handleError(folly::exception_wrapper errorPayload); + virtual void handleError(folly::exception_wrapper); virtual void handleCancel(); + virtual size_t getConsumerAllowance() const; + /// Indicates a terminal signal from the connection. /// /// This signal corresponds to Subscriber::{onComplete,onError} and @@ -59,36 +65,41 @@ class StreamStateMachineBase : public virtual yarpl::Refcounted { /// 3. per "unsubscribe handshake", the state machine must deliver /// corresponding /// terminal signal to the connection. - virtual void endStream(StreamCompletionSignal signal); - /// @} - -// virtual void pauseStream(RequestHandler& requestHandler) = 0; -// virtual void resumeStream(RequestHandler& requestHandler) = 0; + virtual void endStream(StreamCompletionSignal) {} protected: - bool isTerminated() const { - return isTerminated_; - } + void + newStream(StreamType streamType, uint32_t initialRequestN, Payload payload); + + void writeRequestN(uint32_t); + void writeCancel(); + + void writePayload(Payload&& payload, bool complete = false); + void writeComplete(); + void writeApplicationError(folly::StringPiece); + void writeApplicationError(Payload&& payload); + void writeInvalidError(folly::StringPiece); + + void removeFromWriter(); - void newStream( + std::shared_ptr> onNewStreamReady( StreamType streamType, - uint32_t initialRequestN, Payload payload, - bool completed = false); - void writePayload(Payload&& payload, bool complete); - void writeRequestN(uint32_t n); - void applicationError(std::string errorPayload); - void errorStream(std::string errorPayload); - void cancelStream(); - void completeStream(); - void closeStream(StreamCompletionSignal signal); + std::shared_ptr> response); + + void onNewStreamReady( + StreamType streamType, + Payload payload, + std::shared_ptr> response); /// A partially-owning pointer to the connection, the stream runs on. /// It is declared as const to allow only ctor to initialize it for thread /// safety of the dtor. const std::shared_ptr writer_; + StreamFragmentAccumulator payloadFragments_; + + private: const StreamId streamId_; - // TODO: remove and nulify the writer_ instead - bool isTerminated_{false}; }; -} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamsFactory.cpp b/rsocket/statemachine/StreamsFactory.cpp deleted file mode 100644 index c3a417ee4..000000000 --- a/rsocket/statemachine/StreamsFactory.cpp +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "rsocket/statemachine/StreamsFactory.h" - -#include "rsocket/statemachine/ChannelRequester.h" -#include "rsocket/statemachine/ChannelResponder.h" -#include "rsocket/statemachine/RSocketStateMachine.h" -#include "rsocket/statemachine/RequestResponseRequester.h" -#include "rsocket/statemachine/RequestResponseResponder.h" -#include "rsocket/statemachine/StreamRequester.h" -#include "rsocket/statemachine/StreamResponder.h" - -namespace rsocket { - -using namespace yarpl; - -StreamsFactory::StreamsFactory( - RSocketStateMachine& connection, - ReactiveSocketMode mode) - : connection_(connection), - nextStreamId_( - mode == ReactiveSocketMode::CLIENT - ? 1 /*Streams initiated by a client MUST use - odd-numbered stream identifiers*/ - : 2 /*streams initiated by the server MUST use - even-numbered stream identifiers*/) {} - -Reference> -StreamsFactory::createChannelRequester( - Reference> responseSink) { - ChannelRequester::Parameters params( - connection_.shared_from_this(), getNextStreamId()); - auto stateMachine = yarpl::make_ref(params); - connection_.addStream(params.streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); - return stateMachine; -} - -void StreamsFactory::createStreamRequester( - Payload request, - Reference> responseSink) { - StreamRequester::Parameters params( - connection_.shared_from_this(), getNextStreamId()); - auto stateMachine = - yarpl::make_ref(params, std::move(request)); - connection_.addStream(params.streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); -} - -void StreamsFactory::createRequestResponseRequester( - Payload payload, - Reference> responseSink) { - RequestResponseRequester::Parameters params( - connection_.shared_from_this(), getNextStreamId()); - auto stateMachine = - yarpl::make_ref(params, std::move(payload)); - connection_.addStream(params.streamId, stateMachine); - stateMachine->subscribe(std::move(responseSink)); -} - -StreamId StreamsFactory::getNextStreamId() { - StreamId streamId = nextStreamId_; - CHECK(streamId <= std::numeric_limits::max() - 2); - nextStreamId_ += 2; - return streamId; -} - -bool StreamsFactory::registerNewPeerStreamId(StreamId streamId) { - DCHECK(streamId != 0); - if (nextStreamId_ % 2 == streamId % 2) { - // if this is an unknown stream to the socket and this socket is - // generating - // such stream ids, it is an incoming frame on the stream which no longer - // exist - return false; - } - if (streamId <= lastPeerStreamId_) { - // receiving frame for a stream which no longer exists - return false; - } - lastPeerStreamId_ = streamId; - return true; -} - -Reference StreamsFactory::createChannelResponder( - uint32_t initialRequestN, - StreamId streamId) { - ChannelResponder::Parameters params(connection_.shared_from_this(), streamId); - auto stateMachine = - yarpl::make_ref(initialRequestN, params); - connection_.addStream(streamId, stateMachine); - return stateMachine; -} - -Reference> -StreamsFactory::createStreamResponder( - uint32_t initialRequestN, - StreamId streamId) { - StreamResponder::Parameters params(connection_.shared_from_this(), streamId); - auto stateMachine = yarpl::make_ref(initialRequestN, params); - connection_.addStream(streamId, stateMachine); - return stateMachine; -} - -Reference> -StreamsFactory::createRequestResponseResponder(StreamId streamId) { - RequestResponseResponder::Parameters params( - connection_.shared_from_this(), streamId); - auto stateMachine = yarpl::make_ref(params); - connection_.addStream(streamId, stateMachine); - return stateMachine; -} - -} // reactivesocket diff --git a/rsocket/statemachine/StreamsFactory.h b/rsocket/statemachine/StreamsFactory.h deleted file mode 100644 index 90dcc4ccc..000000000 --- a/rsocket/statemachine/StreamsFactory.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/internal/Common.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscription.h" -#include "yarpl/single/SingleObserver.h" - -namespace folly { -class Executor; -} - -namespace rsocket { - -class RSocketStateMachine; -class ChannelResponder; -struct Payload; - -class StreamsFactory { - public: - StreamsFactory(RSocketStateMachine& connection, ReactiveSocketMode mode); - - yarpl::Reference> createChannelRequester( - yarpl::Reference> responseSink); - - void createStreamRequester( - Payload request, - yarpl::Reference> responseSink); - - void createRequestResponseRequester( - Payload payload, - yarpl::Reference> responseSink); - - // TODO: the return type should not be the stateMachine type, but something - // generic - yarpl::Reference createChannelResponder( - uint32_t initialRequestN, - StreamId streamId); - - yarpl::Reference> createStreamResponder( - uint32_t initialRequestN, - StreamId streamId); - - yarpl::Reference> - createRequestResponseResponder(StreamId streamId); - - bool registerNewPeerStreamId(StreamId streamId); - StreamId getNextStreamId(); - - private: - RSocketStateMachine& connection_; - StreamId nextStreamId_; - StreamId lastPeerStreamId_{0}; -}; -} // reactivesocket diff --git a/rsocket/statemachine/StreamsWriter.cpp b/rsocket/statemachine/StreamsWriter.cpp new file mode 100644 index 000000000..5e2279d70 --- /dev/null +++ b/rsocket/statemachine/StreamsWriter.cpp @@ -0,0 +1,197 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/StreamsWriter.h" + +#include "rsocket/RSocketStats.h" +#include "rsocket/framing/FrameSerializer.h" + +namespace rsocket { + +void StreamsWriterImpl::outputFrameOrEnqueue( + std::unique_ptr frame) { + if (shouldQueue()) { + enqueuePendingOutputFrame(std::move(frame)); + } else { + outputFrame(std::move(frame)); + } +} + +void StreamsWriterImpl::sendPendingFrames() { + // We are free to try to send frames again. Not all frames might be sent if + // the connection breaks, the rest of them will queue up again. + auto frames = consumePendingOutputFrames(); + for (auto& frame : frames) { + outputFrameOrEnqueue(std::move(frame)); + } +} + +void StreamsWriterImpl::enqueuePendingOutputFrame( + std::unique_ptr frame) { + auto const length = frame->computeChainDataLength(); + stats().streamBufferChanged(1, static_cast(length)); + pendingSize_ += length; + pendingOutputFrames_.push_back(std::move(frame)); +} + +std::deque> +StreamsWriterImpl::consumePendingOutputFrames() { + if (auto const numFrames = pendingOutputFrames_.size()) { + stats().streamBufferChanged( + -static_cast(numFrames), -static_cast(pendingSize_)); + pendingSize_ = 0; + } + return std::move(pendingOutputFrames_); +} + +void StreamsWriterImpl::writeNewStream( + StreamId streamId, + StreamType streamType, + uint32_t initialRequestN, + Payload payload) { + // for simplicity, require that sent buffers don't consist of chains + writeFragmented( + [&](Payload p, FrameFlags flags) { + switch (streamType) { + case StreamType::CHANNEL: + outputFrameOrEnqueue( + serializer().serializeOut(Frame_REQUEST_CHANNEL( + streamId, flags, initialRequestN, std::move(p)))); + break; + case StreamType::STREAM: + outputFrameOrEnqueue(serializer().serializeOut(Frame_REQUEST_STREAM( + streamId, flags, initialRequestN, std::move(p)))); + break; + case StreamType::REQUEST_RESPONSE: + outputFrameOrEnqueue(serializer().serializeOut( + Frame_REQUEST_RESPONSE(streamId, flags, std::move(p)))); + break; + case StreamType::FNF: + outputFrameOrEnqueue(serializer().serializeOut( + Frame_REQUEST_FNF(streamId, flags, std::move(p)))); + break; + default: + CHECK(false) << "invalid stream type " << toString(streamType); + } + }, + streamId, + FrameFlags::EMPTY_, + std::move(payload)); +} + +void StreamsWriterImpl::writeRequestN(Frame_REQUEST_N&& frame) { + outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); +} + +void StreamsWriterImpl::writeCancel(Frame_CANCEL&& frame) { + outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); +} + +void StreamsWriterImpl::writePayload(Frame_PAYLOAD&& f) { + Frame_PAYLOAD frame = std::move(f); + auto const streamId = frame.header_.streamId; + auto const initialFlags = frame.header_.flags; + + writeFragmented( + [this, streamId](Payload p, FrameFlags flags) { + outputFrameOrEnqueue(serializer().serializeOut( + Frame_PAYLOAD(streamId, flags, std::move(p)))); + }, + streamId, + initialFlags, + std::move(frame.payload_)); +} + +void StreamsWriterImpl::writeError(Frame_ERROR&& frame) { + // TODO: implement fragmentation for writeError as well + outputFrameOrEnqueue(serializer().serializeOut(std::move(frame))); +} + +// The max amount of user data transmitted per frame - eg the size +// of the data and metadata combined, plus the size of the frame header. +// This assumes that the frame header will never be more than 512 bytes in +// size. A CHECK in FrameTransportImpl enforces this. The idea is that +// 16M is so much larger than the ~500 bytes possibly wasted that it won't +// be noticeable (0.003% wasted at most) +constexpr size_t GENEROUS_MAX_FRAME_SIZE = 0xFFFFFF - 512; + +// writeFragmented takes a `payload` and splits it up into chunks which +// are sent as fragmented requests. The first fragmented payload is +// given to writeInitialFrame, which is expected to write the initial +// "REQUEST_" or "PAYLOAD" frame of a stream or response. writeFragmented +// then writes the rest of the frames as payloads. +// +// writeInitialFrame +// - called with the payload of the first frame to send, and any additional +// flags (eg, addFlags with FOLLOWS, if there are more frames to write) +// streamId +// - The stream ID to write additional fragments with +// addFlags +// - All flags that writeInitialFrame wants to write the first frame with, +// and all flags that subsequent fragmented payloads will be sent with +// payload +// - The unsplit payload to send, possibly in multiple fragments +template +void StreamsWriterImpl::writeFragmented( + WriteInitialFrame writeInitialFrame, + StreamId const streamId, + FrameFlags const addFlags, + Payload payload) { + folly::IOBufQueue metaQueue{folly::IOBufQueue::cacheChainLength()}; + folly::IOBufQueue dataQueue{folly::IOBufQueue::cacheChainLength()}; + + // have to keep track of "did the full payload even have a metadata", because + // the rsocket protocol makes a distinction between a zero-length metadata + // and a null metadata. + bool const haveNonNullMeta = !!payload.metadata; + metaQueue.append(std::move(payload.metadata)); + dataQueue.append(std::move(payload.data)); + + bool isFirstFrame = true; + + while (true) { + Payload sendme; + + // chew off some metadata (splitAtMost will never return a null pointer, + // safe to compute length on it always) + if (haveNonNullMeta) { + sendme.metadata = metaQueue.splitAtMost(GENEROUS_MAX_FRAME_SIZE); + DCHECK_GE( + GENEROUS_MAX_FRAME_SIZE, sendme.metadata->computeChainDataLength()); + } + sendme.data = dataQueue.splitAtMost( + GENEROUS_MAX_FRAME_SIZE - + (haveNonNullMeta ? sendme.metadata->computeChainDataLength() : 0)); + + auto const metaLeft = metaQueue.chainLength(); + auto const dataLeft = dataQueue.chainLength(); + auto const moreFragments = metaLeft || dataLeft; + auto const flags = + (moreFragments ? FrameFlags::FOLLOWS : FrameFlags::EMPTY_) | addFlags; + + if (isFirstFrame) { + isFirstFrame = false; + writeInitialFrame(std::move(sendme), flags); + } else { + outputFrameOrEnqueue(serializer().serializeOut( + Frame_PAYLOAD(streamId, flags, std::move(sendme)))); + } + + if (!moreFragments) { + break; + } + } +} + +} // namespace rsocket diff --git a/rsocket/statemachine/StreamsWriter.h b/rsocket/statemachine/StreamsWriter.h index 50be835d7..7ecf1da87 100644 --- a/rsocket/statemachine/StreamsWriter.h +++ b/rsocket/statemachine/StreamsWriter.h @@ -1,18 +1,34 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include + +#include +#include #include "rsocket/Payload.h" +#include "rsocket/framing/Frame.h" +#include "rsocket/framing/FrameType.h" #include "rsocket/internal/Common.h" namespace rsocket { +class RSocketStats; class FrameSerializer; -/// -/// StreamsWriter is the interface for writing stream related frames -/// on the wire. -/// +/// The interface for writing stream related frames on the wire. class StreamsWriter { public: virtual ~StreamsWriter() = default; @@ -21,21 +37,71 @@ class StreamsWriter { StreamId streamId, StreamType streamType, uint32_t initialRequestN, - Payload payload, - bool TEMP_completed) = 0; + Payload payload) = 0; - virtual void writeRequestN(StreamId streamId, uint32_t n) = 0; + virtual void writeRequestN(Frame_REQUEST_N&&) = 0; + virtual void writeCancel(Frame_CANCEL&&) = 0; - virtual void - writePayload(StreamId streamId, Payload payload, bool complete) = 0; + virtual void writePayload(Frame_PAYLOAD&&) = 0; + virtual void writeError(Frame_ERROR&&) = 0; - virtual void writeCloseStream( + virtual void onStreamClosed(StreamId) = 0; + + virtual std::shared_ptr> + onNewStreamReady( StreamId streamId, - StreamCompletionSignal signal, - Payload payload) = 0; + StreamType streamType, + Payload payload, + std::shared_ptr> response) = 0; + virtual void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) = 0; +}; - virtual void onStreamClosed( +class StreamsWriterImpl : public StreamsWriter { + public: + void writeNewStream( StreamId streamId, - StreamCompletionSignal signal) = 0; + StreamType streamType, + uint32_t initialRequestN, + Payload payload) override; + + void writeRequestN(Frame_REQUEST_N&&) override; + void writeCancel(Frame_CANCEL&&) override; + + void writePayload(Frame_PAYLOAD&&) override; + + // TODO: writeFragmentedError + void writeError(Frame_ERROR&&) override; + + protected: + // note: onStreamClosed() method is also still pure + virtual void outputFrame(std::unique_ptr) = 0; + virtual FrameSerializer& serializer() = 0; + virtual RSocketStats& stats() = 0; + virtual bool shouldQueue() = 0; + + template + void writeFragmented( + WriteInitialFrame, + StreamId const, + FrameFlags const, + Payload payload); + + /// Send a frame to the output, or queue it if shouldQueue() + virtual void sendPendingFrames(); + void outputFrameOrEnqueue(std::unique_ptr); + void enqueuePendingOutputFrame(std::unique_ptr frame); + std::deque> consumePendingOutputFrames(); + + private: + /// A queue of frames that are slated to be sent out. + std::deque> pendingOutputFrames_; + + /// The byte size of all pending output frames. + size_t pendingSize_{0}; }; -} + +} // namespace rsocket diff --git a/tck-test/BaseSubscriber.cpp b/rsocket/tck-test/BaseSubscriber.cpp similarity index 73% rename from tck-test/BaseSubscriber.cpp rename to rsocket/tck-test/BaseSubscriber.cpp index 24dbb1510..c0df54613 100644 --- a/tck-test/BaseSubscriber.cpp +++ b/rsocket/tck-test/BaseSubscriber.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "tck-test/BaseSubscriber.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/BaseSubscriber.h" #include @@ -30,9 +42,9 @@ void BaseSubscriber::awaitAtLeast(int numItems) { } void BaseSubscriber::awaitNoEvents(int waitTime) { - int valuesCount = valuesCount_; - bool completed = completed_; - bool errored = errored_; + const int valuesCount = valuesCount_; + const bool completed = completed_; + const bool errored = errored_; /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(waitTime)); if (valuesCount != valuesCount_ || completed != completed_ || @@ -57,7 +69,7 @@ void BaseSubscriber::assertError() { void BaseSubscriber::assertValues( const std::vector>& values) { assertValueCount(values.size()); - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); for (size_t i = 0; i < values.size(); i++) { if (values_[i] != values[i]) { throw std::runtime_error(folly::sformat( @@ -71,7 +83,7 @@ void BaseSubscriber::assertValues( } void BaseSubscriber::assertValueCount(size_t valueCount) { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); if (values_.size() != valueCount) { throw std::runtime_error(folly::sformat( "Did not receive expected number of values! Expected={} Actual={}", @@ -81,7 +93,7 @@ void BaseSubscriber::assertValueCount(size_t valueCount) { } void BaseSubscriber::assertReceivedAtLeast(size_t valueCount) { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); if (values_.size() < valueCount) { throw std::runtime_error(folly::sformat( "Did not receive the minimum number of values! Expected={} Actual={}", @@ -108,5 +120,5 @@ void BaseSubscriber::assertCanceled() { } } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/tck-test/BaseSubscriber.h b/rsocket/tck-test/BaseSubscriber.h similarity index 59% rename from tck-test/BaseSubscriber.h rename to rsocket/tck-test/BaseSubscriber.h index 1fcf32fc6..fdda649ff 100644 --- a/tck-test/BaseSubscriber.h +++ b/rsocket/tck-test/BaseSubscriber.h @@ -1,3 +1,17 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #pragma once #include @@ -10,7 +24,7 @@ namespace rsocket { namespace tck { -class BaseSubscriber : public virtual yarpl::Refcounted { +class BaseSubscriber { public: virtual void request(int n) = 0; virtual void cancel() = 0; @@ -31,13 +45,14 @@ class BaseSubscriber : public virtual yarpl::Refcounted { std::atomic canceled_{false}; //////////////////////////////////////////////////////////////////////////// - std::mutex mutex_; // all variables below has to be protected with the mutex + mutable std::mutex + mutex_; // all variables below has to be protected with the mutex std::vector> values_; std::condition_variable valuesCV_; std::atomic valuesCount_{0}; - std::vector errors_; + std::vector errors_; std::condition_variable terminatedCV_; std::atomic completed_{false}; // by onComplete @@ -45,5 +60,5 @@ class BaseSubscriber : public virtual yarpl::Refcounted { //////////////////////////////////////////////////////////////////////////// }; -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/tck-test/FlowableSubscriber.cpp b/rsocket/tck-test/FlowableSubscriber.cpp similarity index 57% rename from tck-test/FlowableSubscriber.cpp rename to rsocket/tck-test/FlowableSubscriber.cpp index 4389fea32..33b72b7fb 100644 --- a/tck-test/FlowableSubscriber.cpp +++ b/rsocket/tck-test/FlowableSubscriber.cpp @@ -1,4 +1,18 @@ -#include "tck-test/FlowableSubscriber.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/FlowableSubscriber.h" #include @@ -29,7 +43,7 @@ void FlowableSubscriber::cancel() { } void FlowableSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { VLOG(4) << "OnSubscribe in FlowableSubscriber"; subscription_ = subscription; if (initialRequestN_ > 0) { @@ -40,10 +54,10 @@ void FlowableSubscriber::onSubscribe( void FlowableSubscriber::onNext(Payload element) noexcept { LOG(INFO) << "... received onNext from Publisher: " << element; { - std::unique_lock lock(mutex_); - std::string data = + const std::unique_lock lock(mutex_); + const std::string data = element.data ? element.data->moveToFbString().toStdString() : ""; - std::string metadata = element.metadata + const std::string metadata = element.metadata ? element.metadata->moveToFbString().toStdString() : ""; values_.push_back(std::make_pair(data, metadata)); @@ -55,22 +69,22 @@ void FlowableSubscriber::onNext(Payload element) noexcept { void FlowableSubscriber::onComplete() noexcept { LOG(INFO) << "... received onComplete from Publisher"; { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); completed_ = true; } terminatedCV_.notify_one(); } -void FlowableSubscriber::onError(std::exception_ptr ex) noexcept { +void FlowableSubscriber::onError(folly::exception_wrapper ex) noexcept { LOG(INFO) << "... received onError from Publisher"; { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); errors_.push_back(std::move(ex)); errored_ = true; } terminatedCV_.notify_one(); } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/FlowableSubscriber.h b/rsocket/tck-test/FlowableSubscriber.h new file mode 100644 index 000000000..3de091023 --- /dev/null +++ b/rsocket/tck-test/FlowableSubscriber.h @@ -0,0 +1,47 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/tck-test/BaseSubscriber.h" + +#include "yarpl/Flowable.h" + +namespace rsocket { +namespace tck { + +class FlowableSubscriber : public BaseSubscriber, + public yarpl::flowable::Subscriber { + public: + explicit FlowableSubscriber(int initialRequestN = 0); + + // Inherited from BaseSubscriber + void request(int n) override; + void cancel() override; + + protected: + // Inherited from flowable::Subscriber + void onSubscribe(std::shared_ptr + subscription) noexcept override; + void onNext(Payload element) noexcept override; + void onComplete() noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; + + private: + std::shared_ptr subscription_; + int initialRequestN_{0}; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/MarbleProcessor.cpp b/rsocket/tck-test/MarbleProcessor.cpp similarity index 73% rename from tck-test/MarbleProcessor.cpp rename to rsocket/tck-test/MarbleProcessor.cpp index c30e693dd..62038cac1 100644 --- a/tck-test/MarbleProcessor.cpp +++ b/rsocket/tck-test/MarbleProcessor.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "MarbleProcessor.h" @@ -44,7 +56,7 @@ std::map> getArgMap( } return argMap; } -} +} // namespace namespace rsocket { namespace tck { @@ -67,30 +79,26 @@ MarbleProcessor::MarbleProcessor(const std::string marble) } } -std::tuple MarbleProcessor::run( +void MarbleProcessor::run( yarpl::flowable::Subscriber& subscriber, int64_t requested) { canSend_ += requested; - if (index_ > marble_.size()) { - return std::make_tuple(requested, true); - } - while (true) { - auto c = marble_[index_]; + while (canSend_ > 0 && index_ < marble_.size()) { + const auto c = marble_[index_]; switch (c) { case '#': LOG(INFO) << "Sending onError"; - subscriber.onError( - std::make_exception_ptr(std::runtime_error("Marble Error"))); - return std::make_tuple(requested, true); + subscriber.onError(std::runtime_error("Marble Error")); + break; case '|': LOG(INFO) << "Sending onComplete"; subscriber.onComplete(); - return std::make_tuple(requested, true); - default: { + break; + default: if (canSend_ > 0) { Payload payload; - auto it = argMap_.find(folly::to(c)); + const auto it = argMap_.find(folly::to(c)); LOG(INFO) << "Sending data " << c; if (it != argMap_.end()) { LOG(INFO) << folly::sformat( @@ -105,34 +113,30 @@ std::tuple MarbleProcessor::run( } subscriber.onNext(std::move(payload)); canSend_--; - } else { - return std::make_tuple(requested, false); } - } + break; } index_++; } } void MarbleProcessor::run( - yarpl::Reference> + std::shared_ptr> subscriber) { while (true) { - auto c = marble_[index_]; + const auto c = marble_[index_]; switch (c) { case '#': LOG(INFO) << "Sending onError"; - subscriber->onError( - std::make_exception_ptr(std::runtime_error("Marble Error"))); + subscriber->onError(std::runtime_error("Marble Error")); return; case '|': LOG(INFO) << "Sending onComplete"; - subscriber->onError( - std::make_exception_ptr(std::runtime_error("No Response found"))); + subscriber->onError(std::runtime_error("No Response found")); return; default: { Payload payload; - auto it = argMap_.find(folly::to(c)); + const auto it = argMap_.find(folly::to(c)); LOG(INFO) << "Sending data " << c; if (it != argMap_.end()) { LOG(INFO) << folly::sformat( @@ -153,5 +157,5 @@ void MarbleProcessor::run( } } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/MarbleProcessor.h b/rsocket/tck-test/MarbleProcessor.h new file mode 100644 index 000000000..77217b63b --- /dev/null +++ b/rsocket/tck-test/MarbleProcessor.h @@ -0,0 +1,50 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "rsocket/Payload.h" +#include "yarpl/Flowable.h" +#include "yarpl/Single.h" + +namespace rsocket { +namespace tck { + +class MarbleProcessor { + public: + explicit MarbleProcessor(const std::string /* marble */); + + void run( + yarpl::flowable::Subscriber& subscriber, + int64_t requested); + + void run(std::shared_ptr> + subscriber); + + private: + std::string marble_; + + // Stores a mapping from marble character to Payload (data, metadata) + std::map> argMap_; + + // Keeps an account of how many messages can be sent. This could be done + // with Allowance + std::atomic canSend_{0}; + + size_t index_{0}; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/SingleSubscriber.cpp b/rsocket/tck-test/SingleSubscriber.cpp similarity index 50% rename from tck-test/SingleSubscriber.cpp rename to rsocket/tck-test/SingleSubscriber.cpp index ff3246e36..2da91f5b1 100644 --- a/tck-test/SingleSubscriber.cpp +++ b/rsocket/tck-test/SingleSubscriber.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "tck-test/SingleSubscriber.h" +#include "rsocket/tck-test/SingleSubscriber.h" #include @@ -24,7 +36,7 @@ void SingleSubscriber::cancel() { } void SingleSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { + std::shared_ptr subscription) noexcept { VLOG(4) << "OnSubscribe in SingleSubscriber"; subscription_ = subscription; } @@ -32,10 +44,10 @@ void SingleSubscriber::onSubscribe( void SingleSubscriber::onSuccess(Payload element) noexcept { LOG(INFO) << "... received onSuccess from Publisher: " << element; { - std::unique_lock lock(mutex_); - std::string data = + const std::unique_lock lock(mutex_); + const std::string data = element.data ? element.data->moveToFbString().toStdString() : ""; - std::string metadata = element.metadata + const std::string metadata = element.metadata ? element.metadata->moveToFbString().toStdString() : ""; values_.push_back(std::make_pair(data, metadata)); @@ -43,21 +55,21 @@ void SingleSubscriber::onSuccess(Payload element) noexcept { } valuesCV_.notify_one(); { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); completed_ = true; } terminatedCV_.notify_one(); } -void SingleSubscriber::onError(std::exception_ptr ex) noexcept { +void SingleSubscriber::onError(folly::exception_wrapper ex) noexcept { LOG(INFO) << "... received onError from Publisher"; { - std::unique_lock lock(mutex_); + const std::unique_lock lock(mutex_); errors_.push_back(std::move(ex)); errored_ = true; } terminatedCV_.notify_one(); } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/SingleSubscriber.h b/rsocket/tck-test/SingleSubscriber.h new file mode 100644 index 000000000..8b8b8556f --- /dev/null +++ b/rsocket/tck-test/SingleSubscriber.h @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/tck-test/BaseSubscriber.h" + +#include "yarpl/Single.h" + +namespace rsocket { +namespace tck { + +class SingleSubscriber : public BaseSubscriber, + public yarpl::single::SingleObserver { + public: + // Inherited from BaseSubscriber + void request(int n) override; + void cancel() override; + + protected: + // Inherited from flowable::Subscriber + void onSubscribe(std::shared_ptr + subscription) noexcept override; + void onSuccess(Payload element) noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; + + private: + std::shared_ptr subscription_; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/TestFileParser.cpp b/rsocket/tck-test/TestFileParser.cpp similarity index 60% rename from tck-test/TestFileParser.cpp rename to rsocket/tck-test/TestFileParser.cpp index 5d0717cc5..9306960c3 100644 --- a/tck-test/TestFileParser.cpp +++ b/rsocket/tck-test/TestFileParser.cpp @@ -1,6 +1,18 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "tck-test/TestFileParser.h" +#include "rsocket/tck-test/TestFileParser.h" #include #include @@ -43,6 +55,7 @@ void TestFileParser::parseCommand(const std::string& command) { if (parameters.size() == 2 && parameters[0] == "name") { currentTest_.setName(parameters[1]); + currentTest_.setResumption(false); return; } @@ -51,6 +64,11 @@ void TestFileParser::parseCommand(const std::string& command) { LOG(ERROR) << "invalid command on line " << currentLine_ << ": " << command; throw std::runtime_error("unknown command in the test"); } else { + // if test contain resumption related command. + if ("disconnect" == newCommand.name() || "resume" == newCommand.name()) { + currentTest_.setResumption(true); + } + currentTest_.addCommand(std::move(newCommand)); } } @@ -62,5 +80,5 @@ void TestFileParser::addCurrentTest() { } } -} // tck -} // reactivesocket +} // namespace tck +} // namespace rsocket diff --git a/rsocket/tck-test/TestFileParser.h b/rsocket/tck-test/TestFileParser.h new file mode 100644 index 000000000..7830934fb --- /dev/null +++ b/rsocket/tck-test/TestFileParser.h @@ -0,0 +1,42 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/tck-test/TestSuite.h" + +namespace rsocket { +namespace tck { + +class TestFileParser { + public: + explicit TestFileParser(const std::string& fileName); + + TestSuite parse(); + + private: + void parseCommand(const std::string& command); + void addCurrentTest(); + + std::ifstream input_; + int currentLine_; + + TestSuite testSuite_; + Test currentTest_; +}; + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/TestInterpreter.cpp b/rsocket/tck-test/TestInterpreter.cpp similarity index 68% rename from tck-test/TestInterpreter.cpp rename to rsocket/tck-test/TestInterpreter.cpp index b613508b8..f74eab68c 100644 --- a/tck-test/TestInterpreter.cpp +++ b/rsocket/tck-test/TestInterpreter.cpp @@ -1,16 +1,28 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "tck-test/TestInterpreter.h" +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/TestInterpreter.h" #include #include #include #include "rsocket/RSocket.h" +#include "rsocket/tck-test/FlowableSubscriber.h" +#include "rsocket/tck-test/SingleSubscriber.h" +#include "rsocket/tck-test/TypedCommands.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "tck-test/FlowableSubscriber.h" -#include "tck-test/SingleSubscriber.h" -#include "tck-test/TypedCommands.h" using namespace folly; using namespace yarpl; @@ -34,20 +46,26 @@ bool TestInterpreter::run() { "Executing command: [{}] {}", i, command.name()); ++i; if (command.name() == "subscribe") { - auto subscribe = command.as(); + const auto subscribe = command.as(); handleSubscribe(subscribe); } else if (command.name() == "request") { - auto request = command.as(); + const auto request = command.as(); handleRequest(request); } else if (command.name() == "await") { - auto await = command.as(); + const auto await = command.as(); handleAwait(await); } else if (command.name() == "cancel") { - auto cancel = command.as(); + const auto cancel = command.as(); handleCancel(cancel); } else if (command.name() == "assert") { - auto assert = command.as(); + const auto assert = command.as(); handleAssert(assert); + } else if (command.name() == "disconnect") { + const auto disconnect = command.as(); + handleDisconnect(disconnect); + } else if (command.name() == "resume") { + const auto resume = command.as(); + handleResume(resume); } else { LOG(ERROR) << "unknown command " << command.name(); throw std::runtime_error("unknown command"); @@ -65,12 +83,33 @@ bool TestInterpreter::run() { return true; } +void TestInterpreter::handleDisconnect(const DisconnectCommand& command) { + if (testClient_.find(command.clientId()) != testClient_.end()) { + LOG(INFO) << "Disconnecting the client"; + testClient_[command.clientId()]->client->disconnect( + std::runtime_error("disconnect triggered from client")); + } +} + +void TestInterpreter::handleResume(const ResumeCommand& command) { + if (testClient_.find(command.clientId()) != testClient_.end()) { + LOG(INFO) << "Resuming the client"; + testClient_[command.clientId()]->client->resume().get(); + } +} + void TestInterpreter::handleSubscribe(const SubscribeCommand& command) { // If client does not exist, create a new client. if (testClient_.find(command.clientId()) == testClient_.end()) { + SetupParameters setupParameters; + if (test_.resumption()) { + setupParameters.resumable = true; + } auto client = RSocket::createConnectedClient( - std::make_unique(std::move(address_))) - .get(); + std::make_unique( + *worker_.getEventBase(), std::move(address_)), + std::move(setupParameters)) + .get(); testClient_[command.clientId()] = std::make_shared(move(client)); } @@ -80,7 +119,7 @@ void TestInterpreter::handleSubscribe(const SubscribeCommand& command) { testSubscribers_.end()); if (command.isRequestResponseType()) { - auto testSubscriber = make_ref(); + auto testSubscriber = std::make_shared(); testSubscribers_[command.clientId() + command.id()] = testSubscriber; testClient_[command.clientId()] ->requester @@ -88,7 +127,7 @@ void TestInterpreter::handleSubscribe(const SubscribeCommand& command) { Payload(command.payloadData(), command.payloadMetadata())) ->subscribe(std::move(testSubscriber)); } else if (command.isRequestStreamType()) { - auto testSubscriber = make_ref(); + auto testSubscriber = std::make_shared(); testSubscribers_[command.clientId() + command.id()] = testSubscriber; testClient_[command.clientId()] ->requester @@ -98,7 +137,7 @@ void TestInterpreter::handleSubscribe(const SubscribeCommand& command) { } else { throw std::runtime_error("unsupported interaction type"); } -} +} void TestInterpreter::handleRequest(const RequestCommand& command) { getSubscriber(command.clientId() + command.id())->request(command.n()); @@ -158,9 +197,9 @@ void TestInterpreter::handleAssert(const AssertCommand& command) { } } -yarpl::Reference TestInterpreter::getSubscriber( +std::shared_ptr TestInterpreter::getSubscriber( const std::string& id) { - auto found = testSubscribers_.find(id); + const auto found = testSubscribers_.find(id); if (found == testSubscribers_.end()) { throw std::runtime_error("unable to find test subscriber with provided id"); } diff --git a/tck-test/TestInterpreter.h b/rsocket/tck-test/TestInterpreter.h similarity index 51% rename from tck-test/TestInterpreter.h rename to rsocket/tck-test/TestInterpreter.h index 94d787873..d57fb76c6 100644 --- a/tck-test/TestInterpreter.h +++ b/rsocket/tck-test/TestInterpreter.h @@ -1,16 +1,29 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include +#include #include "rsocket/Payload.h" #include "rsocket/RSocket.h" #include "rsocket/RSocketRequester.h" -#include "tck-test/BaseSubscriber.h" -#include "tck-test/TestSuite.h" +#include "rsocket/tck-test/BaseSubscriber.h" +#include "rsocket/tck-test/TestSuite.h" namespace folly { class EventBase; @@ -27,11 +40,13 @@ class RequestCommand; class AwaitCommand; class CancelCommand; class AssertCommand; +class ResumeCommand; +class DisconnectCommand; class TestInterpreter { class TestClient { public: - TestClient(std::shared_ptr c) + explicit TestClient(std::shared_ptr c) : client(std::move(c)) { auto rs = client->getRequester(); requester = std::move(rs); @@ -51,13 +66,16 @@ class TestInterpreter { void handleAwait(const AwaitCommand& command); void handleCancel(const CancelCommand& command); void handleAssert(const AssertCommand& command); + void handleDisconnect(const DisconnectCommand& command); + void handleResume(const ResumeCommand& command); - yarpl::Reference getSubscriber(const std::string& id); + std::shared_ptr getSubscriber(const std::string& id); + folly::ScopedEventBaseThread worker_; folly::SocketAddress address_; const Test& test_; std::map interactionIdToType_; - std::map> testSubscribers_; + std::map> testSubscribers_; std::map> testClient_; }; diff --git a/rsocket/tck-test/TestSuite.cpp b/rsocket/tck-test/TestSuite.cpp new file mode 100644 index 000000000..8e921f347 --- /dev/null +++ b/rsocket/tck-test/TestSuite.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/tck-test/TestSuite.h" + +#include + +namespace rsocket { +namespace tck { + +bool TestCommand::valid() const { + // there has to be a name to the test and at least 1 param + return params_.size() >= 1; +} + +void Test::addCommand(TestCommand command) { + CHECK(command.valid()); + commands_.push_back(std::move(command)); +} + +} // namespace tck +} // namespace rsocket diff --git a/tck-test/TestSuite.h b/rsocket/tck-test/TestSuite.h similarity index 59% rename from tck-test/TestSuite.h rename to rsocket/tck-test/TestSuite.h index aef0e0795..f705e0d13 100644 --- a/tck-test/TestSuite.h +++ b/rsocket/tck-test/TestSuite.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -42,6 +54,14 @@ class Test { name_ = name; } + bool resumption() const { + return resumption_; + } + + void setResumption(bool resumption) { + resumption_ = resumption; + } + void addCommand(TestCommand command); const std::vector& commands() const { @@ -54,6 +74,7 @@ class Test { private: std::string name_; + bool resumption_{false}; std::vector commands_; }; diff --git a/tck-test/TypedCommands.h b/rsocket/tck-test/TypedCommands.h similarity index 81% rename from tck-test/TypedCommands.h rename to rsocket/tck-test/TypedCommands.h index 5d5f0c0d9..d32b144ac 100644 --- a/tck-test/TypedCommands.h +++ b/rsocket/tck-test/TypedCommands.h @@ -1,11 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include -#include "tck-test/TestSuite.h" +#include "rsocket/tck-test/TestSuite.h" namespace rsocket { namespace tck { @@ -71,6 +83,16 @@ class CancelCommand : public TypedTestCommand { } }; +class ResumeCommand : public TypedTestCommand { + public: + using TypedTestCommand::TypedTestCommand; +}; + +class DisconnectCommand : public TypedTestCommand { + public: + using TypedTestCommand::TypedTestCommand; +}; + class AwaitCommand : public TypedTestCommand { public: using TypedTestCommand::TypedTestCommand; diff --git a/tck-test/client.cpp b/rsocket/tck-test/client.cpp similarity index 72% rename from tck-test/client.cpp rename to rsocket/tck-test/client.cpp index 901d7b336..acb5068c3 100644 --- a/tck-test/client.cpp +++ b/rsocket/tck-test/client.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -8,8 +20,8 @@ #include "rsocket/RSocket.h" -#include "tck-test/TestFileParser.h" -#include "tck-test/TestInterpreter.h" +#include "rsocket/tck-test/TestFileParser.h" +#include "rsocket/tck-test/TestInterpreter.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" diff --git a/rsocket/tck-test/clientResumptiontest.txt b/rsocket/tck-test/clientResumptiontest.txt new file mode 100644 index 000000000..db13f3f66 --- /dev/null +++ b/rsocket/tck-test/clientResumptiontest.txt @@ -0,0 +1,49 @@ +! +name%%streamResumptionTest4 +c1%%subscribe%%rs%%1%%a%%b +c1%%request%%3%%1 +c1%%await%%atLeast%%1%%3%%100 +c1%%assert%%received%%1%%a,b&&c,d&&e,f +c1%%disconnect +c1%%resume +c1%%request%%3%%1 +c1%%await%%terminal%%1 +c1%%assert%%completed%%1 +c1%%assert%%no_error%%1 +c1%%assert%%received_n%%1%%6 +! +name%%streamResumptionTest3 +c2%%subscribe%%rs%%2%%a%%b +c2%%request%%3%%2 +c2%%await%%atLeast%%2%%3%%100 +c2%%assert%%received%%2%%a,b&&c,d&&e,f +c2%%disconnect +c2%%request%%1%%2 +c2%%resume +c2%%request%%2%%2 +c2%%await%%terminal%%2 +c2%%assert%%received_n%%2%%6 +! +name%%streamResumptionTest2 +c3%%subscribe%%rs%%3%%a%%b +c3%%request%%3%%3 +c3%%await%%atLeast%%3%%3%%100 +c3%%assert%%received%%3%%a,b&&c,d&&e,f +c3%%disconnect +c3%%request%%3%%3 +c3%%await%%no_events%%3%%100 +c3%%assert%%received_n%%3%%3 +! +name%%streamResumptionTest1 +c4%%subscribe%%rs%%4%%a%%b +c4%%request%%3%%4 +c4%%await%%atLeast%%4%%3%%100 +c4%%assert%%received%%4%%a,b&&c,d&&e,f +c4%%disconnect +c4%%request%%3%%4 +c4%%resume +c4%%await%%terminal%%4 +c4%%assert%%completed%%4 +c4%%assert%%no_error%%4 +c4%%assert%%received_n%%4%%6 +EOF diff --git a/tck-test/clienttest.txt b/rsocket/tck-test/clienttest.txt similarity index 100% rename from tck-test/clienttest.txt rename to rsocket/tck-test/clienttest.txt diff --git a/tck-test/server.cpp b/rsocket/tck-test/server.cpp similarity index 52% rename from tck-test/server.cpp rename to rsocket/tck-test/server.cpp index a8fc4aece..2988cfe28 100644 --- a/tck-test/server.cpp +++ b/rsocket/tck-test/server.cpp @@ -1,21 +1,34 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include #include #include -#include #include #include #include #include +#include "rsocket/RSocket.h" +#include "rsocket/RSocketServiceHandler.h" #include "rsocket/framing/FramedDuplexConnection.h" #include "rsocket/transports/tcp/TcpDuplexConnection.h" -#include "rsocket/RSocket.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "tck-test/MarbleProcessor.h" +#include "rsocket/tck-test/MarbleProcessor.h" using namespace folly; using namespace rsocket; @@ -66,7 +79,7 @@ MarbleStore parseMarbles(const std::string& fileName) { } return ms; } -} +} // namespace class ServerResponder : public RSocketResponder { public: @@ -74,45 +87,50 @@ class ServerResponder : public RSocketResponder { marbles_ = parseMarbles(FLAGS_test_file); } - yarpl::Reference> handleRequestStream( + std::shared_ptr> handleRequestStream( Payload request, StreamId) override { LOG(INFO) << "handleRequestStream " << request; - std::string data = request.data->moveToFbString().toStdString(); - std::string metadata = request.metadata->moveToFbString().toStdString(); - auto it = marbles_.streamMarbles.find(std::make_pair(data, metadata)); + const std::string data = request.data->moveToFbString().toStdString(); + const std::string metadata = + request.metadata->moveToFbString().toStdString(); + const auto it = marbles_.streamMarbles.find(std::make_pair(data, metadata)); if (it == marbles_.streamMarbles.end()) { - return yarpl::flowable::Flowables::error( + return yarpl::flowable::Flowable::error( std::logic_error("No MarbleHandler found")); } else { - auto marbleProcessor = std::make_shared(it->second); + const auto marbleProcessor = + std::make_shared(it->second); auto lambda = [marbleProcessor]( - yarpl::flowable::Subscriber& subscriber, - int64_t requested) mutable { + auto& subscriber, int64_t requested) mutable { return marbleProcessor->run(subscriber, requested); }; return Flowable::create(std::move(lambda)); } } - yarpl::Reference> handleRequestResponse( + std::shared_ptr> handleRequestResponse( Payload request, StreamId) override { LOG(INFO) << "handleRequestResponse " << request; - std::string data = request.data->moveToFbString().toStdString(); - std::string metadata = request.metadata->moveToFbString().toStdString(); - auto it = marbles_.reqRespMarbles.find(std::make_pair(data, metadata)); + const std::string data = request.data->moveToFbString().toStdString(); + const std::string metadata = + request.metadata->moveToFbString().toStdString(); + const auto it = + marbles_.reqRespMarbles.find(std::make_pair(data, metadata)); if (it == marbles_.reqRespMarbles.end()) { return yarpl::single::Singles::error( std::logic_error("No MarbleHandler found")); } else { - auto marbleProcessor = std::make_shared(it->second); - auto lambda = [marbleProcessor]( - yarpl::Reference> - subscriber) { - subscriber->onSubscribe(SingleSubscriptions::empty()); - return marbleProcessor->run(subscriber); - }; + const auto marbleProcessor = + std::make_shared(it->second); + auto lambda = + [marbleProcessor]( + std::shared_ptr> + subscriber) { + subscriber->onSubscribe(SingleSubscriptions::empty()); + return marbleProcessor->run(subscriber); + }; return Single::create(std::move(lambda)); } } @@ -121,6 +139,33 @@ class ServerResponder : public RSocketResponder { MarbleStore marbles_; }; +class ServiceHandler : public RSocketServiceHandler { + public: + folly::Expected onNewSetup( + const SetupParameters&) override { + return RSocketConnectionParams(std::make_shared()); + } + + void onNewRSocketState( + std::shared_ptr state, + ResumeIdentificationToken token) override { + store_.lock()->insert({token, std::move(state)}); + } + + folly::Expected, RSocketException> + onResume(ResumeIdentificationToken token) override { + const auto itr = store_->find(token); + CHECK(itr != store_->end()); + return itr->second; + }; + + private: + folly::Synchronized< + std::map>, + std::mutex> + store_; +}; + std::promise terminate; static void signal_handler(int signal) { @@ -136,19 +181,16 @@ int main(int argc, char* argv[]) { signal(SIGTERM, signal_handler); TcpConnectionAcceptor::Options opts; - opts.address = folly::SocketAddress("::", FLAGS_port); - opts.threads = 2; + opts.address = folly::SocketAddress(FLAGS_ip, FLAGS_port); + opts.threads = 1; // RSocket server accepting on TCP - auto rs = RSocket::createServer( + const auto rs = RSocket::createServer( std::make_unique(std::move(opts))); - auto rawRs = rs.get(); - auto serverThread = std::thread([=] { - rawRs->startAndPark([](const SetupParameters&) { - return std::make_shared(); - }); - }); + const auto rawRs = rs.get(); + auto serverThread = std::thread( + [=] { rawRs->startAndPark(std::make_shared()); }); terminate.get_future().wait(); rs->unpark(); diff --git a/rsocket/tck-test/serverResumptiontest.txt b/rsocket/tck-test/serverResumptiontest.txt new file mode 100644 index 000000000..960a0d78f --- /dev/null +++ b/rsocket/tck-test/serverResumptiontest.txt @@ -0,0 +1 @@ +rs%%a%%b%%---a-----b-----c-----d--e--f---|&&{"a":{"a":"b"},"b":{"c":"d"},"c":{"e":"f"}} diff --git a/tck-test/servertest.txt b/rsocket/tck-test/servertest.txt similarity index 100% rename from tck-test/servertest.txt rename to rsocket/tck-test/servertest.txt diff --git a/rsocket/test/ColdResumptionTest.cpp b/rsocket/test/ColdResumptionTest.cpp new file mode 100644 index 000000000..d07fcb7b4 --- /dev/null +++ b/rsocket/test/ColdResumptionTest.cpp @@ -0,0 +1,406 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include + +#include "RSocketTests.h" + +#include "rsocket/test/handlers/HelloServiceHandler.h" +#include "rsocket/test/test_utils/ColdResumeManager.h" + +DEFINE_int32(num_clients, 5, "Number of clients to parallely cold-resume"); + +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; +using namespace yarpl::flowable; + +typedef std::map>> + HelloSubscribers; + +namespace { +class HelloSubscriber : public BaseSubscriber { + public: + explicit HelloSubscriber(size_t latestValue) : latestValue_(latestValue) {} + + void requestWhenSubscribed(int n) { + subscribedBaton_.wait(); + this->request(n); + } + + void awaitLatestValue(size_t value) { + auto count = 50; + while (value != latestValue_ && count > 0) { + VLOG(1) << "Waiting " << count << " ticks for latest value..."; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + count--; + std::this_thread::yield(); + } + EXPECT_EQ(value, latestValue_); + } + + size_t valueCount() const { + return count_; + } + + size_t getLatestValue() const { + return latestValue_; + } + + protected: + void onSubscribeImpl() noexcept override { + subscribedBaton_.post(); + } + + void onNextImpl(Payload p) noexcept override { + auto currValue = folly::to(p.data->moveToFbString().toStdString()); + EXPECT_EQ(latestValue_, currValue - 1); + latestValue_ = currValue; + count_++; + } + + void onCompleteImpl() override {} + void onErrorImpl(folly::exception_wrapper) override {} + + private: + std::atomic latestValue_; + std::atomic count_{0}; + folly::Baton<> subscribedBaton_; +}; + +class HelloResumeHandler : public ColdResumeHandler { + public: + explicit HelloResumeHandler(HelloSubscribers subscribers) + : subscribers_(std::move(subscribers)) {} + + std::string generateStreamToken(const Payload& payload, StreamId, StreamType) + const override { + const auto streamToken = + payload.data->cloneAsValue().moveToFbString().toStdString(); + VLOG(3) << "Generated token: " << streamToken; + return streamToken; + } + + std::shared_ptr> handleRequesterResumeStream( + std::string streamToken, + size_t consumerAllowance) override { + CHECK(subscribers_.find(streamToken) != subscribers_.end()); + VLOG(1) << "Resuming " << streamToken << " stream with allowance " + << consumerAllowance; + return subscribers_[streamToken]; + } + + private: + HelloSubscribers subscribers_; +}; +} // namespace + +std::unique_ptr createResumedClient( + folly::EventBase* evb, + uint32_t port, + ResumeIdentificationToken token, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler, + folly::EventBase* stateMachineEvb = nullptr) { + auto retries = 10; + while (true) { + try { + return RSocket::createResumedClient( + getConnFactory(evb, port), + token, + resumeManager, + coldResumeHandler, + nullptr, /* responder */ + kDefaultKeepaliveInterval, + nullptr, /* stats */ + nullptr, /* connectionEvents */ + ProtocolVersion::Latest, + stateMachineEvb) + .get(); + } catch (const RSocketException& ex) { + retries--; + VLOG(1) << "Creation of resumed client failed. Exception " << ex.what() + << ". Retries Left: " << retries; + if (retries <= 0) { + throw ex; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + } +} + +// There are three sessions and three streams. +// There is cold-resumption between the three sessions. +// +// The first stream lasts through all three sessions. +// The second stream lasts through the second and third session. +// The third stream lives only in the third session. +// +// The first stream requests 10 frames +// The second stream requests 10 frames +// The third stream requests 5 frames +void coldResumer(uint32_t port, uint32_t client_num) { + auto firstPayload = folly::sformat("client{}_first", client_num); + auto secondPayload = folly::sformat("client{}_second", client_num); + auto thirdPayload = folly::sformat("client{}_third", client_num); + size_t firstLatestValue, secondLatestValue; + + folly::ScopedEventBaseThread worker; + auto token = ResumeIdentificationToken::generateNew(); + auto resumeManager = + std::make_shared(RSocketStats::noop()); + { + auto firstSub = std::make_shared(0); + { + auto coldResumeHandler = std::make_shared( + HelloSubscribers({{firstPayload, firstSub}})); + std::shared_ptr firstClient; + EXPECT_NO_THROW( + firstClient = makeColdResumableClient( + worker.getEventBase(), + port, + token, + resumeManager, + coldResumeHandler)); + firstClient->getRequester() + ->requestStream(Payload(firstPayload)) + ->subscribe(firstSub); + firstSub->requestWhenSubscribed(4); + // Ensure reception of few frames before resuming. + while (firstSub->valueCount() < 1) { + std::this_thread::yield(); + } + } + worker.getEventBase()->runInEventBaseThreadAndWait( + [client_num, &firstLatestValue, firstSub = std::move(firstSub)]() { + firstLatestValue = firstSub->getLatestValue(); + VLOG(1) << folly::sformat( + "client{} {}", client_num, firstLatestValue); + VLOG(1) << folly::sformat("client{} First Resume", client_num); + }); + } + + { + auto firstSub = std::make_shared(firstLatestValue); + auto secondSub = std::make_shared(0); + { + auto coldResumeHandler = std::make_shared( + HelloSubscribers({{firstPayload, firstSub}})); + std::shared_ptr secondClient; + EXPECT_NO_THROW( + secondClient = createResumedClient( + worker.getEventBase(), + port, + token, + resumeManager, + coldResumeHandler)); + + // Create another stream to verify StreamIds are set properly after + // resumption + secondClient->getRequester() + ->requestStream(Payload(secondPayload)) + ->subscribe(secondSub); + firstSub->requestWhenSubscribed(3); + secondSub->requestWhenSubscribed(5); + // Ensure reception of few frames before resuming. + while (secondSub->valueCount() < 1) { + std::this_thread::yield(); + } + } + worker.getEventBase()->runInEventBaseThreadAndWait( + [client_num, + &firstLatestValue, + firstSub = std::move(firstSub), + &secondLatestValue, + secondSub = std::move(secondSub)]() { + firstLatestValue = firstSub->getLatestValue(); + secondLatestValue = secondSub->getLatestValue(); + VLOG(1) << folly::sformat( + "client{} {}", client_num, firstLatestValue); + VLOG(1) << folly::sformat( + "client{} {}", client_num, secondLatestValue); + VLOG(1) << folly::sformat("client{} Second Resume", client_num); + }); + } + + { + auto firstSub = std::make_shared(firstLatestValue); + auto secondSub = std::make_shared(secondLatestValue); + auto thirdSub = std::make_shared(0); + auto coldResumeHandler = + std::make_shared(HelloSubscribers( + {{firstPayload, firstSub}, {secondPayload, secondSub}})); + std::shared_ptr thirdClient; + + EXPECT_NO_THROW( + thirdClient = createResumedClient( + worker.getEventBase(), + port, + token, + resumeManager, + coldResumeHandler)); + + // Create another stream to verify StreamIds are set properly after + // resumption + thirdClient->getRequester() + ->requestStream(Payload(thirdPayload)) + ->subscribe(thirdSub); + firstSub->requestWhenSubscribed(3); + secondSub->requestWhenSubscribed(5); + thirdSub->requestWhenSubscribed(5); + + firstSub->awaitLatestValue(10); + secondSub->awaitLatestValue(10); + thirdSub->awaitLatestValue(5); + } +} + +TEST(ColdResumptionTest, DISABLED_SuccessfulResumption) { + auto server = makeResumableServer(std::make_shared()); + auto port = *server->listeningPort(); + + std::vector clients; + + for (int i = 0; i < FLAGS_num_clients; i++) { + auto client = std::thread([port, i]() { coldResumer(port, i); }); + clients.push_back(std::move(client)); + } + + for (auto& client : clients) { + client.join(); + } +} + +TEST(ColdResumptionTest, DifferentEvb) { + auto server = makeResumableServer(std::make_shared()); + auto port = *server->listeningPort(); + + auto payload = "InitialPayload"; + size_t latestValue; + + folly::ScopedEventBaseThread transportWorker{"transportWorker"}; + folly::ScopedEventBaseThread SMWorker{"SMWorker"}; + + auto token = ResumeIdentificationToken::generateNew(); + auto resumeManager = + std::make_shared(RSocketStats::noop()); + { + auto firstSub = std::make_shared(0); + { + auto coldResumeHandler = std::make_shared( + HelloSubscribers({{payload, firstSub}})); + std::shared_ptr firstClient; + EXPECT_NO_THROW( + firstClient = makeColdResumableClient( + transportWorker.getEventBase(), + port, + token, + resumeManager, + coldResumeHandler, + SMWorker.getEventBase())); + firstClient->getRequester() + ->requestStream(Payload(payload)) + ->subscribe(firstSub); + firstSub->requestWhenSubscribed(7); + // Ensure reception of few frames before resuming. + while (firstSub->valueCount() < 1) { + std::this_thread::yield(); + } + } + SMWorker.getEventBase()->runInEventBaseThreadAndWait( + [&latestValue, firstSub = std::move(firstSub)]() { + latestValue = firstSub->getLatestValue(); + VLOG(1) << latestValue; + VLOG(1) << "First Resume"; + }); + } + + { + auto firstSub = std::make_shared(latestValue); + { + auto coldResumeHandler = std::make_shared( + HelloSubscribers({{payload, firstSub}})); + std::shared_ptr secondClient; + EXPECT_NO_THROW( + secondClient = createResumedClient( + transportWorker.getEventBase(), + port, + token, + resumeManager, + coldResumeHandler, + SMWorker.getEventBase())); + + firstSub->requestWhenSubscribed(3); + // Ensure reception of few frames before resuming. + while (firstSub->valueCount() < 1) { + std::this_thread::yield(); + } + firstSub->awaitLatestValue(10); + } + } + + server->shutdownAndWait(); +} + +// Attempt a resumption when the previous transport/client hasn't +// disconnected it. Verify resumption succeeds after the previous +// transport is disconnected. +TEST(ColdResumptionTest, DisconnectResumption) { + auto server = makeResumableServer(std::make_shared()); + auto port = *server->listeningPort(); + + auto payload = "InitialPayload"; + + folly::ScopedEventBaseThread transportWorker{"transportWorker"}; + + auto token = ResumeIdentificationToken::generateNew(); + auto resumeManager = + std::make_shared(RSocketStats::noop()); + auto sub = std::make_shared(0); + auto crh = + std::make_shared(HelloSubscribers({{payload, sub}})); + std::shared_ptr client; + EXPECT_NO_THROW( + client = makeColdResumableClient( + transportWorker.getEventBase(), port, token, resumeManager, crh)); + client->getRequester()->requestStream(Payload(payload))->subscribe(sub); + sub->requestWhenSubscribed(7); + // Ensure reception of few frames before resuming. + while (sub->valueCount() < 7) { + std::this_thread::yield(); + } + + auto resumedSub = std::make_shared(7); + auto resumedCrh = std::make_shared( + HelloSubscribers({{payload, resumedSub}})); + + std::shared_ptr resumedClient; + EXPECT_NO_THROW( + resumedClient = createResumedClient( + transportWorker.getEventBase(), + port, + token, + resumeManager, + resumedCrh)); + + resumedSub->requestWhenSubscribed(3); + resumedSub->awaitLatestValue(10); + + server->shutdownAndWait(); +} diff --git a/rsocket/test/ConnectionEventsTest.cpp b/rsocket/test/ConnectionEventsTest.cpp new file mode 100644 index 000000000..c6ec2ba27 --- /dev/null +++ b/rsocket/test/ConnectionEventsTest.cpp @@ -0,0 +1,178 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "RSocketTests.h" + +#include "rsocket/test/handlers/HelloServiceHandler.h" + +#include "yarpl/flowable/TestSubscriber.h" + +using namespace rsocket; +using namespace rsocket::tests; +using namespace testing; +using namespace rsocket::tests::client_server; +using namespace yarpl::flowable; + +namespace { + +class MockConnEvents : public RSocketConnectionEvents { + public: + MOCK_METHOD0(onConnected, void()); + MOCK_METHOD1(onDisconnected, void(const folly::exception_wrapper&)); + MOCK_METHOD0(onStreamsPaused, void()); + MOCK_METHOD0(onStreamsResumed, void()); + MOCK_METHOD1(onClosed, void(const folly::exception_wrapper&)); +}; + +} // anonymous namespace + +TEST(ConnectionEventsTest, SimpleStream) { + folly::ScopedEventBaseThread worker; + auto serverConnEvents = std::make_shared>(); + auto clientConnEvents = std::make_shared>(); + + EXPECT_CALL(*clientConnEvents, onConnected()); + EXPECT_CALL(*serverConnEvents, onConnected()); + + // create server supporting resumption + auto server = makeResumableServer( + std::make_shared(serverConnEvents)); + + // create resumable client + auto client = makeWarmResumableClient( + worker.getEventBase(), *server->listeningPort(), clientConnEvents); + + // request stream + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(7 /* initialRequestN */); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + // Wait for a few frames before disconnecting. + while (ts->getValueCount() < 3) { + std::this_thread::yield(); + } + + // disconnect + EXPECT_CALL(*clientConnEvents, onDisconnected(_)); + EXPECT_CALL(*clientConnEvents, onStreamsPaused()); + EXPECT_CALL(*serverConnEvents, onDisconnected(_)); + EXPECT_CALL(*serverConnEvents, onStreamsPaused()); + client->disconnect(std::runtime_error("Test triggered disconnect")); + + // resume + EXPECT_CALL(*clientConnEvents, onConnected()); + EXPECT_CALL(*clientConnEvents, onStreamsResumed()); + EXPECT_CALL(*serverConnEvents, onConnected()); + EXPECT_CALL(*serverConnEvents, onStreamsResumed()); + EXPECT_NO_THROW(client->resume().get()); + + ts->request(3); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + + // disconnect + EXPECT_CALL(*clientConnEvents, onDisconnected(_)); + EXPECT_CALL(*clientConnEvents, onStreamsPaused()); + EXPECT_CALL(*serverConnEvents, onDisconnected(_)); + EXPECT_CALL(*serverConnEvents, onStreamsPaused()); + client->disconnect(std::runtime_error("Test triggered disconnect")); + + // relinquish resources + EXPECT_CALL(*clientConnEvents, onClosed(_)); + EXPECT_CALL(*serverConnEvents, onClosed(_)); +} + +// Verify the ConnectionEvents are called back on the right EventBase. +TEST(ConnectionEventsTest, DifferentEvb) { + folly::ScopedEventBaseThread transportWorker{"TransportWkr"}; + folly::ScopedEventBaseThread SMWorker{"SMWorker"}; + + auto clientConnEvents = std::make_shared>(); + + EXPECT_CALL(*clientConnEvents, onConnected()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + + // create server supporting resumption + auto server = makeResumableServer(std::make_shared()); + + // create resumable client + auto client = makeWarmResumableClient( + transportWorker.getEventBase(), + *server->listeningPort(), + clientConnEvents, + SMWorker.getEventBase()); + + // request stream + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(7 /* initialRequestN */); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + // Wait for a few frames before disconnecting. + while (ts->getValueCount() < 3) { + std::this_thread::yield(); + } + + // disconnect + EXPECT_CALL(*clientConnEvents, onDisconnected(_)) + .WillOnce(InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + EXPECT_CALL(*clientConnEvents, onStreamsPaused()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + client->disconnect(std::runtime_error("Test triggered disconnect")); + + // resume + EXPECT_CALL(*clientConnEvents, onConnected()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + EXPECT_CALL(*clientConnEvents, onStreamsResumed()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + EXPECT_NO_THROW(client->resume().get()); + + ts->request(3); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + + // disconnect + EXPECT_CALL(*clientConnEvents, onDisconnected(_)) + .WillOnce(InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + EXPECT_CALL(*clientConnEvents, onStreamsPaused()) + .WillOnce(Invoke([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); + client->disconnect(std::runtime_error("Test triggered disconnect")); + + // relinquish resources + EXPECT_CALL(*clientConnEvents, onClosed(_)) + .WillOnce(InvokeWithoutArgs([evb = SMWorker.getEventBase()]() { + EXPECT_TRUE(evb->isInEventBaseThread()); + })); +} diff --git a/test/PayloadTest.cpp b/rsocket/test/PayloadTest.cpp similarity index 71% rename from test/PayloadTest.cpp rename to rsocket/test/PayloadTest.cpp index c9435ee61..3b9e8496b 100644 --- a/test/PayloadTest.cpp +++ b/rsocket/test/PayloadTest.cpp @@ -1,13 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include + #include + #include "rsocket/Payload.h" #include "rsocket/framing/Frame.h" -#include "rsocket/framing/FrameSerializer_v0_1.h" +#include "rsocket/framing/FrameSerializer_v1_0.h" -using namespace ::testing; using namespace ::rsocket; TEST(PayloadTest, EmptyMetadata) { @@ -24,17 +37,6 @@ TEST(PayloadTest, Clear) { ASSERT_FALSE(p); } -TEST(PayloadTest, GiantMetadata) { - constexpr auto metadataSize = std::numeric_limits::max(); - - auto metadata = folly::IOBuf::wrapBuffer(&metadataSize, sizeof(metadataSize)); - folly::io::Cursor cur(metadata.get()); - - EXPECT_THROW( - FrameSerializerV0_1::deserializeMetadataFrom(cur, FrameFlags::METADATA), - std::runtime_error); -} - TEST(PayloadTest, Clone) { Payload orig("data", "metadata"); diff --git a/rsocket/test/RSocketClientServerTest.cpp b/rsocket/test/RSocketClientServerTest.cpp new file mode 100644 index 000000000..13c4a9218 --- /dev/null +++ b/rsocket/test/RSocketClientServerTest.cpp @@ -0,0 +1,138 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RSocketTests.h" + +#include +#include +#include +#include "rsocket/test/handlers/HelloStreamRequestHandler.h" + +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; + +TEST(RSocketClientServer, StartAndShutdown) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); +} + +TEST(RSocketClientServer, ConnectOne) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); +} + +TEST(RSocketClientServer, ConnectManySync) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + + for (size_t i = 0; i < 100; ++i) { + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + } +} + +TEST(RSocketClientServer, ConnectManyAsync) { + auto server = makeServer(std::make_shared()); + + constexpr size_t connectionCount = 100; + constexpr size_t workerCount = 10; + std::vector workers(workerCount); + std::vector>> clients; + + std::atomic executed{0}; + for (size_t i = 0; i < connectionCount; ++i) { + int workerId = folly::Random::rand32(workerCount); + auto clientFuture = + makeClientAsync( + workers[workerId].getEventBase(), *server->listeningPort()) + .thenValue( + [&executed](std::shared_ptr client) { + ++executed; + return client; + }) + .thenError([&](folly::exception_wrapper ex) { + LOG(ERROR) << "error: " << ex.what(); + ++executed; + return std::shared_ptr(nullptr); + }); + clients.emplace_back(std::move(clientFuture)); + } + + CHECK_EQ(clients.size(), connectionCount); + auto results = folly::collectAll(clients).get(std::chrono::minutes{1}); + CHECK_EQ(results.size(), connectionCount); + + results.clear(); + clients.clear(); + CHECK_EQ(executed, connectionCount); + workers.clear(); +} + +TEST(RSocketClientServer, ConnectOnDifferentEvb) { + folly::ScopedEventBaseThread transportWorker{"transportWorker"}; + folly::ScopedEventBaseThread stateMachineWorker{"stateMachineWorker"}; + auto server = makeServer(std::make_shared()); + auto client = makeClient( + transportWorker.getEventBase(), + *server->listeningPort(), + stateMachineWorker.getEventBase()); +} + +/// Test destroying a client with an open connection on the same worker thread +/// as that connection. +TEST(RSocketClientServer, ClientClosesOnWorker) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + + // Move the client to the worker thread. + worker.getEventBase()->runInEventBaseThread([c = std::move(client)] {}); +} + +/// Test that sending garbage to the server doesn't crash it. +TEST(RSocketClientServer, ServerGetsGarbage) { + auto server = makeServer(std::make_shared()); + folly::SocketAddress address{"127.0.0.1", *server->listeningPort()}; + + folly::ScopedEventBaseThread worker; + auto factory = + std::make_shared(*worker.getEventBase(), address); + + auto result = + factory->connect(ProtocolVersion::Latest, ResumeStatus::NEW_SESSION) + .get(); + auto connection = std::move(result.connection); + auto evb = &result.eventBase; + + evb->runInEventBaseThreadAndWait([conn = std::move(connection)]() mutable { + conn->send(folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP")); + conn.reset(); + }); +} + +/// Test closing a server with a bunch of open connections. +TEST(RSocketClientServer, CloseServerWithConnections) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + std::vector> clients; + + for (size_t i = 0; i < 100; ++i) { + clients.push_back( + makeClient(worker.getEventBase(), *server->listeningPort())); + } + + server.reset(); +} diff --git a/rsocket/test/RSocketClientTest.cpp b/rsocket/test/RSocketClientTest.cpp new file mode 100644 index 000000000..5a96d09ae --- /dev/null +++ b/rsocket/test/RSocketClientTest.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RSocketTests.h" + +#include +#include + +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::single; + +TEST(RSocketClient, ConnectFails) { + folly::ScopedEventBaseThread worker; + + folly::SocketAddress address; + address.setFromHostPort("localhost", 1); + auto client = + RSocket::createConnectedClient(std::make_unique( + *worker.getEventBase(), std::move(address))); + + std::move(client) + .thenValue([&](auto&&) { FAIL() << "the test needs to fail"; }) + .thenError( + folly::tag_t{}, + [&](const std::exception&) { + LOG(INFO) << "connection failed as expected"; + }) + .get(); +} + +TEST(RSocketClient, PreallocatedBytesInFrames) { + auto connection = std::make_unique(); + EXPECT_CALL(*connection, isFramed()).WillRepeatedly(Return(true)); + + // SETUP frame and FIRE_N_FORGET frame send + EXPECT_CALL(*connection, send_(_)) + .Times(2) + .WillRepeatedly( + Invoke([](std::unique_ptr& serializedFrame) { + // we should have headroom preallocated for the frame size field + EXPECT_EQ( + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest) + ->frameLengthFieldSize(), + serializedFrame->headroom()); + })); + + folly::ScopedEventBaseThread worker; + + worker.getEventBase()->runInEventBaseThread([&] { + auto client = RSocket::createClientFromConnection( + std::move(connection), *worker.getEventBase()); + + client->getRequester() + ->fireAndForget(Payload("hello")) + ->subscribe(SingleObservers::create()); + }); +} diff --git a/rsocket/test/RSocketTests.cpp b/rsocket/test/RSocketTests.cpp new file mode 100644 index 000000000..c3d309f0e --- /dev/null +++ b/rsocket/test/RSocketTests.cpp @@ -0,0 +1,179 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/test/RSocketTests.h" + +#include "rsocket/internal/WarmResumeManager.h" +#include "rsocket/test/test_utils/GenericRequestResponseHandler.h" +#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" + +namespace rsocket { +namespace tests { +namespace client_server { + +std::unique_ptr getConnFactory( + folly::EventBase* eventBase, + uint16_t port) { + folly::SocketAddress address{"127.0.0.1", port}; + return std::make_unique(*eventBase, std::move(address)); +} + +std::unique_ptr makeServer( + std::shared_ptr responder, + std::shared_ptr stats) { + TcpConnectionAcceptor::Options opts; + opts.threads = 2; + opts.address = folly::SocketAddress("0.0.0.0", 0); + + // RSocket server accepting on TCP. + auto rs = RSocket::createServer( + std::make_unique(std::move(opts)), + std::move(stats)); + + rs->start([r = std::move(responder)](const SetupParameters&) { return r; }); + return rs; +} + +std::unique_ptr makeResumableServer( + std::shared_ptr serviceHandler) { + TcpConnectionAcceptor::Options opts; + opts.threads = 10; + opts.backlog = 200; + opts.address = folly::SocketAddress("0.0.0.0", 0); + auto rs = RSocket::createServer( + std::make_unique(std::move(opts))); + rs->start(std::move(serviceHandler)); + return rs; +} + +folly::Future> makeClientAsync( + folly::EventBase* eventBase, + uint16_t port, + folly::EventBase* stateMachineEvb, + std::shared_ptr stats) { + CHECK(eventBase); + return RSocket::createConnectedClient( + getConnFactory(eventBase, port), + SetupParameters(), + std::make_shared(), + kDefaultKeepaliveInterval, + std::move(stats), + std::shared_ptr(), + ResumeManager::makeEmpty(), + std::shared_ptr(), + stateMachineEvb); +} + +std::unique_ptr makeClient( + folly::EventBase* eventBase, + uint16_t port, + folly::EventBase* stateMachineEvb, + std::shared_ptr stats) { + return makeClientAsync(eventBase, port, stateMachineEvb, std::move(stats)) + .get(); +} + +namespace { +struct DisconnectedResponder : public rsocket::RSocketResponder { + DisconnectedResponder() {} + + std::shared_ptr> + handleRequestResponse(rsocket::Payload, rsocket::StreamId) override { + CHECK(false); + return nullptr; + } + + std::shared_ptr> + handleRequestStream(rsocket::Payload, rsocket::StreamId) override { + CHECK(false); + return nullptr; + } + + std::shared_ptr> + handleRequestChannel( + rsocket::Payload, + std::shared_ptr>, + rsocket::StreamId) override { + CHECK(false); + return nullptr; + } + + void handleFireAndForget(rsocket::Payload, rsocket::StreamId) override { + CHECK(false); + } + + void handleMetadataPush(std::unique_ptr) override { + CHECK(false); + } + + ~DisconnectedResponder() override {} +}; +} // namespace + +std::unique_ptr makeDisconnectedClient( + folly::EventBase* eventBase) { + auto server = makeServer(std::make_shared()); + + auto client = makeClient(eventBase, *server->listeningPort()); + client->disconnect().get(); + return client; +} + +std::unique_ptr makeWarmResumableClient( + folly::EventBase* eventBase, + uint16_t port, + std::shared_ptr connectionEvents, + folly::EventBase* stateMachineEvb) { + CHECK(eventBase); + SetupParameters setupParameters; + setupParameters.resumable = true; + return RSocket::createConnectedClient( + getConnFactory(eventBase, port), + std::move(setupParameters), + std::make_shared(), + kDefaultKeepaliveInterval, + RSocketStats::noop(), + std::move(connectionEvents), + std::make_shared(RSocketStats::noop()), + std::shared_ptr(), + stateMachineEvb) + .get(); +} + +std::unique_ptr makeColdResumableClient( + folly::EventBase* eventBase, + uint16_t port, + ResumeIdentificationToken token, + std::shared_ptr resumeManager, + std::shared_ptr coldResumeHandler, + folly::EventBase* stateMachineEvb) { + SetupParameters setupParameters; + setupParameters.resumable = true; + setupParameters.token = token; + return RSocket::createConnectedClient( + getConnFactory(eventBase, port), + std::move(setupParameters), + nullptr, // responder + kDefaultKeepaliveInterval, + nullptr, // stats + nullptr, // connectionEvents + resumeManager, + coldResumeHandler, + stateMachineEvb) + .get(); +} + +} // namespace client_server +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/RSocketTests.h b/rsocket/test/RSocketTests.h new file mode 100644 index 000000000..147238901 --- /dev/null +++ b/rsocket/test/RSocketTests.h @@ -0,0 +1,178 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/RSocket.h" + +#include "rsocket/transports/tcp/TcpConnectionFactory.h" + +namespace rsocket { +namespace tests { +namespace client_server { + +class RSocketStatsFlowControl : public RSocketStats { + public: + void frameWritten(FrameType frameType) { + if (frameType == FrameType::REQUEST_N) { + ++writeRequestN_; + } + } + + void frameRead(FrameType frameType) { + if (frameType == FrameType::REQUEST_N) { + ++readRequestN_; + } + } + + public: + int writeRequestN_{0}; + int readRequestN_{0}; +}; + +std::unique_ptr getConnFactory( + folly::EventBase* eventBase, + uint16_t port); + +std::unique_ptr makeServer( + std::shared_ptr responder, + std::shared_ptr stats = RSocketStats::noop()); + +std::unique_ptr makeResumableServer( + std::shared_ptr serviceHandler); + +std::unique_ptr makeClient( + folly::EventBase* eventBase, + uint16_t port, + folly::EventBase* stateMachineEvb = nullptr, + std::shared_ptr stats = RSocketStats::noop()); + +std::unique_ptr makeDisconnectedClient( + folly::EventBase* eventBase); + +folly::Future> makeClientAsync( + folly::EventBase* eventBase, + uint16_t port, + folly::EventBase* stateMachineEvb = nullptr, + std::shared_ptr stats = RSocketStats::noop()); + +std::unique_ptr makeWarmResumableClient( + folly::EventBase* eventBase, + uint16_t port, + std::shared_ptr connectionEvents = nullptr, + folly::EventBase* stateMachineEvb = nullptr); + +std::unique_ptr makeColdResumableClient( + folly::EventBase* eventBase, + uint16_t port, + ResumeIdentificationToken token, + std::shared_ptr resumeManager, + std::shared_ptr resumeHandler, + folly::EventBase* stateMachineEvb = nullptr); + +} // namespace client_server + +struct RSocketPayloadUtils { + // ~30 megabytes, for metadata+data + static constexpr size_t LargeRequestSize = 15 * 1024 * 1024; + static std::string makeLongString(size_t size, std::string pattern) { + while (pattern.size() < size) { + pattern += pattern; + } + return pattern; + } + + // Builds up an IOBuf consisting of chunks with the following sizes, and then + // the rest tacked on the end in one big iobuf chunk + static std::unique_ptr buildIOBufFromString( + std::vector const& sizes, + std::string const& from) { + folly::IOBufQueue bufQueue{folly::IOBufQueue::cacheChainLength()}; + size_t fromCursor = 0; + size_t remaining = from.size(); + for (auto size : sizes) { + if (remaining == 0) + break; + if (size > remaining) { + size = remaining; + } + + bufQueue.append( + folly::IOBuf::copyBuffer(from.c_str() + fromCursor, size)); + + fromCursor += size; + remaining -= size; + } + + if (remaining) { + bufQueue.append( + folly::IOBuf::copyBuffer(from.c_str() + fromCursor, remaining)); + } + + CHECK_EQ(bufQueue.chainLength(), from.size()); + + auto ret = bufQueue.move(); + int numChainElems = 1; + auto currentChainElem = ret.get()->next(); + while (currentChainElem != ret.get()) { + numChainElems++; + currentChainElem = currentChainElem->next(); + } + CHECK_GE(numChainElems, sizes.size()); + + // verify that the returned buffer has identical data + auto str = ret->cloneAsValue().moveToFbString().toStdString(); + CHECK_EQ(str.size(), from.size()); + CHECK(str == from); + + return ret; + } + + static void checkSameStrings( + std::string const& got, + std::string const& expect, + std::string const& context) { + CHECK_EQ(got.size(), expect.size()) + << "Got mismatched size " << context << " string (" << got.size() + << " vs " << expect.size() << ")"; + CHECK(got == expect) << context << " mismatch between got and expected"; + } + + static void checkSameStrings( + std::unique_ptr const& got, + std::string const& expect, + std::string const& context) { + CHECK_EQ(got->computeChainDataLength(), expect.size()) + << "Mismatched size " << context << ", got " + << got->computeChainDataLength() << " vs expect " << expect.size(); + + size_t expect_cursor = 0; + + for (auto range : *got) { + for (auto got_chr : range) { + // perform redundant check to avoid gtest's CHECK overhead + if (got_chr != expect[expect_cursor]) { + CHECK_EQ(got_chr, expect[expect_cursor]) + << "mismatch at byte " << expect_cursor; + } + expect_cursor++; + } + } + } +}; + +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/RequestChannelTest.cpp b/rsocket/test/RequestChannelTest.cpp new file mode 100644 index 000000000..4a815166a --- /dev/null +++ b/rsocket/test/RequestChannelTest.cpp @@ -0,0 +1,510 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "RSocketTests.h" +#include "rsocket/test/test_utils/GenericRequestResponseHandler.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +using namespace yarpl; +using namespace yarpl::flowable; +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; + +/** + * Test a finite stream both directions. + */ +class TestHandlerHello : public rsocket::RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestChannel( + rsocket::Payload initialPayload, + std::shared_ptr> stream, + rsocket::StreamId /*streamId*/) override { + // say "Hello" to each name on the input stream + return stream->map([initialPayload = std::move(initialPayload)](Payload p) { + std::stringstream ss; + ss << "[" << initialPayload.cloneDataToString() << "] " + << "Hello " << p.moveDataToString() << "!"; + std::string s = ss.str(); + + return Payload(s); + }); + } +}; + +TEST(RequestChannelTest, Hello) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto ts = TestSubscriber::create(); + requester + ->requestChannel( + Payload("/hello"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { + return Payload(v); + })) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(2); + // assert that we echo back the 2nd and 3rd request values + // with the 1st initial payload prepended to each + ts->assertValueAt(0, "[/hello] Hello Bob!"); + ts->assertValueAt(1, "[/hello] Hello Jane!"); +} + +TEST(RequestChannelTest, HelloNoFlowControl) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto stats = std::make_shared(); + auto client = makeClient( + worker.getEventBase(), *server->listeningPort(), nullptr, stats); + auto requester = client->getRequester(); + + auto ts = TestSubscriber::create(1000); + requester + ->requestChannel( + Payload("/hello"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { + return Payload(v); + })) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(2); + // assert that we echo back the 2nd and 3rd request values + // with the 1st initial payload prepended to each + ts->assertValueAt(0, "[/hello] Hello Bob!"); + ts->assertValueAt(1, "[/hello] Hello Jane!"); + + // Make sure that the initial requestN in the Stream Request Frame + // is already enough and no other requestN messages are sent. + EXPECT_EQ(stats->writeRequestN_, 0); +} + +TEST(RequestChannelTest, RequestOnDisconnectedClient) { + folly::ScopedEventBaseThread worker; + auto client = makeDisconnectedClient(worker.getEventBase()); + auto requester = client->getRequester(); + + bool did_call_on_error = false; + folly::Baton<> wait_for_on_error; + + auto instream = Flowable::empty(); + requester->requestChannel(instream)->subscribe( + [](auto /* payload */) { + // onNext shouldn't be called + FAIL(); + }, + [&](folly::exception_wrapper) { + did_call_on_error = true; + wait_for_on_error.post(); + }, + []() { + // onComplete shouldn't be called + FAIL(); + }); + + wait_for_on_error.timed_wait(std::chrono::milliseconds(100)); + ASSERT_TRUE(did_call_on_error); +} + +class TestChannelResponder : public rsocket::RSocketResponder { + public: + TestChannelResponder( + int64_t rangeEnd = 10, + int64_t initialSubReq = credits::kNoFlowControl) + : rangeEnd_{rangeEnd}, + testSubscriber_{TestSubscriber::create(initialSubReq)} {} + + std::shared_ptr> handleRequestChannel( + rsocket::Payload initialPayload, + std::shared_ptr> requestStream, + rsocket::StreamId) override { + // add initial payload to testSubscriber values list + testSubscriber_->manuallyPush(initialPayload.moveDataToString()); + + requestStream->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(testSubscriber_); + + return Flowable<>::range(1, rangeEnd_)->map([&](int64_t v) { + std::stringstream ss; + ss << "Responder stream: " << v << " of " << rangeEnd_; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + } + + std::shared_ptr> getChannelSubscriber() { + return testSubscriber_; + } + + private: + int64_t rangeEnd_; + std::shared_ptr> testSubscriber_; +}; + +TEST(RequestChannelTest, CompleteRequesterResponderContinues) { + int64_t responderRange = 100; + int64_t responderSubscriberInitialRequest = credits::kNoFlowControl; + + auto responder = std::make_shared( + responderRange, responderSubscriberInitialRequest); + folly::ScopedEventBaseThread worker; + + auto server = makeServer(responder); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto requestSubscriber = TestSubscriber::create(50); + auto responderSubscriber = responder->getChannelSubscriber(); + + int64_t requesterRangeEnd = 10; + + auto requesterFlowable = + Flowable<>::range(1, requesterRangeEnd)->map([=](int64_t v) { + std::stringstream ss; + ss << "Requester stream: " << v << " of " << requesterRangeEnd; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + + requester->requestChannel(Payload("Initial Request"), requesterFlowable) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(requestSubscriber); + + // finish streaming from Requester + responderSubscriber->awaitTerminalEvent(); + responderSubscriber->assertSuccess(); + responderSubscriber->assertValueCount(11); + responderSubscriber->assertValueAt(0, "Initial Request"); + responderSubscriber->assertValueAt(1, "Requester stream: 1 of 10"); + responderSubscriber->assertValueAt(10, "Requester stream: 10 of 10"); + + // Requester stream is closed, Responder continues + requestSubscriber->request(50); + requestSubscriber->awaitTerminalEvent(); + requestSubscriber->assertSuccess(); + requestSubscriber->assertValueCount(100); + requestSubscriber->assertValueAt(0, "Responder stream: 1 of 100"); + requestSubscriber->assertValueAt(99, "Responder stream: 100 of 100"); +} + +TEST(RequestChannelTest, CompleteResponderRequesterContinues) { + int64_t responderRange = 10; + int64_t responderSubscriberInitialRequest = 50; + + auto responder = std::make_shared( + responderRange, responderSubscriberInitialRequest); + + folly::ScopedEventBaseThread worker; + auto server = makeServer(responder); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto requestSubscriber = TestSubscriber::create(); + auto responderSubscriber = responder->getChannelSubscriber(); + + int64_t requesterRangeEnd = 100; + + auto requesterFlowable = + Flowable<>::range(1, requesterRangeEnd)->map([=](int64_t v) { + std::stringstream ss; + ss << "Requester stream: " << v << " of " << requesterRangeEnd; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + + requester->requestChannel(requesterFlowable) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(requestSubscriber); + + // finish streaming from Responder + requestSubscriber->awaitTerminalEvent(); + requestSubscriber->assertSuccess(); + requestSubscriber->assertValueCount(10); + requestSubscriber->assertValueAt(0, "Responder stream: 1 of 10"); + requestSubscriber->assertValueAt(9, "Responder stream: 10 of 10"); + + // Responder stream is closed, Requester continues + responderSubscriber->request(50); + responderSubscriber->awaitTerminalEvent(); + responderSubscriber->assertSuccess(); + responderSubscriber->assertValueCount(100); + responderSubscriber->assertValueAt(0, "Requester stream: 1 of 100"); + responderSubscriber->assertValueAt(99, "Requester stream: 100 of 100"); +} + +TEST(RequestChannelTest, FlowControl) { + constexpr int64_t responderRange = 10; + constexpr int64_t responderSubscriberInitialRequest = 0; + + auto responder = std::make_shared( + responderRange, responderSubscriberInitialRequest); + + folly::ScopedEventBaseThread worker; + auto server = makeServer(responder); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto requestSubscriber = TestSubscriber::create(0); + auto responderSubscriber = responder->getChannelSubscriber(); + + constexpr int64_t requesterRangeEnd = 10; + + auto requesterFlowable = + Flowable<>::range(1, requesterRangeEnd)->map([&](int64_t v) { + std::stringstream ss; + ss << "Requester stream: " << v << " of " << requesterRangeEnd; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + + requester->requestChannel(Payload("Initial Request"), requesterFlowable) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(requestSubscriber); + + // Wait till the Channel is created + responderSubscriber->awaitValueCount(1); + + for (int i = 1; i <= 10; i++) { + requestSubscriber->request(1); + requestSubscriber->awaitValueCount(i); + requestSubscriber->assertValueCount(i); + } + + for (int i = 1; i <= 10; i++) { + responderSubscriber->request(1); + // the channel initial payload was pushed to responderSubscriber so we + // need to add this one item to expected + responderSubscriber->awaitValueCount(i + 1); + responderSubscriber->assertValueCount(i + 1); + } + + requestSubscriber->awaitTerminalEvent(); + responderSubscriber->awaitTerminalEvent(); + + requestSubscriber->assertSuccess(); + responderSubscriber->assertSuccess(); + + requestSubscriber->assertValueAt(0, "Responder stream: 1 of 10"); + requestSubscriber->assertValueAt(9, "Responder stream: 10 of 10"); + + responderSubscriber->assertValueAt(0, "Initial Request"); + responderSubscriber->assertValueAt(1, "Requester stream: 1 of 10"); + responderSubscriber->assertValueAt(10, "Requester stream: 10 of 10"); +} + +class TestChannelResponderFailure : public rsocket::RSocketResponder { + public: + TestChannelResponderFailure() + : testSubscriber_{TestSubscriber::create()} {} + + std::shared_ptr> handleRequestChannel( + rsocket::Payload initialPayload, + std::shared_ptr> requestStream, + rsocket::StreamId) override { + // add initial payload to testSubscriber values list + testSubscriber_->manuallyPush(initialPayload.moveDataToString()); + + requestStream->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(testSubscriber_); + + return Flowable::error( + std::runtime_error("A wild Error appeared!")); + } + + std::shared_ptr> getChannelSubscriber() { + return testSubscriber_; + } + + private: + std::shared_ptr> testSubscriber_; +}; + +TEST(RequestChannelTest, FailureOnResponderRequesterSees) { + auto responder = std::make_shared(); + + folly::ScopedEventBaseThread worker; + auto server = makeServer(responder); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto requestSubscriber = TestSubscriber::create(); + auto responderSubscriber = responder->getChannelSubscriber(); + + int64_t requesterRangeEnd = 10; + + auto requesterFlowable = + Flowable<>::range(1, requesterRangeEnd)->map([&](int64_t v) { + std::stringstream ss; + ss << "Requester stream: " << v << " of " << requesterRangeEnd; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + + requester->requestChannel(Payload("Initial Request"), requesterFlowable) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(requestSubscriber); + + // failure streaming from Responder + requestSubscriber->awaitTerminalEvent(); + requestSubscriber->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(requestSubscriber->getException().with_exception( + [](ErrorWithPayload& err) { + EXPECT_STREQ( + "A wild Error appeared!", err.payload.moveDataToString().c_str()); + })); + + responderSubscriber->awaitTerminalEvent(); + responderSubscriber->assertSuccess(); + responderSubscriber->assertValueCount(1); + responderSubscriber->assertValueAt(0, "Initial Request"); +} + +struct LargePayloadChannelHandler : public rsocket::RSocketResponder { + LargePayloadChannelHandler(std::string const& data, std::string const& meta) + : data(data), meta(meta) {} + + std::shared_ptr> handleRequestChannel( + Payload initialPayload, + std::shared_ptr> stream, + StreamId) override { + RSocketPayloadUtils::checkSameStrings( + initialPayload.data, data, "data received in initial payload"); + RSocketPayloadUtils::checkSameStrings( + initialPayload.metadata, meta, "metadata received in initial payload"); + + return stream->map([&](Payload payload) { + RSocketPayloadUtils::checkSameStrings( + payload.data, data, "data received in server stream"); + RSocketPayloadUtils::checkSameStrings( + payload.metadata, meta, "metadata received in server stream"); + return payload; + }); + } + + std::string const& data; + std::string const& meta; +}; + +TEST(RequestChannelTest, TestLargePayload) { + LOG(INFO) << "Building up large data/metadata, this may take a moment..."; + std::string const niceLongData = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "ABCDEFGH"); + std::string const niceLongMeta = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "12345678"); + + LOG(INFO) << "Built meta size: " << niceLongMeta.size() + << " data size: " << niceLongData.size(); + + folly::ScopedEventBaseThread worker; + auto handler = + std::make_shared(niceLongData, niceLongMeta); + auto server = makeServer(handler); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto checkForSizePattern = [&](std::vector const& meta_sizes, + std::vector const& data_sizes) { + auto to = TestSubscriber::create(); + + auto seedPayload = Payload( + RSocketPayloadUtils::buildIOBufFromString(data_sizes, niceLongData), + RSocketPayloadUtils::buildIOBufFromString(meta_sizes, niceLongMeta)); + + auto makePayload = [&] { + return Payload(seedPayload.data->clone(), seedPayload.metadata->clone()); + }; + + auto requests = + yarpl::flowable::Flowable::create([&](auto& subscriber, + int64_t num) { + while (num--) { + subscriber.onNext(makePayload()); + } + })->take(3); + + requester->requestChannel(std::move(requests)) + ->map([&](Payload p) { + RSocketPayloadUtils::checkSameStrings( + p.data, niceLongData, "data received on client"); + RSocketPayloadUtils::checkSameStrings( + p.metadata, niceLongMeta, "metadata received on client"); + return 0; + }) + ->subscribe(to); + to->awaitTerminalEvent(std::chrono::seconds{20}); + to->assertValueCount(2); + to->assertSuccess(); + }; + + // All in one big chunk + checkForSizePattern({}, {}); + + // Small chunk, big chunk, small chunk + checkForSizePattern({100, 5 * 1024 * 1024, 100}, {100, 5 * 1024 * 1024, 100}); +} + +TEST(RequestChannelTest, MultiSubscribe) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto ts = TestSubscriber::create(); + auto stream = + requester + ->requestChannel( + Payload("/hello"), + Flowable<>::justN({"Bob", "Jane"})->map([](std::string v) { + return Payload(v); + })) + ->map([](auto p) { return p.moveDataToString(); }); + + // First subscribe + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(2); + // assert that we echo back the 2nd and 3rd request values + // with the 1st initial payload prepended to each + ts->assertValueAt(0, "[/hello] Hello Bob!"); + ts->assertValueAt(1, "[/hello] Hello Jane!"); + + // Second subscribe + ts = TestSubscriber::create(); + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(2); + // assert that we echo back the 2nd and 3rd request values + // with the 1st initial payload prepended to each + ts->assertValueAt(0, "[/hello] Hello Bob!"); + ts->assertValueAt(1, "[/hello] Hello Jane!"); +} diff --git a/rsocket/test/RequestResponseTest.cpp b/rsocket/test/RequestResponseTest.cpp new file mode 100644 index 000000000..313ff124d --- /dev/null +++ b/rsocket/test/RequestResponseTest.cpp @@ -0,0 +1,305 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "RSocketTests.h" +#include "rsocket/test/test_utils/GenericRequestResponseHandler.h" +#include "yarpl/Single.h" +#include "yarpl/single/SingleTestObserver.h" + +using namespace yarpl::single; +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; + +namespace { +class TestHandlerCancel : public rsocket::RSocketResponder { + public: + TestHandlerCancel( + std::shared_ptr> onCancel, + std::shared_ptr> onSubscribe) + : onCancel_(std::move(onCancel)), onSubscribe_(std::move(onSubscribe)) {} + std::shared_ptr> handleRequestResponse( + Payload request, + StreamId) override { + // used to signal to the client when the subscribe is received + onSubscribe_->post(); + // used to block this responder thread until a cancel is sent from client + // over network + auto cancelFromClient = std::make_shared>(); + // used to signal to the client once we receive a cancel + auto onCancel = onCancel_; + auto requestString = request.moveDataToString(); + return Single::create([name = std::move(requestString), + cancelFromClient, + onCancel](auto subscriber) mutable { + std::thread([subscriber = std::move(subscriber), + name = std::move(name), + cancelFromClient, + onCancel]() { + auto subscription = SingleSubscriptions::create( + [cancelFromClient] { cancelFromClient->post(); }); + subscriber->onSubscribe(subscription); + // simulate slow processing or IO being done + // and block this current background thread + // until we are cancelled + cancelFromClient->timed_wait(std::chrono::seconds(1)); + if (subscription->isCancelled()) { + // this is used by the unit test to assert the cancel was + // received + onCancel->post(); + } else { + // if not cancelled would do work and emit here + } + }).detach(); + }); + } + + private: + std::shared_ptr> onCancel_; + std::shared_ptr> onSubscribe_; +}; +} // namespace + +TEST(RequestResponseTest, Cancel) { + folly::ScopedEventBaseThread worker; + auto onCancel = std::make_shared>(); + auto onSubscribe = std::make_shared>(); + auto server = + makeServer(std::make_shared(onCancel, onSubscribe)); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = SingleTestObserver::create(); + requester->requestResponse(Payload("Jane")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(to); + // NOTE: wait for server to receive request/subscribe + // otherwise the cancellation will all happen locally + onSubscribe->wait(); + // now cancel the local subscription + to->cancel(); + // wait for cancel to propagate to server + onCancel->wait(); + // assert no signals received on client + to->assertNoTerminalEvent(); +} + +// response creation usage +TEST(RequestResponseTest, CanCtorTypes) { + Response r1 = payload_response("foo", "bar"); + Response r2 = error_response(std::runtime_error("whew!")); +} + +TEST(RequestResponseTest, Hello) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + [](StringPair const& request) { + return payload_response( + "Hello, " + request.first + " " + request.second + "!", ":)"); + })); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = SingleTestObserver::create(); + requester->requestResponse(Payload("Jane", "Doe")) + ->map(payload_to_stringpair) + ->subscribe(to); + to->awaitTerminalEvent(); + to->assertOnSuccessValue({"Hello, Jane Doe!", ":)"}); +} + +TEST(RequestResponseTest, FailureInResponse) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + [](StringPair const& request) { + EXPECT_EQ(request.first, "foo"); + EXPECT_EQ(request.second, "bar"); + return error_response(std::runtime_error("whew!")); + })); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = SingleTestObserver::create(); + requester->requestResponse(Payload("foo", "bar")) + ->map(payload_to_stringpair) + ->subscribe(to); + to->awaitTerminalEvent(); + to->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(to->getException().with_exception([](ErrorWithPayload& err) { + EXPECT_STREQ("whew!", err.payload.moveDataToString().c_str()); + })); +} + +TEST(RequestResponseTest, RequestOnDisconnectedClient) { + folly::ScopedEventBaseThread worker; + auto client = makeDisconnectedClient(worker.getEventBase()); + + auto requester = client->getRequester(); + bool did_call_on_error = false; + folly::Baton<> wait_for_on_error; + requester->requestResponse(Payload("foo", "bar")) + ->subscribe( + [](auto) { + // should not call onSuccess + FAIL(); + }, + [&](folly::exception_wrapper) { + did_call_on_error = true; + wait_for_on_error.post(); + }); + + wait_for_on_error.timed_wait(std::chrono::milliseconds(100)); + ASSERT_TRUE(did_call_on_error); +} + +// TODO: test that multiple requests on a requestResponse +// fail in a well-defined way (right now it'd nullptr deref) +TEST(RequestResponseTest, MultipleRequestsError) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + [](StringPair const& request) { + EXPECT_EQ(request.first, "foo"); + EXPECT_EQ(request.second, "bar"); + return payload_response("baz", "quix"); + })); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto flowable = requester->requestResponse(Payload("foo", "bar")); +} + +TEST(RequestResponseTest, FailureOnRequest) { + folly::ScopedEventBaseThread worker; + auto server = makeServer( + std::make_shared([](auto const&) { + ADD_FAILURE(); + return payload_response("", ""); + })); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + VLOG(0) << "Shutting down server so client request fails"; + server->shutdownAndWait(); + server.reset(); + VLOG(0) << "Done"; + + auto to = SingleTestObserver::create(); + requester->requestResponse(Payload("foo", "bar")) + ->map(payload_to_stringpair) + ->subscribe(to); + to->awaitTerminalEvent(); + EXPECT_TRUE(to->getError()); +} + +struct LargePayloadReqRespHandler : public rsocket::RSocketResponder { + LargePayloadReqRespHandler(std::string const& data, std::string const& meta) + : data(data), meta(meta) {} + + std::shared_ptr> handleRequestResponse( + Payload payload, + StreamId) override { + RSocketPayloadUtils::checkSameStrings( + payload.data, data, "data received in payload"); + RSocketPayloadUtils::checkSameStrings( + payload.metadata, meta, "metadata received in payload"); + + return yarpl::single::Single::create( + [p = std::move(payload)](auto sub) mutable { + sub->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + sub->onSuccess(std::move(p)); + }); + } + + std::string const& data; + std::string const& meta; +}; + +TEST(RequestResponseTest, TestLargePayload) { + VLOG(1) << "Building up large data/metadata, this may take a moment..."; + std::string niceLongData = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "ABCDEFGH"); + std::string niceLongMeta = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "12345678"); + VLOG(1) << "Built meta size: " << niceLongMeta.size() + << " data size: " << niceLongData.size(); + + auto checkForSizePattern = [&](std::vector const& meta_sizes, + std::vector const& data_sizes) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + niceLongData, niceLongMeta)); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto to = SingleTestObserver::create(); + + requester + ->requestResponse(Payload( + RSocketPayloadUtils::buildIOBufFromString(data_sizes, niceLongData), + RSocketPayloadUtils::buildIOBufFromString( + meta_sizes, niceLongMeta))) + ->map([&](Payload p) { + RSocketPayloadUtils::checkSameStrings( + p.data, niceLongData, "data (received on client)"); + RSocketPayloadUtils::checkSameStrings( + p.metadata, niceLongMeta, "metadata (received on client)"); + return 0; + }) + ->subscribe(to); + to->awaitTerminalEvent(); + to->assertSuccess(); + }; + + // All in one big chunk + checkForSizePattern({}, {}); + + // Small chunk, big chunk, small chunk + checkForSizePattern( + {100, 10 * 1024 * 1024, 100}, {100, 10 * 1024 * 1024, 100}); +} + +TEST(RequestResponseTest, MultiSubscribe) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared( + [](StringPair const& request) { + return payload_response( + "Hello, " + request.first + " " + request.second + "!", ":)"); + })); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = SingleTestObserver::create(); + auto single = requester->requestResponse(Payload("Jane", "Doe")) + ->map(payload_to_stringpair); + + // Subscribe once + single->subscribe(to); + to->awaitTerminalEvent(); + to->assertOnSuccessValue({"Hello, Jane Doe!", ":)"}); + + // Subscribe twice + to = SingleTestObserver::create(); + single->subscribe(to); + to->awaitTerminalEvent(); + to->assertOnSuccessValue({"Hello, Jane Doe!", ":)"}); +} diff --git a/rsocket/test/RequestStreamTest.cpp b/rsocket/test/RequestStreamTest.cpp new file mode 100644 index 000000000..22633697c --- /dev/null +++ b/rsocket/test/RequestStreamTest.cpp @@ -0,0 +1,357 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "RSocketTests.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +using namespace yarpl::flowable; +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; + +namespace { +class TestHandlerSync : public rsocket::RSocketResponder { + public: + std::shared_ptr> handleRequestStream( + Payload request, + StreamId) override { + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 10)->map( + [name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }); + } +}; + +TEST(RequestStreamTest, HelloSync) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); +} + +TEST(RequestStreamTest, HelloFlowControl) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(5); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + + ts->awaitValueCount(5); + + ts->assertValueCount(5); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(4, "Hello Bob 5!"); + + ts->request(5); + + ts->awaitValueCount(10); + + ts->assertValueCount(10); + ts->assertValueAt(5, "Hello Bob 6!"); + ts->assertValueAt(9, "Hello Bob 10!"); + + ts->awaitTerminalEvent(); + ts->assertSuccess(); +} + +TEST(RequestStreamTest, HelloNoFlowControl) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto stats = std::make_shared(); + auto client = makeClient( + worker.getEventBase(), *server->listeningPort(), nullptr, stats); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(1000); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); + + // Make sure that the initial requestN in the Stream Request Frame + // is already enough and no other requestN messages are sent. + EXPECT_EQ(stats->writeRequestN_, 0); +} + +class TestHandlerAsync : public rsocket::RSocketResponder { + public: + explicit TestHandlerAsync(folly::Executor& executor) : executor_(executor) {} + + std::shared_ptr> handleRequestStream( + Payload request, + StreamId) override { + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 40) + ->map([name = std::move(requestString)](int64_t v) { + std::stringstream ss; + ss << "Hello " << name << " " << v << "!"; + std::string s = ss.str(); + return Payload(s, "metadata"); + }) + ->subscribeOn(executor_); + } + + private: + folly::Executor& executor_; +}; +} // namespace + +TEST(RequestStreamTest, HelloAsync) { + folly::ScopedEventBaseThread worker; + folly::ScopedEventBaseThread worker2; + auto server = + makeServer(std::make_shared(*worker2.getEventBase())); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(40); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(39, "Hello Bob 40!"); +} + +TEST(RequestStreamTest, RequestOnDisconnectedClient) { + folly::ScopedEventBaseThread worker; + auto client = makeDisconnectedClient(worker.getEventBase()); + auto requester = client->getRequester(); + + bool did_call_on_error = false; + folly::Baton<> wait_for_on_error; + + requester->requestStream(Payload("foo", "bar")) + ->subscribe( + [](auto /* payload */) { + // onNext shouldn't be called + FAIL(); + }, + [&](folly::exception_wrapper) { + did_call_on_error = true; + wait_for_on_error.post(); + }, + []() { + // onComplete shouldn't be called + FAIL(); + }); + + wait_for_on_error.timed_wait(std::chrono::milliseconds(100)); + ASSERT_TRUE(did_call_on_error); +} + +class TestHandlerResponder : public rsocket::RSocketResponder { + public: + std::shared_ptr> handleRequestStream(Payload, StreamId) + override { + return Flowable::error( + std::runtime_error("A wild Error appeared!")); + } +}; + +TEST(RequestStreamTest, HandleError) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + // Hide the user error from the logs + ts->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(ts->getException().with_exception([](ErrorWithPayload& err) { + EXPECT_STREQ( + "A wild Error appeared!", err.payload.moveDataToString().c_str()); + })); +} + +class TestErrorAfterOnNextResponder : public rsocket::RSocketResponder { + public: + std::shared_ptr> handleRequestStream( + Payload request, + StreamId) override { + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable::create( + [name = std::move(requestString)]( + Subscriber& subscriber, int64_t requested) { + EXPECT_GT(requested, 1); + subscriber.onNext(Payload(name, "meta")); + subscriber.onNext(Payload(name, "meta")); + subscriber.onNext(Payload(name, "meta")); + subscriber.onNext(Payload(name, "meta")); + subscriber.onError(std::runtime_error("A wild Error appeared!")); + }); + } +}; + +TEST(RequestStreamTest, HandleErrorMidStream) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + requester->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertValueCount(4); + ts->assertOnErrorMessage("ErrorWithPayload"); + EXPECT_TRUE(ts->getException().with_exception([](ErrorWithPayload& err) { + EXPECT_STREQ( + "A wild Error appeared!", err.payload.moveDataToString().c_str()); + })); +} + +struct LargePayloadStreamHandler : public rsocket::RSocketResponder { + LargePayloadStreamHandler( + std::string const& data, + std::string const& meta, + Payload const& seedPayload) + : data(data), meta(meta), seedPayload(seedPayload) {} + + std::shared_ptr> handleRequestStream( + Payload initialPayload, + StreamId) override { + RSocketPayloadUtils::checkSameStrings( + initialPayload.data, data, "data received in initial payload"); + RSocketPayloadUtils::checkSameStrings( + initialPayload.metadata, meta, "metadata received in initial payload"); + + return yarpl::flowable::Flowable::create([&](auto& subscriber, + int64_t num) { + while (num--) { + auto p = Payload( + seedPayload.data->clone(), seedPayload.metadata->clone()); + subscriber.onNext(std::move(p)); + } + }) + ->take(3); + } + + std::string const& data; + std::string const& meta; + Payload const& seedPayload; +}; + +TEST(RequestStreamTest, TestLargePayload) { + LOG(INFO) << "Building up large data/metadata, this may take a moment..."; + // ~20 megabytes per frame (metadata + data) + std::string const niceLongData = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "ABCDEFGH"); + std::string const niceLongMeta = RSocketPayloadUtils::makeLongString( + RSocketPayloadUtils::LargeRequestSize, "12345678"); + + LOG(INFO) << "Built meta size: " << niceLongMeta.size() + << " data size: " << niceLongData.size(); + + auto checkForSizePattern = [&](std::vector const& meta_sizes, + std::vector const& data_sizes) { + folly::ScopedEventBaseThread worker; + auto seedPayload = Payload( + RSocketPayloadUtils::buildIOBufFromString(data_sizes, niceLongData), + RSocketPayloadUtils::buildIOBufFromString(meta_sizes, niceLongMeta)); + auto makePayload = [&] { + return Payload(seedPayload.data->clone(), seedPayload.metadata->clone()); + }; + + auto handler = std::make_shared( + niceLongData, niceLongMeta, seedPayload); + auto server = makeServer(handler); + + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto to = TestSubscriber::create(); + + requester->requestStream(makePayload()) + ->map([&](Payload p) { + RSocketPayloadUtils::checkSameStrings( + p.data, niceLongData, "data received on client"); + RSocketPayloadUtils::checkSameStrings( + p.metadata, niceLongMeta, "metadata received on client"); + return 0; + }) + ->subscribe(to); + to->awaitTerminalEvent(std::chrono::seconds{20}); + to->assertValueCount(3); + to->assertSuccess(); + }; + + // All in one big chunk + checkForSizePattern({}, {}); + + // Small chunk, big chunk, small chunk + checkForSizePattern({100, 5 * 1024 * 1024, 100}, {100, 5 * 1024 * 1024, 100}); +} + +TEST(RequestStreamTest, MultiSubscribe) { + folly::ScopedEventBaseThread worker; + auto server = makeServer(std::make_shared()); + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + auto ts = TestSubscriber::create(); + auto stream = requester->requestStream(Payload("Bob"))->map([](auto p) { + return p.moveDataToString(); + }); + + // First subscribe + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); + + // Second subscribe + ts = TestSubscriber::create(); + stream->subscribe(ts); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); + ts->assertValueAt(0, "Hello Bob 1!"); + ts->assertValueAt(9, "Hello Bob 10!"); +} diff --git a/rsocket/test/RequestStreamTest_concurrency.cpp b/rsocket/test/RequestStreamTest_concurrency.cpp new file mode 100644 index 000000000..ce9b7e0b6 --- /dev/null +++ b/rsocket/test/RequestStreamTest_concurrency.cpp @@ -0,0 +1,153 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "RSocketTests.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +#include "yarpl/test_utils/Mocks.h" + +using namespace yarpl::flowable; +using namespace rsocket; +using namespace rsocket::tests::client_server; + +struct LockstepBatons { + folly::Baton<> onSecondPayloadSent; + folly::Baton<> onCancelSent; + folly::Baton<> onCancelReceivedToserver; + folly::Baton<> onCancelReceivedToclient; + folly::Baton<> onRequestReceived; + folly::Baton<> clientFinished; + folly::Baton<> serverFinished; +}; + +using namespace ::testing; + +constexpr std::chrono::milliseconds timeout{100}; + +class LockstepAsyncHandler : public rsocket::RSocketResponder { + LockstepBatons& batons_; + Sequence& subscription_seq_; + folly::ScopedEventBaseThread worker_; + + public: + LockstepAsyncHandler(LockstepBatons& batons, Sequence& subscription_seq) + : batons_(batons), subscription_seq_(subscription_seq) {} + + std::shared_ptr> handleRequestStream(Payload p, StreamId) + override { + EXPECT_EQ(p.moveDataToString(), "initial"); + + auto step1 = Flowable::empty()->doOnComplete([this]() { + this->batons_.onRequestReceived.timed_wait(timeout); + VLOG(3) << "SERVER: sending onNext(foo)"; + }); + + auto step2 = Flowable<>::justOnce(Payload("foo"))->doOnComplete([this]() { + this->batons_.onCancelSent.timed_wait(timeout); + this->batons_.onCancelReceivedToserver.timed_wait(timeout); + VLOG(3) << "SERVER: sending onNext(bar)"; + }); + + auto step3 = Flowable<>::justOnce(Payload("bar"))->doOnComplete([this]() { + this->batons_.onSecondPayloadSent.post(); + VLOG(3) << "SERVER: sending onComplete()"; + }); + + auto generator = Flowable<>::concat(step1, step2, step3) + ->doOnComplete([this]() { + VLOG(3) << "SERVER: posting serverFinished"; + this->batons_.serverFinished.post(); + }) + ->subscribeOn(*worker_.getEventBase()); + + // checked once the subscription is destroyed + auto requestCheckpoint = std::make_shared>(); + EXPECT_CALL(*requestCheckpoint, Call(2)) + .InSequence(this->subscription_seq_) + .WillOnce(Invoke([=](auto n) { + VLOG(3) << "SERVER: got request(" << n << ")"; + EXPECT_EQ(n, 2); + this->batons_.onRequestReceived.post(); + })); + + auto cancelCheckpoint = std::make_shared>(); + EXPECT_CALL(*cancelCheckpoint, Call()) + .InSequence(this->subscription_seq_) + .WillOnce(Invoke([=] { + VLOG(3) << "SERVER: received cancel()"; + this->batons_.onCancelReceivedToclient.post(); + this->batons_.onCancelReceivedToserver.post(); + })); + + return generator + ->doOnRequest( + [requestCheckpoint](auto n) { requestCheckpoint->Call(n); }) + ->doOnCancel([cancelCheckpoint] { cancelCheckpoint->Call(); }); + } +}; + +TEST(RequestStreamTest, OperationsAfterCancel) { + LockstepBatons batons; + Sequence server_seq; + Sequence client_seq; + + auto server = + makeServer(std::make_shared(batons, server_seq)); + folly::ScopedEventBaseThread worker; + auto client = makeClient(worker.getEventBase(), *server->listeningPort()); + auto requester = client->getRequester(); + + auto subscriber_mock = std::make_shared< + testing::StrictMock>>(0); + + std::shared_ptr subscription; + EXPECT_CALL(*subscriber_mock, onSubscribe_(_)) + .InSequence(client_seq) + .WillOnce(Invoke([&](auto s) { + VLOG(3) << "CLIENT: got onSubscribe(), sending request(2)"; + EXPECT_NE(s, nullptr); + subscription = s; + subscription->request(2); + })); + EXPECT_CALL(*subscriber_mock, onNext_("foo")) + .InSequence(client_seq) + .WillOnce(Invoke([&](auto) { + EXPECT_NE(subscription, nullptr); + VLOG(3) << "CLIENT: got onNext(foo), sending cancel()"; + subscription->cancel(); + batons.onCancelSent.post(); + batons.onCancelReceivedToclient.timed_wait(timeout); + batons.onSecondPayloadSent.timed_wait(timeout); + batons.clientFinished.post(); + })); + + // shouldn't receive 'bar', we canceled syncronously with the Subscriber + // had 'cancel' been called in a different thread with no synchronization, + // the client's Subscriber _could_ have received 'bar' + + VLOG(3) << "RUNNER: doing requestStream()"; + requester->requestStream(Payload("initial")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(subscriber_mock); + + batons.clientFinished.timed_wait(timeout); + batons.serverFinished.timed_wait(timeout); + VLOG(3) << "RUNNER: finished!"; +} diff --git a/rsocket/test/Test.cpp b/rsocket/test/Test.cpp new file mode 100644 index 000000000..512a2281e --- /dev/null +++ b/rsocket/test/Test.cpp @@ -0,0 +1,24 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +int main(int argc, char** argv) { + FLAGS_logtostderr = true; + testing::InitGoogleMock(&argc, argv); + folly::init(&argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/rsocket/test/WarmResumeManagerTest.cpp b/rsocket/test/WarmResumeManagerTest.cpp new file mode 100644 index 000000000..861860595 --- /dev/null +++ b/rsocket/test/WarmResumeManagerTest.cpp @@ -0,0 +1,345 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "rsocket/framing/Frame.h" +#include "rsocket/framing/FrameSerializer.h" +#include "rsocket/framing/FrameTransportImpl.h" +#include "rsocket/internal/WarmResumeManager.h" +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/test/test_utils/MockStats.h" + +using namespace ::testing; +using namespace ::rsocket; + +namespace { + +class FrameTransportMock : public FrameTransportImpl { + public: + FrameTransportMock() + : FrameTransportImpl(std::make_unique()) {} + + MOCK_METHOD1(outputFrameOrDrop_, void(std::unique_ptr&)); + + void outputFrameOrDrop(std::unique_ptr frame) override { + outputFrameOrDrop_(frame); + } +}; + +} // namespace + +class WarmResumeManagerTest : public Test { + protected: + std::unique_ptr frameSerializer_{ + FrameSerializer::createFrameSerializer(ProtocolVersion(1, 0))}; +}; + +TEST_F(WarmResumeManagerTest, EmptyCache) { + WarmResumeManager cache(RSocketStats::noop()); + FrameTransportMock transport; + + EXPECT_CALL(transport, outputFrameOrDrop_(_)).Times(0); + + EXPECT_EQ(0, cache.firstSentPosition()); + EXPECT_EQ(0, cache.lastSentPosition()); + EXPECT_TRUE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(1)); + cache.sendFramesFromPosition(0, transport); + + cache.resetUpToPosition(0); + + EXPECT_EQ(0, cache.firstSentPosition()); + EXPECT_EQ(0, cache.lastSentPosition()); + EXPECT_TRUE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(1)); + cache.sendFramesFromPosition(0, transport); +} + +TEST_F(WarmResumeManagerTest, OneFrame) { + WarmResumeManager cache(RSocketStats::noop()); + FrameTransportMock transport; + + auto frame1 = frameSerializer_->serializeOut(Frame_CANCEL(0)); + const auto frame1Size = frame1->computeChainDataLength(); + + cache.trackSentFrame(*frame1, FrameType::CANCEL, 1, 0); + + EXPECT_EQ(0, cache.firstSentPosition()); + EXPECT_EQ((ResumePosition)frame1Size, cache.lastSentPosition()); + EXPECT_TRUE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size)); + + cache.resetUpToPosition(0); + + EXPECT_EQ(0, cache.firstSentPosition()); + EXPECT_EQ((ResumePosition)frame1Size, cache.lastSentPosition()); + EXPECT_TRUE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size)); + + EXPECT_FALSE(cache.isPositionAvailable(frame1Size - 1)); // misaligned + + EXPECT_CALL(transport, outputFrameOrDrop_(_)) + .WillOnce(Invoke([=](std::unique_ptr& buf) { + EXPECT_EQ(frame1Size, buf->computeChainDataLength()); + })); + + cache.sendFramesFromPosition(0, transport); + cache.sendFramesFromPosition(frame1Size, transport); + + cache.resetUpToPosition(frame1Size); + + EXPECT_EQ((ResumePosition)frame1Size, cache.firstSentPosition()); + EXPECT_EQ((ResumePosition)frame1Size, cache.lastSentPosition()); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size)); + + cache.sendFramesFromPosition(frame1Size, transport); +} + +TEST_F(WarmResumeManagerTest, TwoFrames) { + WarmResumeManager cache(RSocketStats::noop()); + FrameTransportMock transport; + + auto frame1 = frameSerializer_->serializeOut(Frame_CANCEL(0)); + const auto frame1Size = frame1->computeChainDataLength(); + + auto frame2 = frameSerializer_->serializeOut(Frame_REQUEST_N(0, 2)); + const auto frame2Size = frame2->computeChainDataLength(); + + cache.trackSentFrame(*frame1, FrameType::CANCEL, 1, 0); + cache.trackSentFrame(*frame2, FrameType::REQUEST_N, 1, 0); + + EXPECT_EQ(0, cache.firstSentPosition()); + EXPECT_EQ( + (ResumePosition)(frame1Size + frame2Size), cache.lastSentPosition()); + EXPECT_TRUE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size + frame2Size)); + + EXPECT_CALL(transport, outputFrameOrDrop_(_)) + .WillOnce(Invoke([&](std::unique_ptr& buf) { + EXPECT_EQ(frame1Size, buf->computeChainDataLength()); + })) + .WillOnce(Invoke([&](std::unique_ptr& buf) { + EXPECT_EQ(frame2Size, buf->computeChainDataLength()); + })); + + cache.sendFramesFromPosition(0, transport); + + cache.resetUpToPosition(frame1Size); + + EXPECT_EQ((ResumePosition)frame1Size, cache.firstSentPosition()); + EXPECT_EQ( + (ResumePosition)(frame1Size + frame2Size), cache.lastSentPosition()); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size)); + EXPECT_TRUE(cache.isPositionAvailable(frame1Size + frame2Size)); + + EXPECT_CALL(transport, outputFrameOrDrop_(_)) + .WillOnce(Invoke([&](std::unique_ptr& buf) { + EXPECT_EQ(frame2Size, buf->computeChainDataLength()); + })); + + cache.sendFramesFromPosition(frame1Size, transport); +} + +TEST_F(WarmResumeManagerTest, Stats) { + auto stats = std::make_shared>(); + WarmResumeManager cache(stats); + + auto frame1 = frameSerializer_->serializeOut(Frame_CANCEL(0)); + auto frame1Size = frame1->computeChainDataLength(); + EXPECT_CALL(*stats, resumeBufferChanged(1, frame1Size)); + cache.trackSentFrame(*frame1, FrameType::CANCEL, 1, 0); + + auto frame2 = frameSerializer_->serializeOut(Frame_REQUEST_N(0, 3)); + auto frame2Size = frame2->computeChainDataLength(); + EXPECT_CALL(*stats, resumeBufferChanged(1, frame2Size)).Times(2); + cache.trackSentFrame(*frame2, FrameType::REQUEST_N, 1, 0); + cache.trackSentFrame(*frame2, FrameType::REQUEST_N, 1, 0); + + EXPECT_CALL(*stats, resumeBufferChanged(-1, -frame1Size)); + cache.resetUpToPosition(frame1Size); + EXPECT_CALL(*stats, resumeBufferChanged(-2, -2 * frame2Size)); +} + +TEST_F(WarmResumeManagerTest, EvictFIFO) { + auto frame = frameSerializer_->serializeOut(Frame_CANCEL(0)); + const auto frameSize = frame->computeChainDataLength(); + + // construct cache with capacity of 2 frameSize + WarmResumeManager cache(RSocketStats::noop(), frameSize * 2); + + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + + // first 2 frames should be present in the cache + EXPECT_TRUE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 2)); + + // add third frame, and this frame should evict first frame + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 3)); + + // cache size should also be adjusted by resetUpToPosition + cache.resetUpToPosition(frameSize * 2); + EXPECT_FALSE(cache.isPositionAvailable(frameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 3)); + + // add fourth frame, this should evict second frame + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 3)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 4)); + + // create a huge frame and try to cache it + auto hugeFrame = folly::IOBuf::createChain(frameSize * 3, frameSize * 3); + for (int i = 0; i < 3; i++) { + hugeFrame->appendChain(frame->clone()); + } + auto hugeFrameSize = hugeFrame->computeChainDataLength(); + EXPECT_EQ(hugeFrameSize, frameSize * 3); + cache.trackSentFrame(*hugeFrame, FrameType::CANCEL, 1, 0); + + // cache should be cleared + EXPECT_EQ(cache.size(), (size_t)0); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 3)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 4)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 4 + hugeFrameSize)); + EXPECT_EQ( + (ResumePosition)(frameSize * 4 + hugeFrameSize), + cache.firstSentPosition()); + EXPECT_EQ( + (ResumePosition)(frameSize * 4 + hugeFrameSize), + cache.lastSentPosition()); + + // caching small frames shouldn't be affected + // Adding one small frame to cache + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_EQ(cache.size(), frameSize); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 3)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 4)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 4 + hugeFrameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 5 + hugeFrameSize)); + EXPECT_EQ( + (ResumePosition)(frameSize * 4 + hugeFrameSize), + cache.firstSentPosition()); + EXPECT_EQ( + (ResumePosition)(frameSize * 5 + hugeFrameSize), + cache.lastSentPosition()); + + // Adding second small frame to cache + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_EQ(cache.size(), frameSize * 2); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 3)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 4)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 4 + hugeFrameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 5 + hugeFrameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 6 + hugeFrameSize)); + EXPECT_EQ( + (ResumePosition)(frameSize * 4 + hugeFrameSize), + cache.firstSentPosition()); + EXPECT_EQ( + (ResumePosition)(frameSize * 6 + hugeFrameSize), + cache.lastSentPosition()); + + // Adding third small frame to cache. Should result in first frame eviction + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_EQ(cache.size(), frameSize * 2); + EXPECT_FALSE(cache.isPositionAvailable(0)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 2)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 3)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 4)); + EXPECT_FALSE(cache.isPositionAvailable(frameSize * 4 + hugeFrameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 5 + hugeFrameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 6 + hugeFrameSize)); + EXPECT_TRUE(cache.isPositionAvailable(frameSize * 7 + hugeFrameSize)); + EXPECT_EQ( + (ResumePosition)(frameSize * 5 + hugeFrameSize), + cache.firstSentPosition()); + EXPECT_EQ( + (ResumePosition)(frameSize * 7 + hugeFrameSize), + cache.lastSentPosition()); +} + +TEST_F(WarmResumeManagerTest, EvictStats) { + auto stats = std::make_shared>(); + + auto frame = frameSerializer_->serializeOut(Frame_CANCEL(0)); + const auto frameSize = frame->computeChainDataLength(); + + // construct cache with capacity of 2 frameSize + WarmResumeManager cache(stats, frameSize * 2); + + { + InSequence dummy; + // Two added + EXPECT_CALL(*stats, resumeBufferChanged(1, frameSize)); + EXPECT_CALL(*stats, resumeBufferChanged(1, frameSize)); + // One evicted, one added + EXPECT_CALL(*stats, resumeBufferChanged(-1, -frameSize)); + EXPECT_CALL(*stats, resumeBufferChanged(1, frameSize)); + // Destruction + EXPECT_CALL(*stats, resumeBufferChanged(-2, -frameSize * 2)); + } + + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + + EXPECT_EQ(frameSize * 2, cache.size()); +} + +TEST_F(WarmResumeManagerTest, PositionSmallFrame) { + auto frame = frameSerializer_->serializeOut(Frame_CANCEL(0)); + const auto frameSize = frame->computeChainDataLength(); + + // Cache is larger than frame + WarmResumeManager cache(RSocketStats::noop(), frameSize * 2); + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_EQ( + frame->computeChainDataLength(), + static_cast(cache.lastSentPosition())); +} + +TEST_F(WarmResumeManagerTest, PositionLargeFrame) { + auto frame = frameSerializer_->serializeOut(Frame_CANCEL(0)); + const auto frameSize = frame->computeChainDataLength(); + + // Cache is smaller than frame + WarmResumeManager cache(RSocketStats::noop(), frameSize / 2); + cache.trackSentFrame(*frame, FrameType::CANCEL, 1, 0); + EXPECT_EQ( + frame->computeChainDataLength(), + static_cast(cache.lastSentPosition())); +} diff --git a/rsocket/test/WarmResumptionTest.cpp b/rsocket/test/WarmResumptionTest.cpp new file mode 100644 index 000000000..2481456db --- /dev/null +++ b/rsocket/test/WarmResumptionTest.cpp @@ -0,0 +1,173 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "RSocketTests.h" + +#include "rsocket/test/handlers/HelloServiceHandler.h" +#include "rsocket/test/handlers/HelloStreamRequestHandler.h" + +#include "yarpl/flowable/TestSubscriber.h" + +using namespace rsocket; +using namespace rsocket::tests; +using namespace rsocket::tests::client_server; +using namespace yarpl::flowable; + +TEST(WarmResumptionTest, SuccessfulResumption) { + folly::ScopedEventBaseThread worker; + auto server = makeResumableServer(std::make_shared()); + auto client = + makeWarmResumableClient(worker.getEventBase(), *server->listeningPort()); + auto ts = TestSubscriber::create(7 /* initialRequestN */); + client->getRequester() + ->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + // Wait for a few frames before disconnecting. + while (ts->getValueCount() < 3) { + std::this_thread::yield(); + } + auto result = + client->disconnect(std::runtime_error("Test triggered disconnect")) + .thenValue([&](auto&&) { return client->resume(); }); + EXPECT_NO_THROW(std::move(result).get()); + ts->request(3); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); +} + +// Verify after resumption the client is able to consume stream +// from within onError() context +TEST(WarmResumptionTest, FailedResumption1) { + folly::ScopedEventBaseThread worker; + auto server = + makeServer(std::make_shared()); + auto listeningPort = *server->listeningPort(); + auto client = makeWarmResumableClient(worker.getEventBase(), listeningPort); + auto ts = TestSubscriber::create(7 /* initialRequestN */); + client->getRequester() + ->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + // Wait for a few frames before disconnecting. + while (ts->getValueCount() < 3) { + std::this_thread::yield(); + } + + client->disconnect(std::runtime_error("Test triggered disconnect")) + .thenValue([&](auto&&) { return client->resume(); }) + .thenValue( + [](auto&&) { FAIL() << "Resumption succeeded when it should not"; }) + .thenError([listeningPort, &worker](folly::exception_wrapper) { + folly::ScopedEventBaseThread worker2; + auto newClient = + makeWarmResumableClient(worker2.getEventBase(), listeningPort); + auto newTs = + TestSubscriber::create(6 /* initialRequestN */); + newClient->getRequester() + ->requestStream(Payload("Alice")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(newTs); + while (newTs->getValueCount() < 3) { + std::this_thread::yield(); + } + newTs->request(2); + newTs->request(2); + newTs->awaitTerminalEvent(); + newTs->assertSuccess(); + newTs->assertValueCount(10); + }) + .wait(); +} + +// Verify after resumption, the client is able to consume stream +// from within and outside of onError() context +TEST(WarmResumptionTest, FailedResumption2) { + folly::ScopedEventBaseThread worker; + folly::ScopedEventBaseThread worker2; + auto server = + makeServer(std::make_shared()); + auto listeningPort = *server->listeningPort(); + auto client = makeWarmResumableClient(worker.getEventBase(), listeningPort); + auto ts = TestSubscriber::create(7 /* initialRequestN */); + client->getRequester() + ->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + // Wait for a few frames before disconnecting. + while (ts->getValueCount() < 3) { + std::this_thread::yield(); + } + + auto newTs = TestSubscriber::create(6 /* initialRequestN */); + std::shared_ptr newClient; + + client->disconnect(std::runtime_error("Test triggered disconnect")) + .thenValue([&](auto&&) { return client->resume(); }) + .thenValue( + [](auto&&) { FAIL() << "Resumption succeeded when it should not"; }) + .thenError([listeningPort, newTs, &newClient, &worker2]( + folly::exception_wrapper) { + newClient = + makeWarmResumableClient(worker2.getEventBase(), listeningPort); + newClient->getRequester() + ->requestStream(Payload("Alice")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(newTs); + while (newTs->getValueCount() < 3) { + std::this_thread::yield(); + } + newTs->request(2); + }) + .wait(); + newTs->request(2); + newTs->awaitTerminalEvent(); + newTs->assertSuccess(); + newTs->assertValueCount(10); +} + +// Verify resumption when the stateMachine and Transport run on different +// EventBase +TEST(WarmResumptionTest, DifferentEvb) { + folly::ScopedEventBaseThread transportWorker; + folly::ScopedEventBaseThread SMWorker; + auto server = makeResumableServer(std::make_shared()); + auto client = makeWarmResumableClient( + transportWorker.getEventBase(), + *server->listeningPort(), + nullptr, // connectionEvents + SMWorker.getEventBase()); + auto ts = TestSubscriber::create(7 /* initialRequestN */); + client->getRequester() + ->requestStream(Payload("Bob")) + ->map([](auto p) { return p.moveDataToString(); }) + ->subscribe(ts); + // Wait for a few frames before disconnecting. + while (ts->getValueCount() < 3) { + std::this_thread::yield(); + } + auto result = + client->disconnect(std::runtime_error("Test triggered disconnect")) + .thenValue([&](auto&&) { return client->resume(); }); + EXPECT_NO_THROW(std::move(result).get()); + ts->request(3); + ts->awaitTerminalEvent(); + ts->assertSuccess(); + ts->assertValueCount(10); +} diff --git a/test/framing/FrameTest.cpp b/rsocket/test/framing/FrameTest.cpp similarity index 73% rename from test/framing/FrameTest.cpp rename to rsocket/test/framing/FrameTest.cpp index c54746ecd..3efb79449 100644 --- a/test/framing/FrameTest.cpp +++ b/rsocket/test/framing/FrameTest.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include @@ -8,27 +20,13 @@ #include "rsocket/framing/Frame.h" #include "rsocket/framing/FrameSerializer.h" -using namespace ::testing; using namespace ::rsocket; -// TODO(stupaq): tests with malformed frames - -template -Frame reserialize_resume(bool resumable, Args... args) { - Frame givenFrame, newFrame; - givenFrame = Frame(std::forward(args)...); - auto frameSerializer = FrameSerializer::createCurrentVersion(); - EXPECT_TRUE(frameSerializer->deserializeFrom( - newFrame, - frameSerializer->serializeOut(std::move(givenFrame), resumable), - resumable)); - return newFrame; -} - template Frame reserialize(Args... args) { Frame givenFrame = Frame(std::forward(args)...); - auto frameSerializer = FrameSerializer::createCurrentVersion(); + auto frameSerializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); auto serializedFrame = frameSerializer->serializeOut(std::move(givenFrame)); Frame newFrame; EXPECT_TRUE( @@ -42,9 +40,9 @@ void expectHeader( FrameFlags flags, StreamId streamId, const Frame& frame) { - EXPECT_EQ(type, frame.header_.type_); - EXPECT_EQ(streamId, frame.header_.streamId_); - EXPECT_EQ(flags, frame.header_.flags_); + EXPECT_EQ(type, frame.header_.type); + EXPECT_EQ(streamId, frame.header_.streamId); + EXPECT_EQ(flags, frame.header_.flags); } TEST(FrameTest, Frame_REQUEST_STREAM) { @@ -58,8 +56,8 @@ TEST(FrameTest, Frame_REQUEST_STREAM) { expectHeader(FrameType::REQUEST_STREAM, flags, streamId, frame); EXPECT_EQ(requestN, frame.requestN_); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_REQUEST_CHANNEL) { @@ -73,8 +71,8 @@ TEST(FrameTest, Frame_REQUEST_CHANNEL) { expectHeader(FrameType::REQUEST_CHANNEL, flags, streamId, frame); EXPECT_EQ(requestN, frame.requestN_); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_REQUEST_N) { @@ -82,14 +80,14 @@ TEST(FrameTest, Frame_REQUEST_N) { uint32_t requestN = 24; auto frame = reserialize(streamId, requestN); - expectHeader(FrameType::REQUEST_N, FrameFlags::EMPTY, streamId, frame); + expectHeader(FrameType::REQUEST_N, FrameFlags::EMPTY_, streamId, frame); EXPECT_EQ(requestN, frame.requestN_); } TEST(FrameTest, Frame_CANCEL) { uint32_t streamId = 42; auto frame = reserialize(streamId); - expectHeader(FrameType::CANCEL, FrameFlags::EMPTY, streamId, frame); + expectHeader(FrameType::CANCEL, FrameFlags::EMPTY_, streamId, frame); } TEST(FrameTest, Frame_PAYLOAD) { @@ -101,8 +99,8 @@ TEST(FrameTest, Frame_PAYLOAD) { streamId, flags, Payload(data->clone(), metadata->clone())); expectHeader(FrameType::PAYLOAD, flags, streamId, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_PAYLOAD_NoMeta) { @@ -114,7 +112,7 @@ TEST(FrameTest, Frame_PAYLOAD_NoMeta) { expectHeader(FrameType::PAYLOAD, flags, streamId, frame); EXPECT_FALSE(frame.payload_.metadata); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_ERROR) { @@ -128,22 +126,8 @@ TEST(FrameTest, Frame_ERROR) { expectHeader(FrameType::ERROR, flags, streamId, frame); EXPECT_EQ(errorCode, frame.errorCode_); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); -} - -TEST(FrameTest, Frame_KEEPALIVE_resume) { - uint32_t streamId = 0; - ResumePosition position = 101; - auto flags = FrameFlags::KEEPALIVE_RESPOND; - auto data = folly::IOBuf::copyBuffer("424242"); - auto frame = - reserialize_resume(true, flags, position, data->clone()); - - expectHeader( - FrameType::KEEPALIVE, FrameFlags::KEEPALIVE_RESPOND, streamId, frame); - EXPECT_EQ(position, frame.position_); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.data_)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_KEEPALIVE) { @@ -151,23 +135,22 @@ TEST(FrameTest, Frame_KEEPALIVE) { ResumePosition position = 101; auto flags = FrameFlags::KEEPALIVE_RESPOND; auto data = folly::IOBuf::copyBuffer("424242"); - auto frame = reserialize_resume( - false, flags, position, data->clone()); + auto frame = reserialize(flags, position, data->clone()); expectHeader( FrameType::KEEPALIVE, FrameFlags::KEEPALIVE_RESPOND, streamId, frame); // Default position - auto currProtVersion = FrameSerializer::getCurrentProtocolVersion(); + auto currProtVersion = ProtocolVersion::Latest; if (currProtVersion == ProtocolVersion(0, 1)) { EXPECT_EQ(0, frame.position_); } else if (currProtVersion == ProtocolVersion(1, 0)) { EXPECT_EQ(position, frame.position_); } - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.data_)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.data_)); } TEST(FrameTest, Frame_SETUP) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; uint16_t versionMajor = 4; uint16_t versionMinor = 5; uint32_t keepaliveTime = Frame_SETUP::kMaxKeepaliveTime; @@ -194,11 +177,11 @@ TEST(FrameTest, Frame_SETUP) { EXPECT_EQ(ResumeIdentificationToken(), frame.token_); EXPECT_EQ("md", frame.metadataMimeType_); EXPECT_EQ("d", frame.dataMimeType_); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_SETUP_resume) { - FrameFlags flags = FrameFlags::EMPTY | FrameFlags::RESUME_ENABLE; + FrameFlags flags = FrameFlags::EMPTY_ | FrameFlags::RESUME_ENABLE; uint16_t versionMajor = 0; uint16_t versionMinor = 0; uint32_t keepaliveTime = Frame_SETUP::kMaxKeepaliveTime; @@ -224,11 +207,11 @@ TEST(FrameTest, Frame_SETUP_resume) { EXPECT_EQ(token, frame.token_); EXPECT_EQ("md", frame.metadataMimeType_); EXPECT_EQ("d", frame.dataMimeType_); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_LEASE) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; uint32_t ttl = Frame_LEASE::kMaxTtl; auto numberOfRequests = Frame_LEASE::kMaxNumRequests; auto frame = reserialize(ttl, numberOfRequests); @@ -247,8 +230,8 @@ TEST(FrameTest, Frame_REQUEST_RESPONSE) { streamId, flags, Payload(data->clone(), metadata->clone())); expectHeader(FrameType::REQUEST_RESPONSE, flags, streamId, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_REQUEST_FNF) { @@ -260,8 +243,8 @@ TEST(FrameTest, Frame_REQUEST_FNF) { streamId, flags, Payload(data->clone(), metadata->clone())); expectHeader(FrameType::REQUEST_FNF, flags, streamId, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.payload_.metadata)); - EXPECT_TRUE(folly::IOBufEqual()(*data, *frame.payload_.data)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.payload_.metadata)); + EXPECT_TRUE(folly::IOBufEqualTo()(*data, *frame.payload_.data)); } TEST(FrameTest, Frame_METADATA_PUSH) { @@ -270,11 +253,11 @@ TEST(FrameTest, Frame_METADATA_PUSH) { auto frame = reserialize(metadata->clone()); expectHeader(FrameType::METADATA_PUSH, flags, 0, frame); - EXPECT_TRUE(folly::IOBufEqual()(*metadata, *frame.metadata_)); + EXPECT_TRUE(folly::IOBufEqualTo()(*metadata, *frame.metadata_)); } TEST(FrameTest, Frame_RESUME) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; uint16_t versionMajor = 4; uint16_t versionMinor = 5; ResumeIdentificationToken token = ResumeIdentificationToken::generateNew(); @@ -298,10 +281,24 @@ TEST(FrameTest, Frame_RESUME) { } TEST(FrameTest, Frame_RESUME_OK) { - FrameFlags flags = FrameFlags::EMPTY; + FrameFlags flags = FrameFlags::EMPTY_; ResumePosition position = 6; auto frame = reserialize(position); expectHeader(FrameType::RESUME_OK, flags, 0, frame); EXPECT_EQ(position, frame.position_); } + +TEST(FrameTest, Frame_PreallocatedFrameLengthField) { + uint32_t streamId = 42; + FrameFlags flags = FrameFlags::COMPLETE; + auto data = folly::IOBuf::copyBuffer("424242"); + auto frameSerializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + frameSerializer->preallocateFrameSizeField() = true; + + auto frame = Frame_PAYLOAD(streamId, flags, Payload(data->clone())); + auto serializedFrame = frameSerializer->serializeOut(std::move(frame)); + + EXPECT_LT(0, serializedFrame->headroom()); +} diff --git a/rsocket/test/framing/FrameTransportTest.cpp b/rsocket/test/framing/FrameTransportTest.cpp new file mode 100644 index 000000000..48017d682 --- /dev/null +++ b/rsocket/test/framing/FrameTransportTest.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "rsocket/framing/FrameTransportImpl.h" +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/test/test_utils/MockFrameProcessor.h" + +using namespace rsocket; +using namespace testing; + +namespace { + +/* + * Compare a `const folly::IOBuf&` against a `const std::string&`. + */ +MATCHER_P(IOBufStringEq, s, "") { + return folly::IOBufEqualTo()(*arg, *folly::IOBuf::copyBuffer(s)); +} + +} // namespace + +TEST(FrameTransport, Close) { + auto connection = std::make_unique>(); + EXPECT_CALL(*connection, setInput_(_)); + + auto transport = std::make_shared(std::move(connection)); + transport->setFrameProcessor( + std::make_shared>()); + transport->close(); +} + +TEST(FrameTransport, SimpleNoQueue) { + auto connection = std::make_unique>(); + EXPECT_CALL(*connection, setInput_(_)); + + EXPECT_CALL(*connection, send_(IOBufStringEq("Hello"))); + EXPECT_CALL(*connection, send_(IOBufStringEq("World"))); + + auto transport = std::make_shared(std::move(connection)); + + transport->setFrameProcessor( + std::make_shared>()); + + transport->outputFrameOrDrop(folly::IOBuf::copyBuffer("Hello")); + transport->outputFrameOrDrop(folly::IOBuf::copyBuffer("World")); + + transport->close(); +} + +TEST(FrameTransport, InputSendsError) { + auto connection = + std::make_unique>([](auto input) { + auto subscription = + std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + input->onSubscribe(std::move(subscription)); + input->onError(std::runtime_error("Oops")); + }); + + auto transport = std::make_shared(std::move(connection)); + + auto processor = std::make_shared>(); + EXPECT_CALL(*processor, onTerminal_(_)); + + transport->setFrameProcessor(std::move(processor)); + transport->close(); +} diff --git a/rsocket/test/framing/FramedReaderTest.cpp b/rsocket/test/framing/FramedReaderTest.cpp new file mode 100644 index 000000000..d3b6f9e0c --- /dev/null +++ b/rsocket/test/framing/FramedReaderTest.cpp @@ -0,0 +1,113 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "rsocket/framing/FramedReader.h" +#include "rsocket/test/test_utils/MockDuplexConnection.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::mocks; + +TEST(FramedReader, TinyFrame) { + auto version = std::make_shared(ProtocolVersion::Latest); + auto reader = std::make_shared(version); + + // Not using hex string-literal as std::string ctor hits '\x00' and stops + // reading. + auto buf = folly::IOBuf::createCombined(4); + buf->append(4); + buf->writableData()[0] = '\x00'; + buf->writableData()[1] = '\x00'; + buf->writableData()[2] = '\x00'; + buf->writableData()[3] = '\x02'; + + reader->onSubscribe(yarpl::flowable::Subscription::create()); + reader->onNext(std::move(buf)); + + auto subscriber = std::make_shared< + StrictMock>>>(); + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onError_(_)); + + reader->setInput(subscriber); + subscriber->awaitTerminalEvent(); + reader->onComplete(); +} + +TEST(FramedReader, CantDetectVersion) { + auto version = std::make_shared(ProtocolVersion::Unknown); + auto reader = std::make_shared(version); + + auto buf = folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP"); + + reader->onSubscribe(yarpl::flowable::Subscription::create()); + reader->onNext(std::move(buf)); + + auto subscriber = std::make_shared< + StrictMock>>>(); + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onError_(_)); + + reader->setInput(subscriber); + subscriber->awaitTerminalEvent(); + reader->onComplete(); +} + +TEST(FramedReader, SubscriberCompleteAfterError) { + auto version = std::make_shared(ProtocolVersion::Latest); + auto reader = std::make_shared(version); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + reader->onSubscribe(subscription); + + auto subscriber = std::make_shared< + StrictMock>>>(); + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onError_(_)) + .WillOnce(Invoke([](folly::exception_wrapper ew) { + EXPECT_EQ(ew.get_exception()->what(), std::string{"Oops"}); + })); + + reader->setInput(subscriber); + reader->error("Oops"); + reader->onComplete(); +} + +TEST(FramedReader, SubscriberErrorAfterError) { + auto version = std::make_shared(ProtocolVersion::Latest); + auto reader = std::make_shared(version); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + reader->onSubscribe(subscription); + + auto subscriber = std::make_shared< + StrictMock>>>(); + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onError_(_)) + .WillOnce(Invoke([](folly::exception_wrapper ew) { + EXPECT_EQ(ew.get_exception()->what(), std::string{"Oops"}); + })); + + reader->setInput(subscriber); + reader->error("Oops"); + reader->onError(std::runtime_error{"Not oops"}); +} diff --git a/rsocket/test/framing/FramerTest.cpp b/rsocket/test/framing/FramerTest.cpp new file mode 100644 index 000000000..81fe97ffd --- /dev/null +++ b/rsocket/test/framing/FramerTest.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/framing/Framer.h" +#include +#include + +using namespace rsocket; +using namespace testing; + +class FramerMock : public Framer { + public: + explicit FramerMock(ProtocolVersion protocolVersion = ProtocolVersion::Latest) + : Framer(protocolVersion, true) {} + + MOCK_METHOD1(error, void(const char*)); + MOCK_METHOD1(onFrame_, void(std::unique_ptr&)); + + void onFrame(std::unique_ptr frame) override { + onFrame_(frame); + } +}; + +MATCHER_P(isIOBuffEq, n, "") { + return folly::IOBufEqualTo()(arg, n); +} + +TEST(Framer, TinyFrame) { + FramerMock framer; + + // Not using hex string-literal as std::string ctor hits '\x00' and stops + // reading. + auto buf = folly::IOBuf::createCombined(4); + buf->append(4); + buf->writableData()[0] = '\x00'; + buf->writableData()[1] = '\x00'; + buf->writableData()[2] = '\x00'; + buf->writableData()[3] = '\x02'; + + EXPECT_CALL(framer, error(_)); + framer.addFrameChunk(std::move(buf)); +} + +TEST(Framer, CantDetectVersion) { + FramerMock framer(ProtocolVersion::Unknown); + + EXPECT_CALL(framer, error(_)); + + auto buf = folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP"); + framer.addFrameChunk(std::move(buf)); +} + +TEST(Framer, ParseFrame) { + FramerMock framer; + + auto buf = folly::IOBuf::copyBuffer("ABCDEFGHIJKLMNOP"); + EXPECT_CALL(framer, onFrame_(Pointee(isIOBuffEq(*buf)))); + + framer.addFrameChunk(framer.prependSize(std::move(buf))); +} diff --git a/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 b/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 new file mode 100644 index 000000000..a4cdbe62e Binary files /dev/null and b/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000000,sig_11,src_000000,op_havoc,rep_2 differ diff --git a/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 b/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 new file mode 100644 index 000000000..0a4742b1d Binary files /dev/null and b/rsocket/test/fuzzer_testcases/frame_fuzzer/id_000001,sig_11,src_000000,op_havoc,rep_16 differ diff --git a/rsocket/test/fuzzers/frame_fuzzer.cpp b/rsocket/test/fuzzers/frame_fuzzer.cpp new file mode 100644 index 000000000..eb9d04efc --- /dev/null +++ b/rsocket/test/fuzzers/frame_fuzzer.cpp @@ -0,0 +1,119 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include + +#include +#include +#include + +#include "rsocket/ConnectionAcceptor.h" +#include "rsocket/DuplexConnection.h" +#include "rsocket/RSocketServer.h" + +struct FuzzerConnectionAcceptor : rsocket::ConnectionAcceptor { + void start(rsocket::OnDuplexConnectionAccept func_) override { + VLOG(1) << "FuzzerConnectionAcceptor::start()" << std::endl; + func = func_; + } + + void stop() override { + VLOG(1) << "FuzzerConnectionAcceptor::stop()" << std::endl; + } + + folly::Optional listeningPort() const override { + return 0; + } + + rsocket::OnDuplexConnectionAccept func; +}; + +struct FuzzerDuplexConnection : rsocket::DuplexConnection { + using Subscriber = rsocket::DuplexConnection::Subscriber; + + FuzzerDuplexConnection() {} + + void setInput(std::shared_ptr sub) override { + VLOG(1) << "FuzzerDuplexConnection::setInput()" << std::endl; + input_sub = sub; + } + + void send(std::unique_ptr buf) override { + VLOG(1) << "FuzzerDuplexConnection::send(\"" + << folly::humanify(buf->moveToFbString()) << "\")" << std::endl; + } + + std::shared_ptr input_sub; +}; + +struct NoopSubscription : yarpl::flowable::Subscription { + void request(int64_t n) override { + VLOG(1) << "NoopSubscription::request(" << n << ")"; + } + void cancel() override { + VLOG(1) << "NoopSubscription::cancel()"; + } +}; + +struct NoopResponder : rsocket::RSocketResponder {}; + +std::string get_stdin() { + std::cin >> std::noskipws; + std::istream_iterator it(std::cin); + std::istream_iterator end; + std::string input(it, end); + return input; +} + +int main(int argc, char* argv[]) { + folly::init(&argc, &argv); + FLAGS_logtostderr = 1; + + folly::EventBase evb; + folly::EventBaseManager::get()->setEventBase(&evb, false); + + auto feed_conn = std::make_unique(); + auto acceptor = std::make_unique(); + + // grab references while we still own the duplex connection + auto& input_sub = feed_conn->input_sub; + auto& acceptor_func_ptr = acceptor->func; + + rsocket::RSocketServer server(std::move(acceptor)); + + auto responder = std::make_shared(); + server.start( + [responder](const rsocket::SetupParameters&) { return responder; }); + + CHECK(acceptor_func_ptr); + acceptor_func_ptr(std::move(feed_conn), evb); + evb.loopOnce(); + + CHECK(input_sub); + auto input_subscription = std::make_shared(); + input_sub->onSubscribe(input_subscription); + + std::string fuzz_input = get_stdin(); + std::unique_ptr buf = + folly::IOBuf::wrapBuffer(fuzz_input.c_str(), fuzz_input.size()); + + VLOG(1) << "fuzz input: " << std::endl; + VLOG(1) << folly::humanify(buf->cloneAsValue().moveToFbString()) << std::endl; + + input_sub->onNext(std::move(buf)); + evb.loopOnce(); + + return 0; +} diff --git a/rsocket/test/handlers/HelloServiceHandler.cpp b/rsocket/test/handlers/HelloServiceHandler.cpp new file mode 100644 index 000000000..aa7e07ca5 --- /dev/null +++ b/rsocket/test/handlers/HelloServiceHandler.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/test/handlers/HelloServiceHandler.h" +#include "rsocket/test/handlers/HelloStreamRequestHandler.h" + +namespace rsocket { +namespace tests { + +folly::Expected +HelloServiceHandler::onNewSetup(const SetupParameters&) { + return RSocketConnectionParams( + std::make_shared(), + RSocketStats::noop(), + connectionEvents_); +} + +void HelloServiceHandler::onNewRSocketState( + std::shared_ptr state, + ResumeIdentificationToken token) { + store_.lock()->insert({token, std::move(state)}); +} + +folly::Expected, RSocketException> +HelloServiceHandler::onResume(ResumeIdentificationToken token) { + auto itr = store_->find(token); + CHECK(itr != store_->end()); + return itr->second; +} + +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/handlers/HelloServiceHandler.h b/rsocket/test/handlers/HelloServiceHandler.h new file mode 100644 index 000000000..55cf86d53 --- /dev/null +++ b/rsocket/test/handlers/HelloServiceHandler.h @@ -0,0 +1,50 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "rsocket/RSocketServiceHandler.h" + +namespace rsocket { +namespace tests { + +// A minimal RSocketServiceHandler which supports resumption. + +class HelloServiceHandler : public RSocketServiceHandler { + public: + explicit HelloServiceHandler( + std::shared_ptr connEvents = nullptr) + : connectionEvents_(connEvents) {} + + folly::Expected onNewSetup( + const SetupParameters&) override; + + void onNewRSocketState( + std::shared_ptr state, + ResumeIdentificationToken token) override; + + folly::Expected, RSocketException> + onResume(ResumeIdentificationToken token) override; + + private: + std::shared_ptr connectionEvents_; + folly::Synchronized< + std::map>, + std::mutex> + store_; +}; + +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/handlers/HelloStreamRequestHandler.cpp b/rsocket/test/handlers/HelloStreamRequestHandler.cpp new file mode 100644 index 000000000..a4bf8a34d --- /dev/null +++ b/rsocket/test/handlers/HelloStreamRequestHandler.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "HelloStreamRequestHandler.h" +#include +#include +#include "yarpl/Flowable.h" + +using namespace yarpl::flowable; + +namespace rsocket { +namespace tests { +/// Handles a new inbound Stream requested by the other end. +std::shared_ptr> +HelloStreamRequestHandler::handleRequestStream( + rsocket::Payload request, + rsocket::StreamId) { + VLOG(3) << "HelloStreamRequestHandler.handleRequestStream " << request; + + // string from payload data + auto requestString = request.moveDataToString(); + + return Flowable<>::range(1, 10)->map( + [name = std::move(requestString)](int64_t v) { + return Payload(folly::to(v), "metadata"); + }); +} +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/handlers/HelloStreamRequestHandler.h b/rsocket/test/handlers/HelloStreamRequestHandler.h new file mode 100644 index 000000000..3aa48fb08 --- /dev/null +++ b/rsocket/test/handlers/HelloStreamRequestHandler.h @@ -0,0 +1,31 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/RSocketResponder.h" +#include "yarpl/Flowable.h" + +namespace rsocket { +namespace tests { + +class HelloStreamRequestHandler : public RSocketResponder { + public: + /// Handles a new inbound Stream requested by the other end. + std::shared_ptr> + handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) + override; +}; +} // namespace tests +} // namespace rsocket diff --git a/rsocket/test/internal/AllowanceTest.cpp b/rsocket/test/internal/AllowanceTest.cpp new file mode 100644 index 000000000..d2e77a36d --- /dev/null +++ b/rsocket/test/internal/AllowanceTest.cpp @@ -0,0 +1,52 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/internal/Allowance.h" +#include +#include + +using namespace ::rsocket; + +TEST(AllowanceTest, Finite) { + Allowance allowance; + + ASSERT_FALSE(allowance.canConsume(1)); + ASSERT_FALSE(allowance.tryConsume(1)); + + ASSERT_EQ(0U, allowance.add(1)); + ASSERT_FALSE(allowance.canConsume(2)); + ASSERT_TRUE(allowance.canConsume(1)); + ASSERT_TRUE(allowance.tryConsume(1)); + + ASSERT_EQ(0U, allowance.add(2)); + ASSERT_EQ(2U, allowance.add(1)); + ASSERT_EQ(3U, allowance.consumeAll()); + ASSERT_EQ(0U, allowance.consumeAll()); + + ASSERT_EQ(0U, allowance.add(2)); + ASSERT_FALSE(allowance.canConsume(3)); + ASSERT_FALSE(allowance.tryConsume(3)); + ASSERT_TRUE(allowance.canConsume(2)); + ASSERT_TRUE(allowance.tryConsume(2)); + ASSERT_FALSE(allowance.canConsume(1)); +} + +TEST(AllowanceTest, ConsumeWithLimit) { + Allowance allowance; + + ASSERT_EQ(0U, allowance.add(9)); + ASSERT_EQ(4U, allowance.consumeUpTo(4)); + ASSERT_EQ(1U, allowance.consumeUpTo(1)); + ASSERT_EQ(4U, allowance.consumeUpTo(100)); +} diff --git a/rsocket/test/internal/ConnectionSetTest.cpp b/rsocket/test/internal/ConnectionSetTest.cpp new file mode 100644 index 000000000..ea0e215aa --- /dev/null +++ b/rsocket/test/internal/ConnectionSetTest.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "rsocket/RSocketConnectionEvents.h" +#include "rsocket/RSocketResponder.h" +#include "rsocket/RSocketStats.h" +#include "rsocket/internal/ConnectionSet.h" +#include "rsocket/internal/KeepaliveTimer.h" +#include "rsocket/statemachine/RSocketStateMachine.h" + +using namespace rsocket; + +namespace { + +std::shared_ptr makeStateMachine(folly::EventBase* evb) { + return std::make_shared( + std::make_shared(), + std::make_unique(std::chrono::seconds{10}, *evb), + RSocketMode::SERVER, + RSocketStats::noop(), + std::make_shared(), + ResumeManager::makeEmpty(), + nullptr /* coldResumeHandler */ + ); +} +} // namespace + +TEST(ConnectionSet, ImmediateDtor) { + ConnectionSet set; +} + +TEST(ConnectionSet, CloseViaMachine) { + folly::EventBase evb; + auto machine = makeStateMachine(&evb); + + ConnectionSet set; + set.insert(machine, &evb); + machine->registerCloseCallback(&set); + + machine->close({}, StreamCompletionSignal::CANCEL); +} + +TEST(ConnectionSet, CloseViaSetDtor) { + folly::EventBase evb; + auto machine = makeStateMachine(&evb); + + ConnectionSet set; + set.insert(machine, &evb); + machine->registerCloseCallback(&set); +} diff --git a/test/internal/FollyKeepaliveTimerTest.cpp b/rsocket/test/internal/KeepaliveTimerTest.cpp similarity index 58% rename from test/internal/FollyKeepaliveTimerTest.cpp rename to rsocket/test/internal/KeepaliveTimerTest.cpp index ae49c1124..e721691a3 100644 --- a/test/internal/FollyKeepaliveTimerTest.cpp +++ b/rsocket/test/internal/KeepaliveTimerTest.cpp @@ -1,6 +1,17 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include #include #include @@ -9,7 +20,7 @@ #include "rsocket/framing/Frame.h" #include "rsocket/framing/FramedDuplexConnection.h" -#include "rsocket/internal/FollyKeepaliveTimer.h" +#include "rsocket/internal/KeepaliveTimer.h" using namespace ::testing; using namespace ::rsocket; @@ -19,7 +30,7 @@ class MockConnectionAutomaton : public FrameSink { public: // MOCK_METHOD doesn't take functions with unique_ptr args. // A workaround for sendKeepalive method. - virtual void sendKeepalive(std::unique_ptr b) override { + void sendKeepalive(std::unique_ptr b) override { sendKeepalive_(b); } MOCK_METHOD1(sendKeepalive_, void(std::unique_ptr&)); @@ -30,7 +41,7 @@ class MockConnectionAutomaton : public FrameSink { disconnectOrCloseWithError_(error); } }; -} +} // namespace TEST(FollyKeepaliveTimerTest, StartStopWithResponse) { auto connectionAutomaton = @@ -40,15 +51,15 @@ TEST(FollyKeepaliveTimerTest, StartStopWithResponse) { folly::EventBase eventBase; - FollyKeepaliveTimer timer(eventBase, std::chrono::milliseconds(100)); + KeepaliveTimer timer(std::chrono::milliseconds(100), eventBase); timer.start(connectionAutomaton); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); timer.keepaliveReceived(); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); timer.stop(); } @@ -62,13 +73,13 @@ TEST(FollyKeepaliveTimerTest, NoResponse) { folly::EventBase eventBase; - FollyKeepaliveTimer timer(eventBase, std::chrono::milliseconds(100)); + KeepaliveTimer timer(std::chrono::milliseconds(100), eventBase); timer.start(connectionAutomaton); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); - timer.sendKeepalive(); + timer.sendKeepalive(*connectionAutomaton); timer.stop(); } diff --git a/rsocket/test/internal/ResumeIdentificationToken.cpp b/rsocket/test/internal/ResumeIdentificationToken.cpp new file mode 100644 index 000000000..48d4b172c --- /dev/null +++ b/rsocket/test/internal/ResumeIdentificationToken.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "rsocket/framing/ResumeIdentificationToken.h" + +using namespace rsocket; + +TEST(ResumeIdentificationTokenTest, Conversion) { + for (int i = 0; i < 10; i++) { + auto token = ResumeIdentificationToken::generateNew(); + auto token2 = ResumeIdentificationToken(token.str()); + CHECK_EQ(token, token2); + CHECK_EQ(token.str(), token2.str()); + } +} diff --git a/rsocket/test/internal/SetupResumeAcceptorTest.cpp b/rsocket/test/internal/SetupResumeAcceptorTest.cpp new file mode 100644 index 000000000..5365f4964 --- /dev/null +++ b/rsocket/test/internal/SetupResumeAcceptorTest.cpp @@ -0,0 +1,314 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include + +#include "rsocket/framing/FrameSerializer.h" +#include "rsocket/framing/FrameTransportImpl.h" +#include "rsocket/internal/SetupResumeAcceptor.h" +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/test/test_utils/MockFrameProcessor.h" +#include "yarpl/test_utils/Mocks.h" + +using namespace rsocket; +using namespace testing; + +namespace { + +/* + * Make a legitimate-looking SETUP frame. + */ +Frame_SETUP makeSetup() { + auto version = ProtocolVersion::Latest; + + Frame_SETUP frame; + frame.header_ = FrameHeader{FrameType::SETUP, FrameFlags::EMPTY_, 0}; + frame.versionMajor_ = version.major; + frame.versionMinor_ = version.minor; + frame.keepaliveTime_ = Frame_SETUP::kMaxKeepaliveTime; + frame.maxLifetime_ = Frame_SETUP::kMaxLifetime; + frame.token_ = ResumeIdentificationToken::generateNew(); + frame.metadataMimeType_ = "application/olive+oil"; + frame.dataMimeType_ = "json/vorhees"; + frame.payload_ = Payload("Test SETUP data", "Test SETUP metadata"); + return frame; +} + +/* + * Make a legitimate-looking RESUME frame. + */ +Frame_RESUME makeResume() { + Frame_RESUME frame; + frame.header_ = FrameHeader{FrameType::RESUME, FrameFlags::EMPTY_, 0}; + frame.versionMajor_ = 1; + frame.versionMinor_ = 0; + frame.token_ = ResumeIdentificationToken::generateNew(); + frame.lastReceivedServerPosition_ = 500; + frame.clientPosition_ = 300; + return frame; +} + +void setupFail(std::unique_ptr, SetupParameters) { + FAIL() << "setupFail() was called"; +} + +void resumeFail(std::unique_ptr, ResumeParameters) { + FAIL() << "resumeFail() was called"; +} +} // namespace + +TEST(SetupResumeAcceptor, ImmediateDtor) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; +} + +TEST(SetupResumeAcceptor, ImmediateClose) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + acceptor.close().get(); +} + +TEST(SetupResumeAcceptor, CloseWithActiveConnection) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + std::shared_ptr outerInput; + + auto connection = + std::make_unique>([&](auto input) { + outerInput = input; + input->onSubscribe(yarpl::flowable::Subscription::create()); + }); + + ON_CALL(*connection, send_(_)).WillByDefault(Invoke([](auto&) { FAIL(); })); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + acceptor.close(); + + evb.loop(); + + // Normally a DuplexConnection impl would complete/error its input subscriber + // in the destructor. Do that manually here. + outerInput->onComplete(); +} + +TEST(SetupResumeAcceptor, EarlyComplete) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto connection = + std::make_unique>([](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + input->onComplete(); + }); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + + evb.loop(); +} + +TEST(SetupResumeAcceptor, EarlyError) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto connection = + std::make_unique>([](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + input->onError(std::runtime_error("Whoops")); + }); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + + evb.loop(); +} + +TEST(SetupResumeAcceptor, SingleSetup) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto connection = + std::make_unique>([](auto input) { + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + input->onSubscribe(yarpl::flowable::Subscription::create()); + input->onNext(serializer->serializeOut(makeSetup())); + input->onComplete(); + }); + + bool setupCalled = false; + + acceptor.accept( + std::move(connection), + [&](auto, auto) { setupCalled = true; }, + resumeFail); + + evb.loop(); + + EXPECT_TRUE(setupCalled); +} + +TEST(SetupResumeAcceptor, InvalidSetup) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto connection = + std::make_unique>([](auto input) { + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + // Bogus keepalive time that can't be deserialized. + auto setup = makeSetup(); + setup.keepaliveTime_ = -5; + + input->onSubscribe(yarpl::flowable::Subscription::create()); + input->onNext(serializer->serializeOut(std::move(setup))); + input->onComplete(); + }); + + EXPECT_CALL(*connection, send_(_)).WillOnce(Invoke([](auto& buf) { + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + Frame_ERROR frame; + EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); + EXPECT_EQ(frame.errorCode_, ErrorCode::CONNECTION_ERROR); + })); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + + evb.loop(); +} + +TEST(SetupResumeAcceptor, RejectedSetup) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + auto connection = + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + input->onNext(serializer->serializeOut(makeSetup())); + input->onComplete(); + }); + + EXPECT_CALL(*connection, send_(_)).WillOnce(Invoke([](auto& buf) { + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + Frame_ERROR frame; + EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); + EXPECT_EQ(frame.errorCode_, ErrorCode::REJECTED_SETUP); + })); + + bool setupCalled = false; + + acceptor.accept( + std::move(connection), + [&](std::unique_ptr connection, auto) { + setupCalled = true; + connection->send( + serializer->serializeOut(Frame_ERROR::rejectedSetup("Oops"))); + }, + resumeFail); + + evb.loop(); + + EXPECT_TRUE(setupCalled); +} + +TEST(SetupResumeAcceptor, RejectedResume) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + auto connection = + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + input->onNext(serializer->serializeOut(makeResume())); + input->onComplete(); + }); + + EXPECT_CALL(*connection, send_(_)).WillOnce(Invoke([](auto& buf) { + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + Frame_ERROR frame; + EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); + EXPECT_EQ(frame.errorCode_, ErrorCode::REJECTED_RESUME); + })); + + bool resumeCalled = false; + + acceptor.accept( + std::move(connection), + setupFail, + [&](std::unique_ptr connection, auto) { + resumeCalled = true; + connection->send(serializer->serializeOut( + Frame_ERROR::rejectedResume("Cant resume"))); + }); + + evb.loop(); + + EXPECT_TRUE(resumeCalled); +} + +TEST(SetupResumeAcceptor, SetupBadVersion) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + auto connection = + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + + auto setup = makeSetup(); + setup.versionMajor_ = 57; + setup.versionMinor_ = 39; + + input->onNext(serializer->serializeOut(std::move(setup))); + input->onComplete(); + }); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + evb.loop(); +} + +TEST(SetupResumeAcceptor, ResumeBadVersion) { + folly::EventBase evb; + SetupResumeAcceptor acceptor{&evb}; + + auto serializer = + FrameSerializer::createFrameSerializer(ProtocolVersion::Latest); + + auto connection = + std::make_unique>([&](auto input) { + input->onSubscribe(yarpl::flowable::Subscription::create()); + + auto resume = makeResume(); + resume.versionMajor_ = 57; + resume.versionMinor_ = 39; + + input->onNext(serializer->serializeOut(std::move(resume))); + input->onComplete(); + }); + + acceptor.accept(std::move(connection), setupFail, resumeFail); + evb.loop(); +} diff --git a/test/internal/SwappableEventBaseTest.cpp b/rsocket/test/internal/SwappableEventBaseTest.cpp similarity index 76% rename from test/internal/SwappableEventBaseTest.cpp rename to rsocket/test/internal/SwappableEventBaseTest.cpp index f80a0eb9d..b02c7f391 100644 --- a/test/internal/SwappableEventBaseTest.cpp +++ b/rsocket/test/internal/SwappableEventBaseTest.cpp @@ -1,6 +1,17 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include #include #include @@ -20,7 +31,7 @@ struct DidExecTracker { const std::string file; const std::string name; DidExecTracker(int line, std::string file, std::string name) - : line(line), file(file), name(name) {} + : line(line), file(file), name(name) {} MOCK_METHOD0(mark, void()); }; @@ -28,10 +39,18 @@ struct DETMarkedOnce : public ::testing::CardinalityInterface { explicit DETMarkedOnce(DidExecTracker const& det) : det(det) {} DidExecTracker const& det; - int ConservativeLowerBound() const override { return 1; } - int ConservativeUpperBound() const override { return 1; } - bool IsSatisfiedByCallCount(int cc) const override { return cc == 1; } - bool IsSaturatedByCallCount(int cc) const override { return cc == 1; } + int ConservativeLowerBound() const override { + return 1; + } + int ConservativeUpperBound() const override { + return 1; + } + bool IsSatisfiedByCallCount(int cc) const override { + return cc == 1; + } + bool IsSaturatedByCallCount(int cc) const override { + return cc == 1; + } void DescribeTo(std::ostream* os) const override { *os << "is called exactly once on "; @@ -43,19 +62,19 @@ ::testing::Cardinality MarkedOnce(DidExecTracker const& det) { } class SwappableEbTest : public ::testing::Test { -public: + public: std::vector> ebs; std::vector> did_exec_trackers; void loop_ebs() { { ::testing::InSequence s; - for(auto tracker : did_exec_trackers) { + for (auto tracker : did_exec_trackers) { EXPECT_CALL(*tracker, mark()).Times(MarkedOnce(*tracker)); } } - for(auto& eb : ebs) { + for (auto& eb : ebs) { ASSERT_TRUE(eb->loop()); } @@ -64,10 +83,9 @@ class SwappableEbTest : public ::testing::Test { } std::shared_ptr make_did_exec_tracker_impl( - int line, - std::string const& file, - std::string const& name - ) { + int line, + std::string const& file, + std::string const& name) { did_exec_trackers.emplace_back(new DidExecTracker(line, file, name)); return did_exec_trackers.back(); } diff --git a/rsocket/test/statemachine/RSocketStateMachineTest.cpp b/rsocket/test/statemachine/RSocketStateMachineTest.cpp new file mode 100644 index 000000000..9d9909398 --- /dev/null +++ b/rsocket/test/statemachine/RSocketStateMachineTest.cpp @@ -0,0 +1,430 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "rsocket/statemachine/RSocketStateMachine.h" +#include +#include +#include +#include +#include +#include "rsocket/RSocketConnectionEvents.h" +#include "rsocket/RSocketResponder.h" +#include "rsocket/framing/FrameSerializer_v1_0.h" +#include "rsocket/framing/FrameTransportImpl.h" +#include "rsocket/internal/Common.h" +#include "rsocket/statemachine/ChannelRequester.h" +#include "rsocket/statemachine/ChannelResponder.h" +#include "rsocket/statemachine/RequestResponseResponder.h" +#include "rsocket/test/test_utils/MockDuplexConnection.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace testing; +using namespace yarpl::mocks; +using namespace yarpl::single; + +namespace rsocket { + +class ResponderMock : public RSocketResponder { + public: + MOCK_METHOD1( + handleRequestResponse_, + std::shared_ptr>(StreamId)); + MOCK_METHOD1( + handleRequestStream_, + std::shared_ptr>(StreamId)); + MOCK_METHOD2( + handleRequestChannel_, + std::shared_ptr>( + std::shared_ptr> requestStream, + StreamId streamId)); + + std::shared_ptr> handleRequestResponse(Payload, StreamId id) + override { + return handleRequestResponse_(id); + } + + std::shared_ptr> handleRequestStream( + Payload, + StreamId id) override { + return handleRequestStream_(id); + } + + std::shared_ptr> handleRequestChannel( + Payload, + std::shared_ptr> requestStream, + StreamId streamId) override { + return handleRequestChannel_(requestStream, streamId); + } +}; + +struct ConnectionEventsMock : public RSocketConnectionEvents { + MOCK_METHOD1(onDisconnected, void(const folly::exception_wrapper&)); + MOCK_METHOD0(onStreamsPaused, void()); +}; + +class RSocketStateMachineTest : public Test { + public: + auto createClient( + std::unique_ptr connection, + std::shared_ptr responder) { + EXPECT_CALL(*connection, setInput_(_)); + EXPECT_CALL(*connection, isFramed()); + + auto transport = + std::make_shared(std::move(connection)); + + auto stateMachine = std::make_shared( + std::move(responder), + nullptr, + RSocketMode::CLIENT, + nullptr, + nullptr, + ResumeManager::makeEmpty(), + nullptr); + + SetupParameters setupParameters; + setupParameters.resumable = false; // Not resumable! + stateMachine->connectClient( + std::move(transport), std::move(setupParameters)); + + return stateMachine; + } + + auto createServer( + std::unique_ptr connection, + std::shared_ptr responder, + folly::Optional resumeToken = folly::none, + std::shared_ptr connectionEvents = nullptr) { + auto transport = + std::make_shared(std::move(connection)); + + auto stateMachine = std::make_shared( + std::move(responder), + nullptr, + RSocketMode::SERVER, + nullptr, + std::move(connectionEvents), + ResumeManager::makeEmpty(), + nullptr); + + if (resumeToken) { + SetupParameters setupParameters; + setupParameters.resumable = true; + setupParameters.token = *resumeToken; + stateMachine->connectServer(std::move(transport), setupParameters); + } else { + SetupParameters setupParameters; + setupParameters.resumable = false; + stateMachine->connectServer(std::move(transport), setupParameters); + } + + return stateMachine; + } + + const std::unordered_map>& + getStreams(RSocketStateMachine& stateMachine) { + return stateMachine.streams_; + } + + void setupRequestStream( + RSocketStateMachine& stateMachine, + StreamId streamId, + uint32_t requestN, + Payload payload) { + stateMachine.onRequestStreamFrame( + streamId, requestN, std::move(payload), false); + } + + void setupRequestChannel( + RSocketStateMachine& stateMachine, + StreamId streamId, + uint32_t requestN, + Payload payload) { + stateMachine.onRequestChannelFrame( + streamId, requestN, std::move(payload), false, true, false); + } + + void setupRequestResponse( + RSocketStateMachine& stateMachine, + StreamId streamId, + Payload payload) { + stateMachine.onRequestResponseFrame(streamId, std::move(payload), false); + } + + void setupFireAndForget( + RSocketStateMachine& stateMachine, + StreamId streamId, + Payload payload) { + stateMachine.onFireAndForgetFrame(streamId, std::move(payload), false); + } +}; + +TEST_F(RSocketStateMachineTest, RequestStream) { + auto connection = std::make_unique>(); + // Setup frame and request stream frame + EXPECT_CALL(*connection, send_(_)).Times(2); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto subscriber = std::make_shared>>(1000); + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onComplete_()); + + stateMachine->requestStream(Payload{}, subscriber); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(1, streams.size()); + + // This line causes: subscriber.onComplete() + streams.at(1)->endStream(StreamCompletionSignal::CANCEL); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RequestStream_EarlyClose) { + auto connection = std::make_unique>(); + // Setup frame, two request stream frames, one extra frame + EXPECT_CALL(*connection, send_(_)).Times(3); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto subscriber = std::make_shared>>(1000); + EXPECT_CALL(*subscriber, onSubscribe_(_)).Times(2); + EXPECT_CALL(*subscriber, onComplete_()); + + stateMachine->requestStream(Payload{}, subscriber); + + // Second stream + stateMachine->requestStream(Payload{}, subscriber); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(2, streams.size()); + + // Close the stream + auto writer = std::dynamic_pointer_cast(stateMachine); + writer->onStreamClosed(1); + + // Push more data to the closed stream + auto processor = std::dynamic_pointer_cast(stateMachine); + FrameSerializerV1_0 serializer; + processor->processFrame( + serializer.serializeOut(Frame_PAYLOAD(1, FrameFlags::COMPLETE, {}))); + + // Second stream should still be valid + ASSERT_EQ(1, streams.size()); + + streams.at(3)->endStream(StreamCompletionSignal::CANCEL); + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RequestChannel) { + auto connection = std::make_unique>(); + // Setup frame and request channel frame + EXPECT_CALL(*connection, send_(_)).Times(2); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto in = std::make_shared>>(1000); + EXPECT_CALL(*in, onSubscribe_(_)); + EXPECT_CALL(*in, onComplete_()); + + auto out = stateMachine->requestChannel(Payload{}, true, in); + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + out->onSubscribe(subscription); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(1, streams.size()); + + // This line causes: in.onComplete() and outSubscription.cancel() + streams.at(1)->endStream(StreamCompletionSignal::CANCEL); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RequestResponse) { + auto connection = std::make_unique>(); + // Setup frame and request channel frame + EXPECT_CALL(*connection, send_(_)).Times(2); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto in = std::make_shared>(); + stateMachine->requestResponse(Payload{}, in); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(1, streams.size()); + + // This line closes the stream + streams.at(1)->handlePayload(Payload{"test", "123"}, true, false, false); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RespondStream) { + auto connection = std::make_unique>(); + int requestCount = 5; + // Payload frames plus a SETUP frame and an ERROR frame + EXPECT_CALL(*connection, send_(_)).Times(requestCount + 2); + + int sendCount = 0; + auto responder = std::make_shared>(); + EXPECT_CALL(*responder, handleRequestStream_(_)) + .WillOnce(Return( + yarpl::flowable::Flowable::fromGenerator([&sendCount]() { + ++sendCount; + return Payload{}; + }))); + + auto stateMachine = createClient(std::move(connection), responder); + setupRequestStream(*stateMachine, 2, requestCount, Payload{}); + EXPECT_EQ(requestCount, sendCount); + + auto& streams = getStreams(*stateMachine); + EXPECT_EQ(1, streams.size()); + + // releases connection and the responder + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RespondChannel) { + auto connection = std::make_unique>(); + int requestCount = 5; + // + the cancel frame when the stateMachine gets destroyed + EXPECT_CALL(*connection, send_(_)).Times(requestCount + 1); + + int sendCount = 0; + auto responder = std::make_shared>(); + EXPECT_CALL(*responder, handleRequestChannel_(_, _)) + .WillOnce(Return( + yarpl::flowable::Flowable::fromGenerator([&sendCount]() { + ++sendCount; + return Payload{}; + }))); + + auto stateMachine = createClient(std::move(connection), responder); + setupRequestChannel(*stateMachine, 2, requestCount, Payload{}); + EXPECT_EQ(requestCount, sendCount); + + auto& streams = getStreams(*stateMachine); + EXPECT_EQ(1, streams.size()); + + // releases connection and the responder + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, RespondRequest) { + auto connection = std::make_unique>(); + EXPECT_CALL(*connection, send_(_)).Times(2); + + int sendCount = 0; + auto responder = std::make_shared>(); + EXPECT_CALL(*responder, handleRequestResponse_(_)) + .WillOnce(Return(Singles::fromGenerator([&sendCount]() { + ++sendCount; + return Payload{}; + }))); + + auto stateMachine = createClient(std::move(connection), responder); + setupRequestResponse(*stateMachine, 2, Payload{}); + EXPECT_EQ(sendCount, 1); + + auto& streams = getStreams(*stateMachine); + EXPECT_EQ(0, streams.size()); // already completed + + // releases connection and the responder + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, StreamImmediateCancel) { + auto connection = std::make_unique>(); + // Only send a SETUP frame. A REQUEST_STREAM frame should never be sent. + EXPECT_CALL(*connection, send_(_)); + + auto stateMachine = + createClient(std::move(connection), std::make_shared()); + + auto subscriber = std::make_shared>>(); + EXPECT_CALL(*subscriber, onSubscribe_(_)) + .WillOnce(Invoke( + [](std::shared_ptr subscription) { + subscription->cancel(); + })); + + stateMachine->requestStream(Payload{}, subscriber); + + auto& streams = getStreams(*stateMachine); + ASSERT_EQ(0, streams.size()); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +TEST_F(RSocketStateMachineTest, TransportOnNextClose) { + auto connection = std::make_unique>(); + // Only SETUP frame gets sent. + EXPECT_CALL(*connection, setInput_(_)); + EXPECT_CALL(*connection, isFramed()); + EXPECT_CALL(*connection, send_(_)); + + auto transport = std::make_shared(std::move(connection)); + auto stateMachine = std::make_shared( + std::make_shared>(), + nullptr, + RSocketMode::CLIENT, + nullptr, + nullptr, + ResumeManager::makeEmpty(), + nullptr); + + SetupParameters params; + params.resumable = false; + stateMachine->connectClient(transport, std::move(params)); + + auto rawTransport = transport.get(); + + // Leak the cycle. + stateMachine.reset(); + transport.reset(); + + FrameSerializerV1_0 serializer; + auto buf = serializer.serializeOut(Frame_ERROR::connectionError("Hah!")); + rawTransport->onNext(std::move(buf)); +} + +TEST_F(RSocketStateMachineTest, ResumeWithCurrentConnection) { + auto resumeToken = ResumeIdentificationToken::generateNew(); + + auto eventsMock = std::make_shared(); + auto stateMachine = createServer( + std::make_unique>(), + std::make_shared(), + resumeToken, + eventsMock); + + EXPECT_CALL(*eventsMock, onDisconnected(_)).Times(0); + EXPECT_CALL(*eventsMock, onStreamsPaused()).Times(0); + + ResumeParameters resumeParams{resumeToken, 0, 0, ProtocolVersion::Latest}; + auto transport = std::make_shared( + std::make_unique>()); + stateMachine->resumeServer(transport, resumeParams); + + stateMachine->close({}, StreamCompletionSignal::CONNECTION_END); +} + +} // namespace rsocket diff --git a/rsocket/test/statemachine/StreamResponderTest.cpp b/rsocket/test/statemachine/StreamResponderTest.cpp new file mode 100644 index 000000000..5c57937ff --- /dev/null +++ b/rsocket/test/statemachine/StreamResponderTest.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "rsocket/statemachine/StreamResponder.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::mocks; + +TEST(StreamResponder, OnComplete) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(3); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onComplete(); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, OnError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onError(std::runtime_error{"Test"}); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, HandleError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->handleError(std::runtime_error("Test")); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, HandleCancel) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->handleCancel(); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamResponder, EndStream) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + + responder->onSubscribe(subscription); + ASSERT_FALSE(responder->publisherClosed()); + + subscription->request(2); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->onNext(Payload{}); + ASSERT_FALSE(responder->publisherClosed()); + + responder->endStream(StreamCompletionSignal::SOCKET_CLOSED); + ASSERT_TRUE(responder->publisherClosed()); +} diff --git a/rsocket/test/statemachine/StreamStateTest.cpp b/rsocket/test/statemachine/StreamStateTest.cpp new file mode 100644 index 000000000..3786708a9 --- /dev/null +++ b/rsocket/test/statemachine/StreamStateTest.cpp @@ -0,0 +1,332 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "rsocket/internal/Common.h" +#include "rsocket/statemachine/ChannelRequester.h" +#include "rsocket/statemachine/ChannelResponder.h" +#include "rsocket/statemachine/StreamStateMachineBase.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace rsocket; +using namespace testing; +using namespace yarpl::mocks; + +class TestStreamStateMachineBase : public StreamStateMachineBase { + public: + using StreamStateMachineBase::StreamStateMachineBase; + void handlePayload(Payload&&, bool, bool, bool) override { + // ignore... + } +}; + +// @see github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel +TEST(StreamState, NewStateMachineBase) { + auto writer = std::make_shared>(); + EXPECT_CALL(*writer, onStreamClosed(_)); + + TestStreamStateMachineBase ssm(writer, 1u); + ssm.getConsumerAllowance(); + ssm.handleCancel(); + ssm.handleError(std::runtime_error("test")); + ssm.handlePayload(Payload{}, false, true, false); + ssm.handleRequestN(1); +} + +TEST(StreamState, ChannelRequesterOnError) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + EXPECT_CALL(*subscription, request_(1)); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + requester->subscribe(mockSubscriber); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + subscriber->onError(std::runtime_error("test")); + + ASSERT_TRUE(requester->consumerClosed()); + ASSERT_TRUE(requester->publisherClosed()); +} + +TEST(StreamState, ChannelResponderOnError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writeError_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + EXPECT_CALL(*writer, writeRequestN_(_)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + responder->subscribe(mockSubscriber); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + subscriber->onError(std::runtime_error("test")); + + ASSERT_TRUE(responder->consumerClosed()); + ASSERT_TRUE(responder->publisherClosed()); +} + +TEST(StreamState, ChannelRequesterHandleError) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writeError_(_)).Times(0); + EXPECT_CALL(*writer, onStreamClosed(1u)).Times(0); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + requester->subscribe(mockSubscriber); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + EXPECT_CALL(*subscription, request_(1)); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + ConsumerBase* consumer = requester.get(); + consumer->handleError(std::runtime_error("test")); + + ASSERT_TRUE(requester->consumerClosed()); + ASSERT_TRUE(requester->publisherClosed()); +} + +TEST(StreamState, ChannelResponderHandleError) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writeError_(_)).Times(0); + EXPECT_CALL(*writer, onStreamClosed(1u)).Times(0); + EXPECT_CALL(*writer, writeRequestN_(_)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + + responder->subscribe(mockSubscriber); + + // Initialize the responder + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + EXPECT_CALL(*subscription, request_(1)).Times(0); + + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + ConsumerBase* consumer = responder.get(); + consumer->handleError(std::runtime_error("test")); + + ASSERT_TRUE(responder->consumerClosed()); + ASSERT_TRUE(responder->publisherClosed()); +} + +// https://github.com/rsocket/rsocket/blob/master/Protocol.md#cancel-from-requester-responder-terminates +TEST(StreamState, ChannelRequesterCancel) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeCancel_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)).Times(0); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + requester->subscribe(mockSubscriber); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + EXPECT_CALL(*subscription, request_(1)); + EXPECT_CALL(*subscription, request_(2)); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + ConsumerBase* consumer = requester.get(); + consumer->cancel(); + + ASSERT_TRUE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + // Still capable of using the producer side + StreamStateMachineBase* base = requester.get(); + base->handleRequestN(2u); + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); +} + +TEST(StreamState, ChannelResponderCancel) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writePayload_(_)).Times(2); + EXPECT_CALL(*writer, writeCancel_(_)); + EXPECT_CALL(*writer, writeRequestN_(_)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + + responder->subscribe(mockSubscriber); + + // Initialize the responder + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()).Times(0); + EXPECT_CALL(*subscription, request_(2)); + + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + ConsumerBase* consumer = responder.get(); + consumer->cancel(); + + ASSERT_TRUE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + // Still capable of using the producer side + StreamStateMachineBase* base = responder.get(); + base->handleRequestN(2u); + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); +} + +TEST(StreamState, ChannelRequesterHandleCancel) { + auto writer = std::make_shared>(); + auto requester = std::make_shared(writer, 1u); + + EXPECT_CALL(*writer, writeNewStream_(1u, _, _, _)); + EXPECT_CALL(*writer, writePayload_(_)).Times(0); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto mockSubscriber = + std::make_shared>>(1000); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + requester->subscribe(mockSubscriber); // cycle: requester <-> mockSubscriber + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + EXPECT_CALL(*subscription, request_(1)); + + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(subscription); + // Initial request to activate the channel + subscriber->onNext(Payload()); + + ASSERT_FALSE(requester->consumerClosed()); + ASSERT_FALSE(requester->publisherClosed()); + + ConsumerBase* consumer = requester.get(); + consumer->handleCancel(); + + ASSERT_TRUE(requester->publisherClosed()); + ASSERT_FALSE(requester->consumerClosed()); + + // As the publisher is closed, this payload will be dropped + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); + + // Break the cycle: requester <-> mockSubscriber + EXPECT_CALL(*writer, writeCancel_(_)); + auto consumerSubscription = mockSubscriber->subscription(); + consumerSubscription->cancel(); +} + +TEST(StreamState, ChannelResponderHandleCancel) { + auto writer = std::make_shared>(); + auto responder = std::make_shared(writer, 1u, 0u); + + EXPECT_CALL(*writer, writePayload_(_)).Times(0); + EXPECT_CALL(*writer, writeRequestN_(_)); + EXPECT_CALL(*writer, onStreamClosed(1u)); + + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + responder->subscribe(mockSubscriber); // cycle: responder <-> mockSubscriber + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, cancel_()); + + yarpl::flowable::Subscriber* subscriber = responder.get(); + subscriber->onSubscribe(subscription); + + ASSERT_FALSE(responder->consumerClosed()); + ASSERT_FALSE(responder->publisherClosed()); + + ConsumerBase* consumer = responder.get(); + consumer->handleCancel(); + + ASSERT_TRUE(responder->publisherClosed()); + ASSERT_FALSE(responder->consumerClosed()); + + // As the publisher is closed, this payload will be dropped + subscriber->onNext(Payload()); + subscriber->onNext(Payload()); + + // Break the cycle: responder <-> mockSubscriber + EXPECT_CALL(*writer, writeCancel_(_)); + auto consumerSubscription = mockSubscriber->subscription(); + consumerSubscription->cancel(); +} diff --git a/rsocket/test/statemachine/StreamsWriterTest.cpp b/rsocket/test/statemachine/StreamsWriterTest.cpp new file mode 100644 index 000000000..e764df8c6 --- /dev/null +++ b/rsocket/test/statemachine/StreamsWriterTest.cpp @@ -0,0 +1,86 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "rsocket/statemachine/ChannelRequester.h" +#include "rsocket/test/test_utils/MockStreamsWriter.h" + +using namespace rsocket; +using namespace testing; + +TEST(StreamsWriterTest, DelegateMock) { + auto writer = std::make_shared>(); + auto& impl = writer->delegateToImpl(); + EXPECT_CALL(impl, outputFrame_(_)); + EXPECT_CALL(impl, shouldQueue()).WillOnce(Return(false)); + EXPECT_CALL(*writer, writeNewStream_(_, _, _, _)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); +} + +TEST(StreamsWriterTest, NewStreamsMockWriterImpl) { + auto writer = std::make_shared>(); + EXPECT_CALL(*writer, outputFrame_(_)); + EXPECT_CALL(*writer, shouldQueue()).WillOnce(Return(false)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); +} + +TEST(StreamsWriterTest, QueueFrames) { + auto writer = std::make_shared>(); + auto& impl = writer->delegateToImpl(); + impl.shouldQueue_ = true; + + EXPECT_CALL(impl, outputFrame_(_)).Times(0); + EXPECT_CALL(impl, shouldQueue()).WillOnce(Return(true)); + EXPECT_CALL(*writer, writeNewStream_(_, _, _, _)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); +} + +TEST(StreamsWriterTest, FlushQueuedFrames) { + auto writer = std::make_shared>(); + auto& impl = writer->delegateToImpl(); + impl.shouldQueue_ = true; + + EXPECT_CALL(impl, outputFrame_(_)).Times(1); + EXPECT_CALL(impl, shouldQueue()).Times(3); + EXPECT_CALL(*writer, writeNewStream_(_, _, _, _)); + + auto requester = std::make_shared(writer, 1u); + yarpl::flowable::Subscriber* subscriber = requester.get(); + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onNext(Payload()); + + // Will queue again + impl.sendPendingFrames(); + + // Now send them actually + impl.shouldQueue_ = false; + impl.sendPendingFrames(); + // it will not send the pending frames twice + impl.sendPendingFrames(); +} diff --git a/rsocket/test/test_utils/ColdResumeManager.cpp b/rsocket/test/test_utils/ColdResumeManager.cpp new file mode 100644 index 000000000..cf53d1b40 --- /dev/null +++ b/rsocket/test/test_utils/ColdResumeManager.cpp @@ -0,0 +1,213 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ColdResumeManager.h" + +#include +#include + +#include + +namespace { +constexpr folly::StringPiece FIRST_SENT_POSITION = "FirstSentPosition"; +constexpr folly::StringPiece LAST_SENT_POSITION = "LastSentPosition"; +constexpr folly::StringPiece IMPLIED_POSITION = "ImpliedPosition"; +constexpr folly::StringPiece LARGEST_USED_STREAMID = "LargestUsedStreamId"; +constexpr folly::StringPiece STREAM_RESUME_INFOS = "StreamResumeInfos"; +constexpr folly::StringPiece FRAMES = "Frames"; +constexpr folly::StringPiece STREAM_TYPE = "StreamType"; +constexpr folly::StringPiece REQUESTER = "Requester"; +constexpr folly::StringPiece STREAM_TOKEN = "StreamToken"; +constexpr folly::StringPiece PROD_ALLOWANCE = "ProducerAllowance"; +constexpr folly::StringPiece CONS_ALLOWANCE = "ConsumerAllowance"; +} // namespace + +namespace rsocket { + +ColdResumeManager::ColdResumeManager( + std::shared_ptr stats, + std::string inputFile) + : WarmResumeManager(std::move(stats)) { + if (inputFile.empty()) { + return; + } + LOG(INFO) << "Reading state from " << inputFile; + + try { + std::ifstream f(inputFile); + std::stringstream buffer; + buffer << f.rdbuf(); + auto state = folly::parseJson(buffer.str()); + f.close(); + + if (!state.isObject() || state.size() != 6) { + throw std::runtime_error( + "Invalid file content. Expected dynamic object of 6 elements"); + } + + if (state.count(FIRST_SENT_POSITION) != 1 || + state.count(LAST_SENT_POSITION) != 1 || + state.count(IMPLIED_POSITION) != 1 || + state.count(LARGEST_USED_STREAMID) != 1 || + state.count(STREAM_RESUME_INFOS) != 1 || state.count(FRAMES) != 1) { + throw std::runtime_error("Invalid file content. Keys Missing"); + } + + firstSentPosition_ = state[FIRST_SENT_POSITION].getInt(); + lastSentPosition_ = state[LAST_SENT_POSITION].getInt(); + impliedPosition_ = state[IMPLIED_POSITION].getInt(); + largestUsedStreamId_ = state[LARGEST_USED_STREAMID].getInt(); + + for (const auto& item : state[STREAM_RESUME_INFOS].items()) { + auto streamId = folly::to(item.first.getString()); + auto streamResumeInfoObj = item.second; + if (streamResumeInfoObj.count(STREAM_TYPE) != 1 || + streamResumeInfoObj.count(STREAM_TOKEN) != 1 || + streamResumeInfoObj.count(PROD_ALLOWANCE) != 1 || + streamResumeInfoObj.count(CONS_ALLOWANCE) != 1 || + streamResumeInfoObj.count(CONS_ALLOWANCE) != 1) { + throw std::runtime_error( + "Invalid file content. StreamResumeInfo Keys Missing"); + } + StreamResumeInfo streamResumeInfo( + static_cast(streamResumeInfoObj[STREAM_TYPE].getInt()), + static_cast( + streamResumeInfoObj[REQUESTER].getInt()), + streamResumeInfoObj[STREAM_TOKEN].getString()); + streamResumeInfo.producerAllowance = + streamResumeInfoObj[PROD_ALLOWANCE].getInt(); + streamResumeInfo.consumerAllowance = + streamResumeInfoObj[CONS_ALLOWANCE].getInt(); + streamResumeInfos_.emplace(streamId, std::move(streamResumeInfo)); + } + + auto framesObj = state[FRAMES]; + if (!framesObj.isArray()) { + throw std::runtime_error( + "Invalid file content. Frames not in right format"); + } + + for (const auto& item : framesObj) { + if (!item.isObject() || item.size() != 1) { + throw std::runtime_error( + "Invalid file content. Expected dynamic object of 1 element"); + } + auto ioBuf = folly::IOBuf::copyBuffer( + item.values().begin()->getString().c_str(), + item.values().begin()->getString().size()); + frames_.emplace_back( + folly::to(item.keys().begin()->getString()), + std::move(ioBuf)); + } + } catch (const std::exception& ex) { + throw std::runtime_error( + folly::sformat("Failed parsing file {}. {}", inputFile, ex.what())); + } +} + +void ColdResumeManager::persistState(std::string outputFile) { + VLOG(1) << "~ColdResumeManager"; + if (outputFile.empty()) { + throw std::runtime_error("Persisting to file failed. Empty filename"); + } + LOG(INFO) << "Persisting state to " << outputFile; + try { + folly::dynamic state = folly::dynamic::object(); + state[FIRST_SENT_POSITION] = firstSentPosition_; + state[LAST_SENT_POSITION] = lastSentPosition_; + state[IMPLIED_POSITION] = impliedPosition_; + state[LARGEST_USED_STREAMID] = largestUsedStreamId_; + state[STREAM_RESUME_INFOS] = folly::dynamic::object(); + for (const auto& streamResumeInfo : streamResumeInfos_) { + folly::dynamic val = folly::dynamic::object(); + val[STREAM_TYPE] = folly::to(streamResumeInfo.second.streamType); + val[STREAM_TOKEN] = streamResumeInfo.second.streamToken; + val[REQUESTER] = folly::to(streamResumeInfo.second.requester); + val[CONS_ALLOWANCE] = streamResumeInfo.second.consumerAllowance; + val[PROD_ALLOWANCE] = streamResumeInfo.second.producerAllowance; + state[STREAM_RESUME_INFOS].insert( + folly::to(streamResumeInfo.first), val); + } + state[FRAMES] = folly::dynamic::array(); + for (const auto& frame : frames_) { + state[FRAMES].push_back(folly::dynamic::object( + folly::to(frame.first), + frame.second->moveToFbString().toStdString())); + } + std::string jsonState = folly::toPrettyJson(state); + std::ofstream f(outputFile); + f << jsonState; + f.close(); + } catch (const std::exception& ex) { + throw std::runtime_error(folly::sformat( + "Persisting state to {} failed. {}", outputFile, ex.what())); + } + LOG(INFO) << "Done persisting state to " << outputFile; +} + +void ColdResumeManager::trackReceivedFrame( + size_t frameLength, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) { + if (!shouldTrackFrame(frameType)) { + return; + } + auto it = streamResumeInfos_.find(streamId); + // If streamId is not present in streamResumeInfo it likely means a + // COMPLETE/CANCEL/ERROR was received in this frame and the + // ResumeMananger::onCloseStream() was already invoked resulting in the entry + // being deleted. + if (it != streamResumeInfos_.end()) { + it->second.consumerAllowance = consumerAllowance; + } + WarmResumeManager::trackReceivedFrame( + frameLength, frameType, streamId, consumerAllowance); +} + +void ColdResumeManager::trackSentFrame( + const folly::IOBuf& serializedFrame, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) { + if (!shouldTrackFrame(frameType)) { + return; + } + auto it = streamResumeInfos_.find(streamId); + CHECK(it != streamResumeInfos_.end()); + it->second.consumerAllowance = consumerAllowance; + WarmResumeManager::trackSentFrame( + std::move(serializedFrame), frameType, streamId, consumerAllowance); +} + +void ColdResumeManager::onStreamClosed(StreamId streamId) { + streamResumeInfos_.erase(streamId); +} + +void ColdResumeManager::onStreamOpen( + StreamId streamId, + RequestOriginator requester, + std::string streamToken, + StreamType streamType) { + CHECK(streamType != StreamType::FNF); + CHECK(streamResumeInfos_.find(streamId) == streamResumeInfos_.end()); + if (requester == RequestOriginator::LOCAL && + streamId > largestUsedStreamId_) { + largestUsedStreamId_ = streamId; + } + streamResumeInfos_.emplace( + streamId, StreamResumeInfo(streamType, requester, streamToken)); +} + +} // namespace rsocket diff --git a/rsocket/test/test_utils/ColdResumeManager.h b/rsocket/test/test_utils/ColdResumeManager.h new file mode 100644 index 000000000..17cee3afb --- /dev/null +++ b/rsocket/test/test_utils/ColdResumeManager.h @@ -0,0 +1,76 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/internal/WarmResumeManager.h" + +namespace folly { +class IOBuf; +} + +namespace rsocket { + +class RSocketStateMachine; +class FrameTransport; + +// In-memory ResumeManager for cold-resumption (for prototyping and +// testing purposes) +class ColdResumeManager : public WarmResumeManager { + public: + // If inputFile is provided, the ColdResumeManager will read state from the + // file, else it will start with a clean state. + // The constructor will throw if there is an error reading from the inputFile. + ColdResumeManager( + std::shared_ptr stats, + std::string inputFile = ""); + + void trackReceivedFrame( + size_t frameLength, + FrameType frameType, + StreamId streamId, + size_t consumerAllowance) override; + + void trackSentFrame( + const folly::IOBuf& serializedFrame, + FrameType frameType, + StreamId streamIdPtr, + size_t consumerAllowance) override; + + void onStreamOpen( + StreamId, + RequestOriginator, + std::string streamToken, + StreamType) override; + + void onStreamClosed(StreamId streamId) override; + + const StreamResumeInfos& getStreamResumeInfos() const override { + return streamResumeInfos_; + } + + StreamId getLargestUsedStreamId() const override { + return largestUsedStreamId_; + } + + // Persist resumption state to outputFile. Will throw if write fails. + void persistState(std::string outputFile); + + private: + StreamResumeInfos streamResumeInfos_; + + // Largest used StreamId so far. + StreamId largestUsedStreamId_{0}; +}; +} // namespace rsocket diff --git a/rsocket/test/test_utils/GenericRequestResponseHandler.h b/rsocket/test/test_utils/GenericRequestResponseHandler.h new file mode 100644 index 000000000..f3f79b6d1 --- /dev/null +++ b/rsocket/test/test_utils/GenericRequestResponseHandler.h @@ -0,0 +1,117 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Single.h" + +#include "folly/ExceptionWrapper.h" + +namespace rsocket { +namespace tests { + +using StringPair = std::pair; + +struct ResponseImpl { + enum class Type { PAYLOAD, EXCEPTION }; + + StringPair p; + folly::exception_wrapper e; + Type type; + + explicit ResponseImpl(StringPair const& p) : p(p), type(Type::PAYLOAD) {} + explicit ResponseImpl(folly::exception_wrapper e) + : e(std::move(e)), type(Type::EXCEPTION) {} + + ~ResponseImpl() {} +}; + +using Response = std::unique_ptr; + +// Type that maps a request (data/metadata) to a response +// (data/metadata or exception) +using HandlerFunc = folly::Function; + +struct GenericRequestResponseHandler : public rsocket::RSocketResponder { + explicit GenericRequestResponseHandler(HandlerFunc&& func) + : handler_(std::make_unique(std::move(func))) {} + + std::shared_ptr> handleRequestResponse( + Payload request, + StreamId) override { + auto ioBufChainToString = [](std::unique_ptr buf) { + folly::IOBufQueue queue; + queue.append(std::move(buf)); + + std::string ret; + while (auto elem = queue.pop_front()) { + auto part = elem->moveToFbString(); + ret += part.toStdString(); + } + + return ret; + }; + + std::string data = ioBufChainToString(std::move(request.data)); + std::string meta = ioBufChainToString(std::move(request.metadata)); + + StringPair req(data, meta); + Response resp = (*handler_)(req); + + return yarpl::single::Single::create( + [resp = std::move(resp), this](auto subscriber) { + subscriber->onSubscribe(yarpl::single::SingleSubscriptions::empty()); + + if (resp->type == ResponseImpl::Type::PAYLOAD) { + subscriber->onSuccess(Payload(resp->p.first, resp->p.second)); + } else if (resp->type == ResponseImpl::Type::EXCEPTION) { + subscriber->onError(resp->e); + } else { + throw std::runtime_error("unknown response type"); + } + }); + } + + ~GenericRequestResponseHandler() {} + + private: + std::unique_ptr handler_; +}; + +inline Response payload_response(StringPair const& sp) { + return std::make_unique(sp); +} + +inline Response payload_response(std::string const& a, std::string const& b) { + return payload_response({a, b}); +} + +template +Response error_response(T const& err) { + return std::make_unique(err); +} + +inline StringPair payload_to_stringpair(Payload p) { + return StringPair(p.moveDataToString(), p.moveMetadataToString()); +} +} // namespace tests +} // namespace rsocket + +namespace std { +inline ostream& operator<<( + std::ostream& os, + rsocket::tests::StringPair const& payload) { + return os << "('" << payload.first << "', '" << payload.second << "')"; +} +} // namespace std diff --git a/rsocket/test/test_utils/MockDuplexConnection.h b/rsocket/test/test_utils/MockDuplexConnection.h new file mode 100644 index 000000000..a64c64436 --- /dev/null +++ b/rsocket/test/test_utils/MockDuplexConnection.h @@ -0,0 +1,54 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/DuplexConnection.h" +#include "yarpl/test_utils/Mocks.h" + +namespace rsocket { + +class MockDuplexConnection : public DuplexConnection { + public: + using Subscriber = DuplexConnection::Subscriber; + + MockDuplexConnection() {} + + /// Creates a DuplexConnection that always runs `in` on the input subscriber. + template + MockDuplexConnection(InputFn in) { + EXPECT_CALL(*this, setInput_(testing::_)) + .WillRepeatedly(testing::Invoke(std::move(in))); + } + + // DuplexConnection. + + void setInput(std::shared_ptr in) override { + setInput_(std::move(in)); + } + + void send(std::unique_ptr buf) override { + send_(buf); + } + + // Mocks. + + MOCK_METHOD1(setInput_, void(std::shared_ptr)); + MOCK_METHOD1(send_, void(std::unique_ptr&)); + MOCK_CONST_METHOD0(isFramed, bool()); +}; + +} // namespace rsocket diff --git a/rsocket/test/test_utils/MockFrameProcessor.h b/rsocket/test/test_utils/MockFrameProcessor.h new file mode 100644 index 000000000..385a143f1 --- /dev/null +++ b/rsocket/test/test_utils/MockFrameProcessor.h @@ -0,0 +1,40 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "rsocket/framing/FrameProcessor.h" + +namespace rsocket { + +class MockFrameProcessor : public FrameProcessor { + public: + void processFrame(std::unique_ptr buf) override { + processFrame_(buf); + } + + void onTerminal(folly::exception_wrapper ew) override { + onTerminal_(std::move(ew)); + } + + MOCK_METHOD1(processFrame_, void(std::unique_ptr&)); + MOCK_METHOD1(onTerminal_, void(folly::exception_wrapper)); +}; + +} // namespace rsocket diff --git a/test/test_utils/MockStats.h b/rsocket/test/test_utils/MockStats.h similarity index 52% rename from test/test_utils/MockStats.h rename to rsocket/test/test_utils/MockStats.h index 717e8c3fa..c1707299f 100644 --- a/test/test_utils/MockStats.h +++ b/rsocket/test/test_utils/MockStats.h @@ -1,14 +1,26 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include -#include "rsocket/RSocketStats.h> -#include "rsocket/transports/tcp/TcpDuplexConnection.h> #include "rsocket/Payload.h" +#include "rsocket/RSocketStats.h" +#include "rsocket/transports/tcp/TcpDuplexConnection.h" namespace rsocket { @@ -32,4 +44,4 @@ class MockStats : public RSocketStats { MOCK_METHOD2(resumeBufferChanged, void(int, int)); MOCK_METHOD2(streamBufferChanged, void(int64_t, int64_t)); }; -} +} // namespace rsocket diff --git a/rsocket/test/test_utils/MockStreamsWriter.h b/rsocket/test/test_utils/MockStreamsWriter.h new file mode 100644 index 000000000..4d6593c48 --- /dev/null +++ b/rsocket/test/test_utils/MockStreamsWriter.h @@ -0,0 +1,152 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "rsocket/RSocketStats.h" +#include "rsocket/framing/FrameSerializer_v1_0.h" +#include "rsocket/statemachine/StreamsWriter.h" + +namespace rsocket { + +class MockStreamsWriterImpl : public StreamsWriterImpl { + public: + MOCK_METHOD1(onStreamClosed, void(StreamId)); + MOCK_METHOD1(outputFrame_, void(folly::IOBuf*)); + MOCK_METHOD0(shouldQueue, bool()); + + MockStreamsWriterImpl() { + using namespace testing; + ON_CALL(*this, shouldQueue()).WillByDefault(Invoke([this]() { + return this->shouldQueue_; + })); + } + + void outputFrame(std::unique_ptr buf) override { + outputFrame_(buf.get()); + } + + FrameSerializer& serializer() override { + return frameSerializer; + } + + RSocketStats& stats() override { + return *stats_; + } + + std::shared_ptr> onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) override { + // ignoring... + return nullptr; + } + + void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) + override { + // ignoring... + } + + using StreamsWriterImpl::sendPendingFrames; + + bool shouldQueue_{false}; + std::shared_ptr stats_ = RSocketStats::noop(); + FrameSerializerV1_0 frameSerializer; +}; + +class MockStreamsWriter : public StreamsWriter { + public: + MOCK_METHOD4(writeNewStream_, void(StreamId, StreamType, uint32_t, Payload&)); + MOCK_METHOD1(writeRequestN_, void(rsocket::Frame_REQUEST_N)); + MOCK_METHOD1(writeCancel_, void(rsocket::Frame_CANCEL)); + MOCK_METHOD1(writePayload_, void(rsocket::Frame_PAYLOAD&)); + MOCK_METHOD1(writeError_, void(rsocket::Frame_ERROR&)); + MOCK_METHOD1(onStreamClosed, void(rsocket::StreamId)); + + // Delegate the Mock calls to the implementation in StreamsWriterImpl. + MockStreamsWriterImpl& delegateToImpl() { + delegateToImpl_ = true; + using namespace testing; + ON_CALL(*this, onStreamClosed(_)) + .WillByDefault(Invoke(&impl_, &StreamsWriter::onStreamClosed)); + return impl_; + } + + void writeNewStream(StreamId id, StreamType type, uint32_t i, Payload p) + override { + writeNewStream_(id, type, i, p); + if (delegateToImpl_) { + impl_.writeNewStream(id, type, i, std::move(p)); + } + } + + void writeRequestN(rsocket::Frame_REQUEST_N&& request) override { + if (delegateToImpl_) { + impl_.writeRequestN(std::move(request)); + } + writeRequestN_(request); + } + + void writeCancel(rsocket::Frame_CANCEL&& cancel) override { + writeCancel_(cancel); + if (delegateToImpl_) { + impl_.writeCancel(std::move(cancel)); + } + } + + void writePayload(rsocket::Frame_PAYLOAD&& payload) override { + writePayload_(payload); + if (delegateToImpl_) { + impl_.writePayload(std::move(payload)); + } + } + + void writeError(rsocket::Frame_ERROR&& error) override { + writeError_(error); + if (delegateToImpl_) { + impl_.writeError(std::move(error)); + } + } + + std::shared_ptr> onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) override { + // ignoring... + return nullptr; + } + + void onNewStreamReady( + StreamId streamId, + StreamType streamType, + Payload payload, + std::shared_ptr> response) + override { + // ignoring... + } + + protected: + MockStreamsWriterImpl impl_; + bool delegateToImpl_{false}; +}; + +} // namespace rsocket diff --git a/rsocket/test/test_utils/PrintSubscriber.cpp b/rsocket/test/test_utils/PrintSubscriber.cpp new file mode 100644 index 000000000..baf3c14ef --- /dev/null +++ b/rsocket/test/test_utils/PrintSubscriber.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "PrintSubscriber.h" +#include +#include +#include + +namespace rsocket { + +PrintSubscriber::~PrintSubscriber() { + LOG(INFO) << "~PrintSubscriber " << this; +} + +void PrintSubscriber::onSubscribe( + std::shared_ptr subscription) noexcept { + LOG(INFO) << "PrintSubscriber " << this << " onSubscribe"; + subscription->request(std::numeric_limits::max()); +} + +void PrintSubscriber::onNext(Payload element) noexcept { + LOG(INFO) << "PrintSubscriber " << this << " onNext " << element; +} + +void PrintSubscriber::onComplete() noexcept { + LOG(INFO) << "PrintSubscriber " << this << " onComplete"; +} + +void PrintSubscriber::onError(folly::exception_wrapper ex) noexcept { + LOG(INFO) << "PrintSubscriber " << this << " onError " << ex; +} +} // namespace rsocket diff --git a/rsocket/test/test_utils/PrintSubscriber.h b/rsocket/test/test_utils/PrintSubscriber.h new file mode 100644 index 000000000..5a392c1dd --- /dev/null +++ b/rsocket/test/test_utils/PrintSubscriber.h @@ -0,0 +1,31 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "rsocket/Payload.h" +#include "yarpl/flowable/Subscriber.h" + +namespace rsocket { +class PrintSubscriber : public yarpl::flowable::Subscriber { + public: + ~PrintSubscriber(); + + void onSubscribe(std::shared_ptr + subscription) noexcept override; + void onNext(Payload element) noexcept override; + void onComplete() noexcept override; + void onError(folly::exception_wrapper ex) noexcept override; +}; +} // namespace rsocket diff --git a/test/test_utils/StatsPrinter.cpp b/rsocket/test/test_utils/StatsPrinter.cpp similarity index 64% rename from test/test_utils/StatsPrinter.cpp rename to rsocket/test/test_utils/StatsPrinter.cpp index 9324fe2da..90f4f0511 100644 --- a/test/test_utils/StatsPrinter.cpp +++ b/rsocket/test/test_utils/StatsPrinter.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "StatsPrinter.h" #include @@ -57,4 +69,12 @@ void StatsPrinter::streamBufferChanged( LOG(INFO) << "streamBufferChanged framesCountDelta=" << framesCountDelta << " dataSizeDelta=" << dataSizeDelta; } + +void StatsPrinter::keepaliveSent() { + LOG(INFO) << "keepalive sent"; +} + +void StatsPrinter::keepaliveReceived() { + LOG(INFO) << "keepalive response received"; } +} // namespace rsocket diff --git a/test/test_utils/StatsPrinter.h b/rsocket/test/test_utils/StatsPrinter.h similarity index 55% rename from test/test_utils/StatsPrinter.h rename to rsocket/test/test_utils/StatsPrinter.h index 8c793b631..afe2fa8a0 100644 --- a/test/test_utils/StatsPrinter.h +++ b/rsocket/test/test_utils/StatsPrinter.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -26,5 +38,8 @@ class StatsPrinter : public RSocketStats { void resumeBufferChanged(int framesCountDelta, int dataSizeDelta) override; void streamBufferChanged(int64_t framesCountDelta, int64_t dataSizeDelta) override; + + void keepaliveSent() override; + void keepaliveReceived() override; }; -} +} // namespace rsocket diff --git a/rsocket/test/transport/DuplexConnectionTest.cpp b/rsocket/test/transport/DuplexConnectionTest.cpp new file mode 100644 index 000000000..d8bcd33b7 --- /dev/null +++ b/rsocket/test/transport/DuplexConnectionTest.cpp @@ -0,0 +1,214 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "DuplexConnectionTest.h" + +#include +#include "yarpl/test_utils/Mocks.h" + +namespace rsocket { +namespace tests { + +using namespace folly; +using namespace rsocket; +using namespace ::testing; + +void makeMultipleSetInputGetOutputCalls( + std::unique_ptr serverConnection, + EventBase* serverEvb, + std::unique_ptr clientConnection, + EventBase* clientEvb) { + auto serverSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); + EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(10); + + serverEvb->runInEventBaseThreadAndWait([&] { + // Keep receiving messages from different subscribers + serverConnection->setInput(serverSubscriber); + }); + + for (int i = 0; i < 10; ++i) { + auto clientSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); + EXPECT_CALL(*clientSubscriber, onNext_(_)); + + clientEvb->runInEventBaseThreadAndWait([&] { + // Set another subscriber and receive messages + clientConnection->setInput(clientSubscriber); + // Get another subscriber and send messages + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); + }); + serverSubscriber->awaitFrames(1); + + serverEvb->runInEventBaseThreadAndWait( + [&] { serverConnection->send(folly::IOBuf::copyBuffer("6543210")); }); + clientSubscriber->awaitFrames(1); + + clientEvb->runInEventBaseThreadAndWait( + [subscriber = std::move(clientSubscriber)]() { + // Enables calling setInput again with another subscriber. + subscriber->subscription()->cancel(); + }); + } + + // Cleanup + serverEvb->runInEventBaseThreadAndWait( + [subscriber = std::move(serverSubscriber)] { + subscriber->subscription()->cancel(); + }); + clientEvb->runInEventBaseThreadAndWait( + [connection = std::move(clientConnection)] {}); + serverEvb->runInEventBaseThreadAndWait( + [connection = std::move(serverConnection)] {}); +} + +/** + * Closing an Input or Output should not effect the other. + */ +void verifyInputAndOutputIsUntied( + std::unique_ptr serverConnection, + EventBase* serverEvb, + std::unique_ptr clientConnection, + EventBase* clientEvb) { + auto serverSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); + EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(3); + + serverEvb->runInEventBaseThreadAndWait( + [&] { serverConnection->setInput(serverSubscriber); }); + + auto clientSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); + + clientEvb->runInEventBaseThreadAndWait([&] { + clientConnection->setInput(clientSubscriber); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); + }); + serverSubscriber->awaitFrames(1); + + clientEvb->runInEventBaseThreadAndWait([&] { + // Close the client subscriber + { + clientSubscriber->subscription()->cancel(); + auto deleteSubscriber = std::move(clientSubscriber); + } + // Output is still active + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); + }); + serverSubscriber->awaitFrames(1); + + // Another client subscriber + clientSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); + EXPECT_CALL(*clientSubscriber, onNext_(_)); + clientEvb->runInEventBaseThreadAndWait([&] { + // Set new input subscriber + clientConnection->setInput(clientSubscriber); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); + }); + serverSubscriber->awaitFrames(1); + + // Still sending message from server to the client. + serverEvb->runInEventBaseThreadAndWait( + [&] { serverConnection->send(folly::IOBuf::copyBuffer("6543210")); }); + clientSubscriber->awaitFrames(1); + + // Cleanup + clientEvb->runInEventBaseThreadAndWait( + [subscriber = std::move(clientSubscriber)] { + subscriber->subscription()->cancel(); + }); + serverEvb->runInEventBaseThreadAndWait( + [subscriber = std::move(serverSubscriber)] { + subscriber->subscription()->cancel(); + }); + clientEvb->runInEventBaseThreadAndWait( + [connection = std::move(clientConnection)] {}); + serverEvb->runInEventBaseThreadAndWait( + [connection = std::move(serverConnection)] {}); +} + +void verifyClosingInputAndOutputDoesntCloseConnection( + std::unique_ptr serverConnection, + folly::EventBase* serverEvb, + std::unique_ptr clientConnection, + folly::EventBase* clientEvb) { + auto serverSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); + + serverEvb->runInEventBaseThreadAndWait( + [&] { serverConnection->setInput(serverSubscriber); }); + + auto clientSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); + + clientEvb->runInEventBaseThreadAndWait( + [&] { clientConnection->setInput(clientSubscriber); }); + + // Close all subscribers + clientEvb->runInEventBaseThreadAndWait([input = std::move(clientSubscriber)] { + input->subscription()->cancel(); + }); + + serverEvb->runInEventBaseThreadAndWait([input = std::move(serverSubscriber)] { + input->subscription()->cancel(); + }); + + // Set new subscribers as the connection is not closed + serverSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); + EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(1); + // The subscriber is to be closed, as the subscription is not cancelled + // but the connection is closed at the end + EXPECT_CALL(*serverSubscriber, onComplete_()); + + serverEvb->runInEventBaseThreadAndWait( + [&] { serverConnection->setInput(serverSubscriber); }); + + clientSubscriber = std::make_shared< + yarpl::mocks::MockSubscriber>>(); + EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); + EXPECT_CALL(*clientSubscriber, onNext_(_)).Times(1); + // The subscriber is to be closed, as the subscription is not cancelled + // but the connection is closed at the end + EXPECT_CALL(*clientSubscriber, onComplete_()); + + clientEvb->runInEventBaseThreadAndWait([&] { + clientConnection->setInput(clientSubscriber); + clientConnection->send(folly::IOBuf::copyBuffer("0123456")); + }); + serverSubscriber->awaitFrames(1); + + // Wait till client is ready before sending message from server. + serverEvb->runInEventBaseThreadAndWait( + [&] { serverConnection->send(folly::IOBuf::copyBuffer("6543210")); }); + clientSubscriber->awaitFrames(1); + + // Cleanup + clientEvb->runInEventBaseThreadAndWait( + [connection = std::move(clientConnection)] {}); + serverEvb->runInEventBaseThreadAndWait( + [connection = std::move(serverConnection)] {}); +} + +} // namespace tests +} // namespace rsocket diff --git a/test/transport/DuplexConnectionTest.h b/rsocket/test/transport/DuplexConnectionTest.h similarity index 59% rename from test/transport/DuplexConnectionTest.h rename to rsocket/test/transport/DuplexConnectionTest.h index e013b6ffb..c370975e9 100644 --- a/test/transport/DuplexConnectionTest.h +++ b/rsocket/test/transport/DuplexConnectionTest.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/test/transport/TcpDuplexConnectionTest.cpp b/rsocket/test/transport/TcpDuplexConnectionTest.cpp similarity index 50% rename from test/transport/TcpDuplexConnectionTest.cpp rename to rsocket/test/transport/TcpDuplexConnectionTest.cpp index b9600383e..ae17a51dd 100644 --- a/test/transport/TcpDuplexConnectionTest.cpp +++ b/rsocket/test/transport/TcpDuplexConnectionTest.cpp @@ -1,19 +1,32 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include -#include +#include +#include +#include #include +#include "rsocket/test/transport/DuplexConnectionTest.h" #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" #include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "test/transport/DuplexConnectionTest.h" namespace rsocket { namespace tests { using namespace folly; using namespace rsocket; -using namespace ::testing; /** * Synchronously create a server and a client. @@ -25,12 +38,15 @@ makeSingleClientServer( std::unique_ptr& serverConnection, EventBase** serverEvb, std::unique_ptr& clientConnection, - EventBase** clientEvb) { - Promise serverPromise, clientPromise; + EventBase* clientEvb) { + Promise serverPromise; - TcpConnectionAcceptor::Options options( - 0 /*port*/, 1 /*threads*/, 0 /*backlog*/); - auto server = std::make_unique(options); + TcpConnectionAcceptor::Options options; + options.address = folly::SocketAddress{"::", 0}; + options.threads = 1; + options.backlog = 0; + + auto server = std::make_unique(std::move(options)); server->start( [&serverPromise, &serverConnection, &serverEvb]( std::unique_ptr connection, EventBase& eventBase) { @@ -42,54 +58,55 @@ makeSingleClientServer( int16_t port = server->listeningPort().value(); auto client = std::make_unique( - SocketAddress("localhost", port, true)); - client->connect( - [&clientPromise, &clientConnection, &clientEvb]( - std::unique_ptr connection, EventBase& eventBase) { - clientConnection = std::move(connection); - *clientEvb = &eventBase; - clientPromise.setValue(); - }); + *clientEvb, SocketAddress("localhost", port, true)); + client->connect(ProtocolVersion::Latest, ResumeStatus::NEW_SESSION) + .thenValue([&clientConnection]( + ConnectionFactory::ConnectedDuplexConnection connection) { + clientConnection = std::move(connection.connection); + }) + .wait(); - serverPromise.getFuture().wait(); - clientPromise.getFuture().wait(); + serverPromise.getSemiFuture().wait(); return std::make_pair(std::move(server), std::move(client)); } TEST(TcpDuplexConnection, MultipleSetInputGetOutputCalls) { + folly::ScopedEventBaseThread worker; std::unique_ptr serverConnection, clientConnection; - EventBase *serverEvb = nullptr, *clientEvb = nullptr; + EventBase* serverEvb = nullptr; auto keepAlive = makeSingleClientServer( - serverConnection, &serverEvb, clientConnection, &clientEvb); + serverConnection, &serverEvb, clientConnection, worker.getEventBase()); makeMultipleSetInputGetOutputCalls( std::move(serverConnection), serverEvb, std::move(clientConnection), - clientEvb); + worker.getEventBase()); } TEST(TcpDuplexConnection, InputAndOutputIsUntied) { + folly::ScopedEventBaseThread worker; std::unique_ptr serverConnection, clientConnection; - EventBase *serverEvb = nullptr, *clientEvb = nullptr; + EventBase* serverEvb = nullptr; auto keepAlive = makeSingleClientServer( - serverConnection, &serverEvb, clientConnection, &clientEvb); + serverConnection, &serverEvb, clientConnection, worker.getEventBase()); verifyInputAndOutputIsUntied( std::move(serverConnection), serverEvb, std::move(clientConnection), - clientEvb); + worker.getEventBase()); } TEST(TcpDuplexConnection, ConnectionAndSubscribersAreUntied) { + folly::ScopedEventBaseThread worker; std::unique_ptr serverConnection, clientConnection; - EventBase *serverEvb = nullptr, *clientEvb = nullptr; + EventBase* serverEvb = nullptr; auto keepAlive = makeSingleClientServer( - serverConnection, &serverEvb, clientConnection, &clientEvb); + serverConnection, &serverEvb, clientConnection, worker.getEventBase()); verifyClosingInputAndOutputDoesntCloseConnection( std::move(serverConnection), serverEvb, std::move(clientConnection), - clientEvb); + worker.getEventBase()); } } // namespace tests diff --git a/rsocket/transports/RSocketTransport.h b/rsocket/transports/RSocketTransport.h new file mode 100644 index 000000000..d86a4669a --- /dev/null +++ b/rsocket/transports/RSocketTransport.h @@ -0,0 +1,49 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace rsocket { +class RSocketTransportHandler { + public: + virtual ~RSocketTransportHandler() = default; + + // connection scope signals + virtual void onKeepAlive( + ResumePosition resumePosition, + std::unique_ptr data, + bool keepAliveRespond) = 0; + virtual void onMetadataPush(std::unique_ptr metadata) = 0; + virtual void onResumeOk(ResumePosition resumePosition); + virtual void onError(ErrorCode errorCode, Payload payload) = 0; + + // stream scope signals + virtual void onStreamRequestN(StreamId streamId, uint32_t requestN) = 0; + virtual void onStreamCancel(StreamId streamId) = 0; + virtual void onStreamError(StreamId streamId, Payload payload) = 0; + virtual void onStreamPayload( + StreamId streamId, + Payload payload, + bool flagsFollows, + bool flagsComplete, + bool flagsNext) = 0; +}; + +class RSocketTransport { + public: + virtual ~RSocketTransport() = default; + + // TODO: +}; +} // namespace rsocket diff --git a/rsocket/transports/tcp/TcpConnectionAcceptor.cpp b/rsocket/transports/tcp/TcpConnectionAcceptor.cpp index d18a95e94..12ac289f9 100644 --- a/rsocket/transports/tcp/TcpConnectionAcceptor.cpp +++ b/rsocket/transports/tcp/TcpConnectionAcceptor.cpp @@ -1,12 +1,24 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include +#include #include -#include +#include +#include -#include "rsocket/framing/FramedDuplexConnection.h" #include "rsocket/transports/tcp/TcpDuplexConnection.h" namespace rsocket { @@ -15,23 +27,24 @@ class TcpConnectionAcceptor::SocketCallback : public folly::AsyncServerSocket::AcceptCallback { public: explicit SocketCallback(OnDuplexConnectionAccept& onAccept) - : onAccept_{onAccept} {} + : thread_{folly::sformat("rstcp-acceptor")}, onAccept_{onAccept} {} void connectionAccepted( - int fd, + folly::NetworkSocket fdNetworkSocket, const folly::SocketAddress& address) noexcept override { - VLOG(1) << "Accepting TCP connection from " << address << " on FD " << fd; + int fd = fdNetworkSocket.toFd(); - folly::AsyncSocket::UniquePtr socket( - new folly::AsyncSocket(eventBase(), fd)); + VLOG(2) << "Accepting TCP connection from " << address << " on FD " << fd; - auto connection = std::make_unique( - std::move(socket)); + folly::AsyncTransportWrapper::UniquePtr socket( + new folly::AsyncSocket(eventBase(), folly::NetworkSocket::fromFd(fd))); + + auto connection = std::make_unique(std::move(socket)); onAccept_(std::move(connection), *eventBase()); } - void acceptError(const std::exception& ex) noexcept override { - VLOG(1) << "TCP error: " << ex.what(); + void acceptError(folly::exception_wrapper ex) noexcept override { + VLOG(2) << "TCP error: " << ex; } folly::EventBase* eventBase() const { @@ -46,70 +59,59 @@ class TcpConnectionAcceptor::SocketCallback OnDuplexConnectionAccept& onAccept_; }; -//////////////////////////////////////////////////////////////////////////////// - TcpConnectionAcceptor::TcpConnectionAcceptor(Options options) : options_(std::move(options)) {} TcpConnectionAcceptor::~TcpConnectionAcceptor() { if (serverThread_) { stop(); + serverThread_.reset(); } } -//////////////////////////////////////////////////////////////////////////////// - void TcpConnectionAcceptor::start(OnDuplexConnectionAccept onAccept) { if (onAccept_ != nullptr) { throw std::runtime_error("TcpConnectionAcceptor::start() already called"); } onAccept_ = std::move(onAccept); - serverThread_ = std::make_unique(); - serverThread_->getEventBase()->runInEventBaseThread( - [] { folly::setThreadName("TcpConnectionAcceptor.Listener"); }); + serverThread_ = + std::make_unique("rstcp-listener"); callbacks_.reserve(options_.threads); for (size_t i = 0; i < options_.threads; ++i) { callbacks_.push_back(std::make_unique(onAccept_)); - callbacks_[i]->eventBase()->runInEventBaseThread( - [] { folly::setThreadName("TcpConnectionAcceptor.Worker"); }); } - LOG(INFO) << "Starting TCP listener on port " << options_.address.getPort() << " with " - << options_.threads << " request threads"; + VLOG(1) << "Starting TCP listener on port " << options_.address.getPort() + << " with " << options_.threads << " request threads"; serverSocket_.reset( new folly::AsyncServerSocket(serverThread_->getEventBase())); // The AsyncServerSocket needs to be accessed from the listener thread only. // This will propagate out any exceptions the listener throws. - folly::via( - serverThread_->getEventBase(), - [this] { - serverSocket_->bind(options_.address); - - for (auto const& callback : callbacks_) { - serverSocket_->addAcceptCallback( - callback.get(), callback->eventBase()); - } - - serverSocket_->listen(options_.backlog); - serverSocket_->startAccepting(); - - for (auto& i : serverSocket_->getAddresses()) { - LOG(INFO) << "Listening on " << i.describe(); - } - }) - .get(); + folly::via(serverThread_->getEventBase(), [this] { + serverSocket_->bind(options_.address); + + for (auto const& callback : callbacks_) { + serverSocket_->addAcceptCallback(callback.get(), callback->eventBase()); + } + + serverSocket_->listen(options_.backlog); + serverSocket_->startAccepting(); + + for (const auto& i : serverSocket_->getAddresses()) { + VLOG(1) << "Listening on " << i.describe(); + } + }).get(); } void TcpConnectionAcceptor::stop() { - LOG(INFO) << "Shutting down TCP listener"; + VLOG(1) << "Shutting down TCP listener"; - serverThread_->getEventBase()->runInEventBaseThread( - [this] { serverSocket_.reset(); }); - serverThread_.reset(); + serverThread_->getEventBase()->runInEventBaseThreadAndWait( + [serverSocket = std::move(serverSocket_)]() {}); } folly::Optional TcpConnectionAcceptor::listeningPort() const { @@ -119,4 +121,4 @@ folly::Optional TcpConnectionAcceptor::listeningPort() const { return serverSocket_->getAddress().getPort(); } -} +} // namespace rsocket diff --git a/rsocket/transports/tcp/TcpConnectionAcceptor.h b/rsocket/transports/tcp/TcpConnectionAcceptor.h index 94ae5e598..5d922d06e 100644 --- a/rsocket/transports/tcp/TcpConnectionAcceptor.h +++ b/rsocket/transports/tcp/TcpConnectionAcceptor.h @@ -1,15 +1,24 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include +#include #include "rsocket/ConnectionAcceptor.h" -namespace folly { -class ScopedEventBaseThread; -} - namespace rsocket { /** @@ -20,27 +29,19 @@ namespace rsocket { class TcpConnectionAcceptor : public ConnectionAcceptor { public: struct Options { - explicit Options(uint16_t port_ = 8080, size_t threads_ = 2, - int backlog_ = 10) : address("::", port_), threads(threads_), - backlog(backlog_) {} - /// Address to listen on - folly::SocketAddress address; + folly::SocketAddress address{"::", 8080}; /// Number of worker threads processing requests. - size_t threads; + size_t threads{2}; /// Number of connections to buffer before accept handlers process them. - int backlog; + int backlog{10}; }; - ////////////////////////////////////////////////////////////////////////////// - explicit TcpConnectionAcceptor(Options); ~TcpConnectionAcceptor(); - ////////////////////////////////////////////////////////////////////////////// - // ConnectionAcceptor overrides. /** @@ -61,6 +62,9 @@ class TcpConnectionAcceptor : public ConnectionAcceptor { private: class SocketCallback; + /// Options this acceptor has been configured with. + const Options options_; + /// The thread driving the AsyncServerSocket. std::unique_ptr serverThread_; @@ -73,8 +77,6 @@ class TcpConnectionAcceptor : public ConnectionAcceptor { /// The socket listening for new connections. folly::AsyncServerSocket::UniquePtr serverSocket_; - - /// Options this acceptor has been configured with. - Options options_; }; -} + +} // namespace rsocket diff --git a/rsocket/transports/tcp/TcpConnectionFactory.cpp b/rsocket/transports/tcp/TcpConnectionFactory.cpp index d6f25076b..b970cd756 100644 --- a/rsocket/transports/tcp/TcpConnectionFactory.cpp +++ b/rsocket/transports/tcp/TcpConnectionFactory.cpp @@ -1,15 +1,27 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/transports/tcp/TcpConnectionFactory.h" +#include #include +#include #include #include #include "rsocket/transports/tcp/TcpDuplexConnection.h" -using namespace rsocket; - namespace rsocket { namespace { @@ -18,73 +30,94 @@ class ConnectCallback : public folly::AsyncSocket::ConnectCallback { public: ConnectCallback( folly::SocketAddress address, - OnDuplexConnectionConnect onConnect) - : address_(address), onConnect_{std::move(onConnect)} { + const std::shared_ptr& sslContext, + folly::Promise + connectPromise) + : address_(address), connectPromise_(std::move(connectPromise)) { VLOG(2) << "Constructing ConnectCallback"; // Set up by ScopedEventBaseThread. auto evb = folly::EventBaseManager::get()->getExistingEventBase(); DCHECK(evb); - VLOG(3) << "Starting socket"; - socket_.reset(new folly::AsyncSocket(evb)); + if (sslContext) { +#if !FOLLY_OPENSSL_HAS_ALPN + // setAdvertisedNextProtocols() is unavailable +#error ALPN is required for rsockets. \ + Your version of OpenSSL is likely too old. +#else + VLOG(3) << "Starting SSL socket"; + sslContext->setAdvertisedNextProtocols({"rs"}); +#endif + socket_.reset(new folly::AsyncSSLSocket(sslContext, evb)); + } else { + VLOG(3) << "Starting socket"; + socket_.reset(new folly::AsyncSocket(evb)); + } VLOG(3) << "Attempting connection to " << address_; socket_->connect(this, address_); } - ~ConnectCallback() { + ~ConnectCallback() override { VLOG(2) << "Destroying ConnectCallback"; } void connectSuccess() noexcept override { std::unique_ptr deleter(this); - VLOG(4) << "connectSuccess() on " << address_; auto connection = TcpConnectionFactory::createDuplexConnectionFromSocket( std::move(socket_), RSocketStats::noop()); auto evb = folly::EventBaseManager::get()->getExistingEventBase(); CHECK(evb); - onConnect_(std::move(connection), *evb); + connectPromise_.setValue(ConnectionFactory::ConnectedDuplexConnection{ + std::move(connection), *evb}); } void connectErr(const folly::AsyncSocketException& ex) noexcept override { std::unique_ptr deleter(this); - VLOG(4) << "connectErr(" << ex.what() << ") on " << address_; + connectPromise_.setException(ex); } private: - folly::SocketAddress address_; + const folly::SocketAddress address_; folly::AsyncSocket::UniquePtr socket_; - OnDuplexConnectionConnect onConnect_; + folly::Promise connectPromise_; }; } // namespace -TcpConnectionFactory::TcpConnectionFactory(folly::SocketAddress address) - : address_{std::move(address)} { - VLOG(1) << "Constructing TcpConnectionFactory"; -} +TcpConnectionFactory::TcpConnectionFactory( + folly::EventBase& eventBase, + folly::SocketAddress address, + std::shared_ptr sslContext) + : eventBase_(&eventBase), + address_(std::move(address)), + sslContext_(std::move(sslContext)) {} -TcpConnectionFactory::~TcpConnectionFactory() { - VLOG(1) << "Destroying TcpConnectionFactory"; -} +TcpConnectionFactory::~TcpConnectionFactory() = default; + +folly::Future +TcpConnectionFactory::connect(ProtocolVersion, ResumeStatus /* unused */) { + folly::Promise connectPromise; + auto connectFuture = connectPromise.getFuture(); -void TcpConnectionFactory::connect(OnDuplexConnectionConnect cb) { - worker_.getEventBase()->runInEventBaseThread( - [ this, fn = std::move(cb) ]() mutable { - new ConnectCallback(address_, std::move(fn)); + eventBase_->runInEventBaseThread( + [this, promise = std::move(connectPromise)]() mutable { + new ConnectCallback(address_, sslContext_, std::move(promise)); }); + return connectFuture; } std::unique_ptr TcpConnectionFactory::createDuplexConnectionFromSocket( - folly::AsyncSocket::UniquePtr socket, + folly::AsyncTransportWrapper::UniquePtr socket, std::shared_ptr stats) { - return std::make_unique(std::move(socket), std::move(stats)); + return std::make_unique( + std::move(socket), std::move(stats)); } } // namespace rsocket diff --git a/rsocket/transports/tcp/TcpConnectionFactory.h b/rsocket/transports/tcp/TcpConnectionFactory.h index 994593421..283b50eb5 100644 --- a/rsocket/transports/tcp/TcpConnectionFactory.h +++ b/rsocket/transports/tcp/TcpConnectionFactory.h @@ -1,14 +1,30 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include -#include -#include +#include #include "rsocket/ConnectionFactory.h" #include "rsocket/DuplexConnection.h" +namespace folly { + +class SSLContext; +} + namespace rsocket { class RSocketStats; @@ -20,7 +36,10 @@ class RSocketStats; */ class TcpConnectionFactory : public ConnectionFactory { public: - explicit TcpConnectionFactory(folly::SocketAddress); + TcpConnectionFactory( + folly::EventBase& eventBase, + folly::SocketAddress address, + std::shared_ptr sslContext = nullptr); virtual ~TcpConnectionFactory(); /** @@ -28,14 +47,17 @@ class TcpConnectionFactory : public ConnectionFactory { * * Each call to connect() creates a new AsyncSocket. */ - void connect(OnDuplexConnectionConnect) override; + folly::Future connect( + ProtocolVersion, + ResumeStatus resume) override; static std::unique_ptr createDuplexConnectionFromSocket( - folly::AsyncSocket::UniquePtr socket, + folly::AsyncTransportWrapper::UniquePtr socket, std::shared_ptr stats = std::shared_ptr()); private: - folly::SocketAddress address_; - folly::ScopedEventBaseThread worker_; + folly::EventBase* eventBase_; + const folly::SocketAddress address_; + std::shared_ptr sslContext_; }; } // namespace rsocket diff --git a/rsocket/transports/tcp/TcpDuplexConnection.cpp b/rsocket/transports/tcp/TcpDuplexConnection.cpp index d0268dc73..054e31768 100644 --- a/rsocket/transports/tcp/TcpDuplexConnection.cpp +++ b/rsocket/transports/tcp/TcpDuplexConnection.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "rsocket/transports/tcp/TcpDuplexConnection.h" @@ -10,32 +22,35 @@ namespace rsocket { -using namespace ::folly; using namespace yarpl::flowable; -class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, - public ::folly::AsyncTransportWrapper::ReadCallback, - public std::enable_shared_from_this { +class TcpReaderWriter : public folly::AsyncTransportWrapper::WriteCallback, + public folly::AsyncTransportWrapper::ReadCallback { + friend void intrusive_ptr_add_ref(TcpReaderWriter* x); + friend void intrusive_ptr_release(TcpReaderWriter* x); + public: explicit TcpReaderWriter( - folly::AsyncSocket::UniquePtr&& socket, + folly::AsyncTransportWrapper::UniquePtr&& socket, std::shared_ptr stats) : socket_(std::move(socket)), stats_(std::move(stats)) {} - ~TcpReaderWriter() { + ~TcpReaderWriter() override { CHECK(isClosed()); DCHECK(!inputSubscriber_); } - void setInput( - yarpl::Reference>> - inputSubscriber) { + folly::AsyncTransportWrapper* getTransport() { + return socket_.get(); + } + + void setInput(std::shared_ptr inputSubscriber) { if (inputSubscriber && isClosed()) { inputSubscriber->onComplete(); return; } - if(!inputSubscriber) { + if (!inputSubscriber) { inputSubscriber_ = nullptr; return; } @@ -43,26 +58,12 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, CHECK(!inputSubscriber_); inputSubscriber_ = std::move(inputSubscriber); - self_ = shared_from_this(); - - // safe to call repeatedly - socket_->setReadCB(this); - } - - void setOutputSubscription(yarpl::Reference subscription) { - if (!subscription) { - outputSubscription_ = nullptr; - return; - } - - if (isClosed()) { - subscription->cancel(); - return; + if (!socket_->getReadCallback()) { + // The AsyncSocket will hold a reference to this instance until it calls + // readEOF or readErr. + intrusive_ptr_add_ref(this); + socket_->setReadCB(this); } - - // No flow control at TCP level, since we can't know the size of messages. - subscription->request(std::numeric_limits::max()); - outputSubscription_ = std::move(subscription); } void send(std::unique_ptr element) { @@ -73,6 +74,9 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, if (stats_) { stats_->bytesWritten(element->computeChainDataLength()); } + // now AsyncSocket will hold a reference to this instance as a writer until + // they call writeComplete or writeErr + intrusive_ptr_add_ref(this); socket_->writeChain(this, std::move(element)); } @@ -80,9 +84,6 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, if (auto socket = std::move(socket_)) { socket->close(); } - if (auto outputSubscription = std::move(outputSubscription_)) { - outputSubscription->cancel(); - } if (auto subscriber = std::move(inputSubscriber_)) { subscriber->onComplete(); } @@ -92,11 +93,8 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, if (auto socket = std::move(socket_)) { socket->close(); } - if (auto subscription = std::move(outputSubscription_)) { - subscription->cancel(); - } if (auto subscriber = std::move(inputSubscriber_)) { - subscriber->onError(ew.to_exception_ptr()); + subscriber->onError(std::move(ew)); } } @@ -105,13 +103,14 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, return !socket_; } - void writeSuccess() noexcept override {} + void writeSuccess() noexcept override { + intrusive_ptr_release(this); + } - void writeErr( - size_t, - const folly::AsyncSocketException& exn) noexcept override { - closeErr(exn); - self_ = nullptr; + void writeErr(size_t, const folly::AsyncSocketException& exn) noexcept + override { + closeErr(folly::exception_wrapper{folly::copy(exn)}); + intrusive_ptr_release(this); } void getReadBuffer(void** bufReturn, size_t* lenReturn) noexcept override { @@ -131,12 +130,12 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, void readEOF() noexcept override { close(); - self_ = nullptr; + intrusive_ptr_release(this); } void readErr(const folly::AsyncSocketException& exn) noexcept override { - closeErr(exn); - self_ = nullptr; + closeErr(folly::exception_wrapper{folly::copy(exn)}); + intrusive_ptr_release(this); } bool isBufferMovable() noexcept override { @@ -150,75 +149,38 @@ class TcpReaderWriter : public ::folly::AsyncTransportWrapper::WriteCallback, } folly::IOBufQueue readBuffer_{folly::IOBufQueue::cacheChainLength()}; - folly::AsyncSocket::UniquePtr socket_; + folly::AsyncTransportWrapper::UniquePtr socket_; const std::shared_ptr stats_; - yarpl::Reference>> - inputSubscriber_; - yarpl::Reference outputSubscription_; - - // self reference is used to keep the instance alive for the AsyncSocket - // callbacks even after DuplexConnection releases references to this - std::shared_ptr self_; + std::shared_ptr inputSubscriber_; + int refCount_{0}; }; -class TcpOutputSubscriber - : public Subscriber> { - public: - explicit TcpOutputSubscriber( - std::shared_ptr tcpReaderWriter) - : tcpReaderWriter_(std::move(tcpReaderWriter)) { - CHECK(tcpReaderWriter_); - } - - void onSubscribe( - yarpl::Reference subscription) noexcept override { - CHECK(subscription); - if (!tcpReaderWriter_) { - LOG(ERROR) << "trying to resubscribe on a closed subscriber"; - subscription->cancel(); - return; - } - tcpReaderWriter_->setOutputSubscription(std::move(subscription)); - } - - void onNext(std::unique_ptr element) noexcept override { - CHECK(tcpReaderWriter_); - tcpReaderWriter_->send(std::move(element)); - } +void intrusive_ptr_add_ref(TcpReaderWriter* x); +void intrusive_ptr_release(TcpReaderWriter* x); - void onComplete() noexcept override { - CHECK(tcpReaderWriter_); - tcpReaderWriter_->setOutputSubscription(nullptr); - } - - void onError(std::exception_ptr eptr) noexcept override { - CHECK(tcpReaderWriter_); - auto tcpReaderWriter = std::move(tcpReaderWriter_); +inline void intrusive_ptr_add_ref(TcpReaderWriter* x) { + ++x->refCount_; +} - try { - std::rethrow_exception(eptr); - } catch (const std::exception& exn) { - folly::exception_wrapper ew{eptr, exn}; - tcpReaderWriter->closeErr(std::move(ew)); - } - } +inline void intrusive_ptr_release(TcpReaderWriter* x) { + if (--x->refCount_ == 0) + delete x; +} - private: - std::shared_ptr tcpReaderWriter_; -}; +namespace { class TcpInputSubscription : public Subscription { public: explicit TcpInputSubscription( - std::shared_ptr tcpReaderWriter) + boost::intrusive_ptr tcpReaderWriter) : tcpReaderWriter_(std::move(tcpReaderWriter)) { CHECK(tcpReaderWriter_); } void request(int64_t n) noexcept override { DCHECK(tcpReaderWriter_); - DCHECK(n == kMaxRequestN) + DCHECK_EQ(n, std::numeric_limits::max()) << "TcpDuplexConnection doesnt support proper flow control"; } @@ -228,14 +190,15 @@ class TcpInputSubscription : public Subscription { } private: - std::shared_ptr tcpReaderWriter_; + boost::intrusive_ptr tcpReaderWriter_; }; +} // namespace + TcpDuplexConnection::TcpDuplexConnection( - folly::AsyncSocket::UniquePtr&& socket, + folly::AsyncTransportWrapper::UniquePtr&& socket, std::shared_ptr stats) - : tcpReaderWriter_( - std::make_shared(std::move(socket), stats)), + : tcpReaderWriter_(new TcpReaderWriter(std::move(socket), stats)), stats_(stats) { if (stats_) { stats_->duplexConnectionCreated("tcp", this); @@ -249,18 +212,21 @@ TcpDuplexConnection::~TcpDuplexConnection() { tcpReaderWriter_->close(); } -yarpl::Reference>> -TcpDuplexConnection::getOutput() { - return yarpl::make_ref(tcpReaderWriter_); +folly::AsyncTransportWrapper* TcpDuplexConnection::getTransport() { + return tcpReaderWriter_ ? tcpReaderWriter_->getTransport() : nullptr; +} + +void TcpDuplexConnection::send(std::unique_ptr buf) { + if (tcpReaderWriter_) { + tcpReaderWriter_->send(std::move(buf)); + } } void TcpDuplexConnection::setInput( - yarpl::Reference>> - inputSubscriber) { + std::shared_ptr inputSubscriber) { // we don't care if the subscriber will call request synchronously inputSubscriber->onSubscribe( - yarpl::make_ref(tcpReaderWriter_)); + std::make_shared(tcpReaderWriter_)); tcpReaderWriter_->setInput(std::move(inputSubscriber)); } - -} // rsocket +} // namespace rsocket diff --git a/rsocket/transports/tcp/TcpDuplexConnection.h b/rsocket/transports/tcp/TcpDuplexConnection.h index 2cae61dd8..5bfa9adec 100644 --- a/rsocket/transports/tcp/TcpDuplexConnection.h +++ b/rsocket/transports/tcp/TcpDuplexConnection.h @@ -1,8 +1,22 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once +#include #include +#include #include "rsocket/DuplexConnection.h" #include "rsocket/RSocketStats.h" @@ -15,24 +29,19 @@ class TcpReaderWriter; class TcpDuplexConnection : public DuplexConnection { public: explicit TcpDuplexConnection( - folly::AsyncSocket::UniquePtr&& socket, + folly::AsyncTransportWrapper::UniquePtr&& socket, std::shared_ptr stats = RSocketStats::noop()); ~TcpDuplexConnection(); - // - // both getOutput and setOutput are ok to be called multiple times - // on a single instance of TcpDuplexConnection - // the latest input/output will be used - // + void send(std::unique_ptr) override; - yarpl::Reference>> - getOutput() override; + void setInput(std::shared_ptr) override; - void setInput(yarpl::Reference>> framesSink) override; + // Only to be used for observation purposes. + folly::AsyncTransportWrapper* getTransport(); private: - std::shared_ptr tcpReaderWriter_; + boost::intrusive_ptr tcpReaderWriter_; std::shared_ptr stats_; }; } // namespace rsocket diff --git a/scripts/build_folly.sh b/scripts/build_folly.sh new file mode 100755 index 000000000..ebe67aa00 --- /dev/null +++ b/scripts/build_folly.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# +# Copyright 2004-present Facebook. All Rights Reserved. +# +CHECKOUT_DIR=$1 +INSTALL_DIR=$2 +if [[ -z $INSTALL_DIR ]]; then + echo "usage: $0 CHECKOUT_DIR INSTALL_DIR" >&2 + exit 1 +fi + +# Convert INSTALL_DIR to an absolute path so it still refers to the same +# location after we cd into the build directory. +case "$INSTALL_DIR" in + /*) ;; + *) INSTALL_DIR="$PWD/$INSTALL_DIR" +esac + +# If folly was already installed, just return early +INSTALL_MARKER_FILE="$INSTALL_DIR/folly.installed" +if [[ -f $INSTALL_MARKER_FILE ]]; then + echo "folly was previously built" + exit 0 +fi + +set -e +set -x + +if [[ -d "$CHECKOUT_DIR" ]]; then + git -C "$CHECKOUT_DIR" fetch + git -C "$CHECKOUT_DIR" checkout master +else + git clone https://github.com/facebook/folly "$CHECKOUT_DIR" +fi + +mkdir -p "$CHECKOUT_DIR/_build" +cd "$CHECKOUT_DIR/_build" +if ! cmake \ + "-DCMAKE_PREFIX_PATH=${INSTALL_DIR}" \ + "-DCMAKE_INSTALL_PREFIX=${INSTALL_DIR}" \ + ..; then + echo "error configuring folly" >&2 + tail -n 100 CMakeFiles/CMakeError.log >&2 + exit 1 +fi +make -j4 +make install +touch "$INSTALL_MARKER_FILE" diff --git a/scripts/frame_fuzzer_test.sh b/scripts/frame_fuzzer_test.sh new file mode 100755 index 000000000..785cfa8d1 --- /dev/null +++ b/scripts/frame_fuzzer_test.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# +# Copyright 2004-present Facebook. All Rights Reserved. +# +if [ ! -s ./build/frame_fuzzer ]; then + echo "./build/frame_fuzzer binary not found!" + exit 1 +fi + +shopt -s nullglob +for fuzzcase in ./test/fuzzer_testcases/frame_fuzzer/*; do + echo "testing with $fuzzcase..." + ./build/frame_fuzzer --v=100 < $fuzzcase +done diff --git a/scripts/tck_test.sh b/scripts/tck_test.sh index 8a7ca7b50..814b6e309 100755 --- a/scripts/tck_test.sh +++ b/scripts/tck_test.sh @@ -1,3 +1,7 @@ +#!/bin/bash +# +# Copyright 2004-present Facebook. All Rights Reserved. +# if [ "$#" -ne 4 ]; then echo "Illegal number of parameters - $#" exit 1 @@ -41,23 +45,28 @@ if [ ! -s ./build/tckclient ] && [ "$client_lang" = cpp ]; then exit 1 fi -java_server="java -cp rsocket-tck-drivers-0.9-SNAPSHOT.jar io/rsocket/tckdrivers/main/Main --server --host localhost --port 9898 --file tck-test/servertest.txt" -java_client="java -cp rsocket-tck-drivers-0.9-SNAPSHOT.jar io/rsocket/tckdrivers/main/Main --client --host localhost --port 9898 --file tck-test/clienttest.txt" +timeout='timeout' +if [[ "$OSTYPE" == "darwin"* ]]; then + timeout='gtimeout' +fi + +java_server="java -jar rsocket-tck-drivers-0.9.10.jar --server --host localhost --port 9898 --file rsocket/tck-test/servertest.txt" +java_client="java -jar rsocket-tck-drivers-0.9.10.jar --client --host localhost --port 9898 --file rsocket/tck-test/clienttest.txt" -cpp_server="./build/tckserver -test_file tck-test/servertest.txt -rs_use_protocol_version 1.0" -cpp_client="./build/tckclient -test_file tck-test/clienttest.txt -rs_use_protocol_version 1.0" +cpp_server="./build/tckserver -test_file rsocket/tck-test/servertest.txt -rs_use_protocol_version 1.0" +cpp_client="./build/tckclient -test_file rsocket/tck-test/clienttest.txt -rs_use_protocol_version 1.0" server="${server_lang}_server" client="${client_lang}_client" # run server in the background -timeout 60 ${!server} & +$timeout 60 ${!server} & # wait for the server to listen sleep 2 # run client -timeout 60 ${!client} +$timeout 60 ${!client} ret=$? # terminate server diff --git a/tck-test/FlowableSubscriber.h b/tck-test/FlowableSubscriber.h deleted file mode 100644 index a3717abaa..000000000 --- a/tck-test/FlowableSubscriber.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "tck-test/BaseSubscriber.h" - -#include "yarpl/Flowable.h" - -namespace rsocket { -namespace tck { - -class FlowableSubscriber : public BaseSubscriber, - public yarpl::flowable::Subscriber { - public: - explicit FlowableSubscriber(int initialRequestN = 0); - - // Inherited from BaseSubscriber - void request(int n) override; - void cancel() override; - - protected: - // Inherited from flowable::Subscriber - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload element) noexcept override; - void onComplete() noexcept override; - void onError(std::exception_ptr ex) noexcept override; - - private: - yarpl::Reference subscription_; - int initialRequestN_{0}; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/MarbleProcessor.h b/tck-test/MarbleProcessor.h deleted file mode 100644 index 7f6972299..000000000 --- a/tck-test/MarbleProcessor.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include "rsocket/Payload.h" -#include "yarpl/Flowable.h" -#include "yarpl/Single.h" - -namespace rsocket { -namespace tck { - -class MarbleProcessor { - public: - explicit MarbleProcessor(const std::string /* marble */); - - std::tuple run( - yarpl::flowable::Subscriber& subscriber, - int64_t requested); - - void run(yarpl::Reference> - subscriber); - - private: - std::string marble_; - - // Stores a mapping from marble character to Payload (data, metadata) - std::map> argMap_; - - // Keeps an account of how many messages can be sent. This could be done - // with Semaphores (AllowanceSemaphore) - std::atomic canSend_{0}; - - size_t index_{0}; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/SingleSubscriber.h b/tck-test/SingleSubscriber.h deleted file mode 100644 index cab2bd7ef..000000000 --- a/tck-test/SingleSubscriber.h +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "tck-test/BaseSubscriber.h" - -#include "yarpl/Single.h" - -namespace rsocket { -namespace tck { - -class SingleSubscriber : public BaseSubscriber, - public yarpl::single::SingleObserver { - public: - // Inherited from BaseSubscriber - void request(int n) override; - void cancel() override; - - protected: - // Inherited from flowable::Subscriber - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onSuccess(Payload element) noexcept override; - void onError(std::exception_ptr ex) noexcept override; - - private: - yarpl::Reference subscription_; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/TestFileParser.h b/tck-test/TestFileParser.h deleted file mode 100644 index cd0166010..000000000 --- a/tck-test/TestFileParser.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "tck-test/TestSuite.h" - -namespace rsocket { -namespace tck { - -class TestFileParser { - public: - explicit TestFileParser(const std::string& fileName); - - TestSuite parse(); - - private: - void parseCommand(const std::string& command); - void addCurrentTest(); - - std::ifstream input_; - int currentLine_; - - TestSuite testSuite_; - Test currentTest_; -}; - -} // tck -} // reactivesocket diff --git a/tck-test/TestSuite.cpp b/tck-test/TestSuite.cpp deleted file mode 100644 index 0e0dad89e..000000000 --- a/tck-test/TestSuite.cpp +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "tck-test/TestSuite.h" - -#include - -namespace rsocket { -namespace tck { - -bool TestCommand::valid() const { - // there has to be a name to the test and at least 1 param - return params_.size() >= 1; -} - -void Test::addCommand(TestCommand command) { - CHECK(command.valid()); - commands_.push_back(std::move(command)); -} - -} // tck -} // reactivesocket diff --git a/test/RSocketClientServerTest.cpp b/test/RSocketClientServerTest.cpp deleted file mode 100644 index 77d393ec3..000000000 --- a/test/RSocketClientServerTest.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "RSocketTests.h" - -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -TEST(RSocketClientServer, StartAndShutdown) { - auto server = makeServer(std::make_shared()); - auto client = makeClient(*server->listeningPort()); -} - -TEST(RSocketClientServer, ConnectOne) { - auto server = makeServer(std::make_shared()); - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); -} - -TEST(RSocketClientServer, ConnectManySync) { - auto server = makeServer(std::make_shared()); - - for (size_t i = 0; i < 100; ++i) { - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - } -} - -TEST(RSocketClientServer, ConnectManyAsync) { - auto server = makeServer(std::make_shared()); - - for (size_t i = 0; i < 100; ++i) { - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - } - -} diff --git a/test/RSocketTests.h b/test/RSocketTests.h deleted file mode 100644 index 11c77aa70..000000000 --- a/test/RSocketTests.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include - -#include "rsocket/RSocket.h" -#include "rsocket/transports/tcp/TcpConnectionAcceptor.h" -#include "rsocket/transports/tcp/TcpConnectionFactory.h" -#include "test/handlers/HelloStreamRequestHandler.h" - -namespace rsocket { -namespace tests { -namespace client_server { - -inline std::unique_ptr makeServer( - std::shared_ptr responder) { - TcpConnectionAcceptor::Options opts; - opts.threads = 2; - opts.address = folly::SocketAddress("::", 0); - - // RSocket server accepting on TCP. - auto rs = RSocket::createServer( - std::make_unique(std::move(opts))); - - rs->start([r = std::move(responder)](const SetupParameters&) { return r; }); - - return rs; -} - -inline std::unique_ptr makeResumableServer( - std::shared_ptr serviceHandler) { - TcpConnectionAcceptor::Options opts; - opts.threads = 1; - opts.address = folly::SocketAddress("::", 0); - auto rs = RSocket::createServer( - std::make_unique(std::move(opts))); - rs->start(std::move(serviceHandler)); - return rs; -} - -inline std::shared_ptr makeClient(uint16_t port) { - folly::SocketAddress address; - address.setFromHostPort("localhost", port); - return RSocket::createConnectedClient( - std::make_unique(std::move(address))) - .get(); -} - -inline std::shared_ptr makeResumableClient(uint16_t port) { - folly::SocketAddress address; - address.setFromHostPort("localhost", port); - SetupParameters setupParameters; - setupParameters.resumable = true; - return RSocket::createConnectedClient( - std::make_unique(std::move(address)), - std::move(setupParameters)) - .get(); -} -} -} -} // namespace diff --git a/test/RequestChannelTest.cpp b/test/RequestChannelTest.cpp deleted file mode 100644 index 006b3c36e..000000000 --- a/test/RequestChannelTest.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include - -#include "RSocketTests.h" -#include "yarpl/Flowable.h" -#include "yarpl/flowable/TestSubscriber.h" - -using namespace yarpl; -using namespace yarpl::flowable; -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -/** - * Test a finite stream both directions. - */ -class TestHandlerHello : public rsocket::RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> handleRequestChannel( - rsocket::Payload initialPayload, - yarpl::Reference> request, - rsocket::StreamId) override { - // say "Hello" to each name on the input stream - return request->map([initialPayload = std::move(initialPayload)]( - Payload p) { - std::stringstream ss; - ss << "[" << initialPayload.cloneDataToString() << "] " - << "Hello " << p.moveDataToString() << "!"; - std::string s = ss.str(); - - return Payload(s); - }); - } -}; - -TEST(RequestChannelTest, Hello) { - auto server = makeServer(std::make_shared()); - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - - auto ts = TestSubscriber::create(); - requester - ->requestChannel( - Flowables::justN({"/hello", "Bob", "Jane"})->map([](std::string v) { - return Payload(v); - })) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(2); - // assert that we echo back the 2nd and 3rd request values - // with the 1st initial payload prepended to each - ts->assertValueAt(0, "[/hello] Hello Bob!"); - ts->assertValueAt(1, "[/hello] Hello Jane!"); -} - -// TODO complete from requester, responder continues -// TODO complete from responder, requester continues -// TODO cancel from requester, shuts down -// TODO flow control from requester to responder -// TODO flow control from responder to requester -// TODO failure on responder, requester sees -// TODO failure on request, requester sees -// TODO failure from requester ... what happens? diff --git a/test/RequestResponseTest.cpp b/test/RequestResponseTest.cpp deleted file mode 100644 index 6a217b063..000000000 --- a/test/RequestResponseTest.cpp +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include "RSocketTests.h" -#include "yarpl/Single.h" -#include "yarpl/single/SingleTestObserver.h" - -using namespace yarpl; -using namespace yarpl::single; -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -namespace { -class TestHandlerHello : public rsocket::RSocketResponder { - public: - Reference> handleRequestResponse(Payload request, StreamId) - override { - auto requestString = request.moveDataToString(); - return Single::create([name = std::move(requestString)]( - auto subscriber) { - subscriber->onSubscribe(SingleSubscriptions::empty()); - std::stringstream ss; - ss << "Hello " << name << "!"; - std::string s = ss.str(); - subscriber->onSuccess(Payload(s, "metadata")); - }); - } -}; -} - -TEST(RequestResponseTest, Hello) { - auto server = makeServer(std::make_shared()); - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - - auto to = SingleTestObserver::create(); - requester->requestResponse(Payload("Jane")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(to); - to->awaitTerminalEvent(); - to->assertOnSuccessValue("Hello Jane!"); -} - -namespace { -class TestHandlerCancel : public rsocket::RSocketResponder { - public: - TestHandlerCancel( - std::shared_ptr> onCancel, - std::shared_ptr> onSubscribe) - : onCancel_(std::move(onCancel)), onSubscribe_(std::move(onSubscribe)) {} - Reference> handleRequestResponse(Payload request, StreamId) - override { - // used to signal to the client when the subscribe is received - onSubscribe_->post(); - // used to block this responder thread until a cancel is sent from client - // over network - auto cancelFromClient = std::make_shared>(); - // used to signal to the client once we receive a cancel - auto onCancel = onCancel_; - auto requestString = request.moveDataToString(); - return Single::create( - [ name = std::move(requestString), cancelFromClient, onCancel ]( - auto subscriber) mutable { - std::thread([ - subscriber = std::move(subscriber), - name = std::move(name), - cancelFromClient, - onCancel - ]() { - auto subscription = SingleSubscriptions::create( - [cancelFromClient] { cancelFromClient->post(); }); - subscriber->onSubscribe(subscription); - // simulate slow processing or IO being done - // and block this current background thread - // until we are cancelled - cancelFromClient->wait(); - if (subscription->isCancelled()) { - // this is used by the unit test to assert the cancel was - // received - onCancel->post(); - } else { - // if not cancelled would do work and emit here - } - }).detach(); - }); - } - - private: - std::shared_ptr> onCancel_; - std::shared_ptr> onSubscribe_; -}; -} - -TEST(RequestResponseTest, Cancel) { - auto onCancel = std::make_shared>(); - auto onSubscribe = std::make_shared>(); - auto server = - makeServer(std::make_shared(onCancel, onSubscribe)); - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - - auto to = SingleTestObserver::create(); - requester->requestResponse(Payload("Jane")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(to); - // NOTE: wait for server to receive request/subscribe - // otherwise the cancellation will all happen locally - onSubscribe->wait(); - // now cancel the local subscription - to->cancel(); - // wait for cancel to propagate to server - onCancel->wait(); - // assert no signals received on client - to->assertNoTerminalEvent(); -} - -// TODO failure on responder, requester sees -// TODO failure on request, requester sees diff --git a/test/RequestStreamTest.cpp b/test/RequestStreamTest.cpp deleted file mode 100644 index 1930397b0..000000000 --- a/test/RequestStreamTest.cpp +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include - -#include "RSocketTests.h" -#include "yarpl/Flowable.h" -#include "yarpl/flowable/TestSubscriber.h" - -using namespace yarpl; -using namespace yarpl::flowable; -using namespace rsocket; -using namespace rsocket::tests; -using namespace rsocket::tests::client_server; - -namespace { -class TestHandlerSync : public rsocket::RSocketResponder { - public: - Reference> handleRequestStream( - Payload request, - StreamId) override { - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 10)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); - } -}; - -TEST(RequestStreamTest, HelloSync) { - auto server = makeServer(std::make_shared()); - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(10); - ts->assertValueAt(0, "Hello Bob 1!"); - ts->assertValueAt(9, "Hello Bob 10!"); -} - -class TestHandlerAsync : public rsocket::RSocketResponder { - public: - Reference> handleRequestStream( - Payload request, - StreamId) override { - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::fromPublisher< - Payload>([requestString = std::move(requestString)]( - Reference> subscriber) { - std::thread([ - requestString = std::move(requestString), - subscriber = std::move(subscriber) - ]() { - Flowables::range(1, 40) - ->map([name = std::move(requestString)](int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }) - ->subscribe(subscriber); - }).detach(); - }); - } -}; -} - -TEST(RequestStreamTest, HelloAsync) { - auto server = makeServer(std::make_shared()); - auto client = makeClient(*server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(40); - ts->assertValueAt(0, "Hello Bob 1!"); - ts->assertValueAt(39, "Hello Bob 40!"); -} diff --git a/test/Test.cpp b/test/Test.cpp deleted file mode 100644 index 884d4736b..000000000 --- a/test/Test.cpp +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include - -int main(int argc, char** argv) { - FLAGS_logtostderr = true; - testing::InitGoogleMock(&argc, argv); - folly::init(&argc, &argv); - return RUN_ALL_TESTS(); -} diff --git a/test/WarmResumptionTest.cpp b/test/WarmResumptionTest.cpp deleted file mode 100644 index 35e98301f..000000000 --- a/test/WarmResumptionTest.cpp +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include - -#include "RSocketTests.h" - -#include "rsocket/RSocketServiceHandler.h" -#include "yarpl/flowable/TestSubscriber.h" - -using namespace rsocket; -using namespace rsocket::tests::client_server; -using namespace yarpl::flowable; - -namespace { - -class HelloServiceHandler : public RSocketServiceHandler { - public: - folly::Expected onNewSetup( - const SetupParameters&) override { - return RSocketConnectionParams( - std::make_shared()); - } - - void onNewRSocketState( - std::shared_ptr state, - ResumeIdentificationToken token) override { - store_.lock()->insert({token, std::move(state)}); - } - - folly::Expected, RSocketException> - onResume(ResumeIdentificationToken token) override { - auto itr = store_->find(token); - CHECK(itr != store_->end()); - return itr->second; - }; - - private: - folly::Synchronized< - std::map>, - std::mutex> - store_; -}; - -} // anonymous namespace - -TEST(WarmResumptionTest, SimpleStream) { - auto server = makeResumableServer(std::make_shared()); - auto client = makeResumableClient(*server->listeningPort()); - auto requester = client->getRequester(); - auto ts = TestSubscriber::create(7 /* initialRequestN */); - requester->requestStream(Payload("Bob")) - ->map([](auto p) { return p.moveDataToString(); }) - ->subscribe(ts); - // Wait for a few frames before disconnecting. - while (ts->getValueCount() < 3) { - std::this_thread::yield(); - } - client->disconnect(std::runtime_error("Test triggered disconnect")); - EXPECT_NO_THROW(client->resume().get()); - ts->request(3); - ts->awaitTerminalEvent(); - ts->assertSuccess(); - ts->assertValueCount(10); -} diff --git a/test/framing/FrameTransportTest.cpp b/test/framing/FrameTransportTest.cpp deleted file mode 100644 index adb4a83b4..000000000 --- a/test/framing/FrameTransportTest.cpp +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include - -#include "rsocket/framing/FrameTransport.h" -#include "test/test_utils/MockDuplexConnection.h" -#include "test/test_utils/MockFrameProcessor.h" - -using namespace rsocket; -using namespace testing; - -namespace { - -/* - * Compare a `const folly::IOBuf&` against a `const std::string&`. - */ -MATCHER_P(IOBufStringEq, s, "") { - return folly::IOBufEqual()(*arg, *folly::IOBuf::copyBuffer(s)); -} - -} - -TEST(FrameTransport, Close) { - auto connection = std::make_unique>( - [](auto) {}, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onComplete_()); - }); - - auto transport = yarpl::make_ref(std::move(connection)); - transport->setFrameProcessor( - std::make_shared>()); - transport->close(); -} - -TEST(FrameTransport, CloseWithError) { - auto connection = std::make_unique>( - [](auto) {}, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onError_(_)); - }); - - auto transport = yarpl::make_ref(std::move(connection)); - transport->setFrameProcessor( - std::make_shared>()); - transport->closeWithError(std::runtime_error("Uh oh")); -} - -TEST(FrameTransport, SimpleEnqueue) { - auto connection = std::make_unique>( - [](auto) {}, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - - EXPECT_CALL(*output, onNext_(IOBufStringEq("Hello"))); - EXPECT_CALL(*output, onNext_(IOBufStringEq("World"))); - - EXPECT_CALL(*output, onComplete_()); - }); - - auto transport = yarpl::make_ref(std::move(connection)); - - transport->outputFrameOrEnqueue(folly::IOBuf::copyBuffer("Hello")); - transport->outputFrameOrEnqueue(folly::IOBuf::copyBuffer("World")); - - transport->setFrameProcessor( - std::make_shared>()); - transport->close(); -} - -TEST(FrameTransport, SimpleNoQueue) { - auto connection = std::make_unique>( - [](auto) {}, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - - EXPECT_CALL(*output, onNext_(IOBufStringEq("Hello"))); - EXPECT_CALL(*output, onNext_(IOBufStringEq("World"))); - - EXPECT_CALL(*output, onComplete_()); - }); - - auto transport = yarpl::make_ref(std::move(connection)); - - transport->setFrameProcessor( - std::make_shared>()); - - transport->outputFrameOrEnqueue(folly::IOBuf::copyBuffer("Hello")); - transport->outputFrameOrEnqueue(folly::IOBuf::copyBuffer("World")); - - transport->close(); -} - -TEST(FrameTransport, InputSendsError) { - auto connection = std::make_unique>( - [](auto input) { - auto subscription = yarpl::make_ref>(); - EXPECT_CALL(*subscription, request_(_)); - EXPECT_CALL(*subscription, cancel_()); - - input->onSubscribe(std::move(subscription)); - input->onError(std::make_exception_ptr(std::runtime_error("Oops"))); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onComplete_()); - }); - - auto transport = yarpl::make_ref(std::move(connection)); - - auto processor = std::make_shared>(); - EXPECT_CALL(*processor, onTerminal_(_)); - - transport->setFrameProcessor(std::move(processor)); - transport->close(); -} diff --git a/test/handlers/HelloStreamRequestHandler.cpp b/test/handlers/HelloStreamRequestHandler.cpp deleted file mode 100644 index 82731275d..000000000 --- a/test/handlers/HelloStreamRequestHandler.cpp +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "HelloStreamRequestHandler.h" -#include -#include "yarpl/Flowable.h" - -using namespace ::rsocket; -using namespace yarpl; -using namespace yarpl::flowable; - -namespace rsocket { -namespace tests { -/// Handles a new inbound Stream requested by the other end. -Reference> -HelloStreamRequestHandler::handleRequestStream( - rsocket::Payload request, - rsocket::StreamId) { - LOG(INFO) << "HelloStreamRequestHandler.handleRequestStream " << request; - - // string from payload data - auto requestString = request.moveDataToString(); - - return Flowables::range(1, 10)->map([name = std::move(requestString)]( - int64_t v) { - std::stringstream ss; - ss << "Hello " << name << " " << v << "!"; - std::string s = ss.str(); - return Payload(s, "metadata"); - }); -} -} -} diff --git a/test/handlers/HelloStreamRequestHandler.h b/test/handlers/HelloStreamRequestHandler.h deleted file mode 100644 index 768641bda..000000000 --- a/test/handlers/HelloStreamRequestHandler.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/RSocketResponder.h" -#include "yarpl/Flowable.h" - -namespace rsocket { -namespace tests { - -class HelloStreamRequestHandler : public RSocketResponder { - public: - /// Handles a new inbound Stream requested by the other end. - yarpl::Reference> - handleRequestStream(rsocket::Payload request, rsocket::StreamId streamId) - override; -}; -} -} diff --git a/test/internal/AllowanceSemaphoreTest.cpp b/test/internal/AllowanceSemaphoreTest.cpp deleted file mode 100644 index 5c390c477..000000000 --- a/test/internal/AllowanceSemaphoreTest.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include "rsocket/internal/AllowanceSemaphore.h" - -using namespace ::testing; -using namespace ::rsocket; - -TEST(AllowanceSemaphoreTest, Finite) { - AllowanceSemaphore sem; - - ASSERT_FALSE(sem.canAcquire()); - ASSERT_FALSE(sem.tryAcquire()); - - ASSERT_EQ(0U, sem.release(1)); - ASSERT_FALSE(sem.canAcquire(2)); - ASSERT_TRUE(sem.canAcquire()); - ASSERT_TRUE(sem.tryAcquire()); - - ASSERT_EQ(0U, sem.release(2)); - ASSERT_EQ(2U, sem.release(1)); - ASSERT_EQ(3U, sem.drain()); - ASSERT_EQ(0U, sem.drain()); - - ASSERT_EQ(0U, sem.release(2)); - ASSERT_FALSE(sem.canAcquire(3)); - ASSERT_FALSE(sem.tryAcquire(3)); - ASSERT_TRUE(sem.canAcquire(2)); - ASSERT_TRUE(sem.tryAcquire(2)); - ASSERT_FALSE(sem.canAcquire()); -} - -TEST(AllowanceSemaphoreTest, DrainWithLimit) { - AllowanceSemaphore sem; - - ASSERT_EQ(0U, sem.release(9)); - ASSERT_EQ(4U, sem.drainWithLimit(4)); - ASSERT_EQ(1U, sem.drainWithLimit(1)); - ASSERT_EQ(4U, sem.drainWithLimit(100)); -} diff --git a/test/internal/SetupResumeAcceptorTest.cpp b/test/internal/SetupResumeAcceptorTest.cpp deleted file mode 100644 index 1ad9f25ca..000000000 --- a/test/internal/SetupResumeAcceptorTest.cpp +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include - -#include - -#include "rsocket/framing/FrameTransport.h" -#include "rsocket/internal/SetupResumeAcceptor.h" -#include "test/test_utils/MockDuplexConnection.h" -#include "test/test_utils/MockFrameProcessor.h" -#include "test/test_utils/Mocks.h" - -using namespace rsocket; -using namespace testing; - -namespace { - -/* - * Make a legitimate-looking SETUP frame. - */ -Frame_SETUP makeSetup() { - Frame_SETUP frame; - frame.header_ = FrameHeader{FrameType::SETUP, FrameFlags::EMPTY, 0}; - frame.versionMajor_ = 1; - frame.versionMinor_ = 0; - frame.keepaliveTime_ = Frame_SETUP::kMaxKeepaliveTime; - frame.maxLifetime_ = Frame_SETUP::kMaxLifetime; - frame.token_ = ResumeIdentificationToken::generateNew(); - frame.metadataMimeType_ = "application/olive+oil"; - frame.dataMimeType_ = "json/vorhees"; - frame.payload_ = Payload("Test SETUP data", "Test SETUP metadata"); - return frame; -} - -/* - * Make a legitimate-looking RESUME frame. - */ -Frame_RESUME makeResume() { - Frame_RESUME frame; - frame.header_ = FrameHeader{FrameType::RESUME, FrameFlags::EMPTY, 0}; - frame.versionMajor_ = 1; - frame.versionMinor_ = 0; - frame.token_ = ResumeIdentificationToken::generateNew(); - frame.lastReceivedServerPosition_ = 500; - frame.clientPosition_ = 300; - return frame; -} - -void setupFail(yarpl::Reference transport, SetupParameters) { - transport->close(); - FAIL() << "setupFail() was called"; -} - -bool resumeFail(yarpl::Reference transport, ResumeParameters) { - transport->close(); - ADD_FAILURE() << "resumeFail() was called"; - return false; -} -} - -TEST(SetupResumeAcceptor, ImmediateDtor) { - folly::EventBase evb; - SetupResumeAcceptor acceptor1{ProtocolVersion::Latest, &evb}; - SetupResumeAcceptor acceptor2{ProtocolVersion::Unknown, &evb}; -} - -TEST(SetupResumeAcceptor, ImmediateClose) { - folly::EventBase evb; - SetupResumeAcceptor acceptor1{ProtocolVersion::Latest, &evb}; - SetupResumeAcceptor acceptor2{ProtocolVersion::Unknown, &evb}; - acceptor1.close().get(); - acceptor2.close().get(); -} - - -TEST(SetupResumeAcceptor, EarlyComplete) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - input->onComplete(); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onComplete_()); - }); - - acceptor.accept(std::move(connection), setupFail, resumeFail); - - evb.loop(); -} - -TEST(SetupResumeAcceptor, EarlyError) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - input->onError(std::make_exception_ptr(std::runtime_error("Whoops"))); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onError_(_)); - }); - - acceptor.accept(std::move(connection), setupFail, resumeFail); - - evb.loop(); -} - -TEST(SetupResumeAcceptor, SingleSetup) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - auto serializer = FrameSerializer::createCurrentVersion(); - input->onNext(serializer->serializeOut(makeSetup())); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onComplete_()); - }); - - bool setupCalled = false; - - acceptor.accept( - std::move(connection), - [&](auto transport, auto) { - transport->close(); - setupCalled = true; - }, - resumeFail); - - evb.loop(); - - EXPECT_TRUE(setupCalled); -} - -TEST(SetupResumeAcceptor, SetupAndFnf) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - auto serializer = FrameSerializer::createCurrentVersion(); - - auto setup = makeSetup(); - Frame_REQUEST_FNF fnf{100, FrameFlags::EMPTY, Payload("Hi")}; - - input->onNext(serializer->serializeOut(std::move(setup))); - input->onNext(serializer->serializeOut(std::move(fnf))); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onComplete_()); - }); - - yarpl::Reference transport; - - acceptor.accept( - std::move(connection), - [&](auto tport, auto) { transport = std::move(tport); }, - resumeFail); - - evb.loop(); - - EXPECT_TRUE(transport.get()); - - auto processor = std::make_shared>(); - EXPECT_CALL(*processor, processFrame_(_)) - .WillOnce(Invoke([](auto const& buf) { - auto serializer = FrameSerializer::createCurrentVersion(); - - Frame_REQUEST_FNF fnf; - EXPECT_TRUE(serializer->deserializeFrom(fnf, buf->clone())); - EXPECT_EQ(fnf.header_.streamId_, 100u); - EXPECT_EQ(fnf.header_.flags_, FrameFlags::EMPTY); - EXPECT_EQ(fnf.payload_.cloneDataToString(), "Hi"); - })); - transport->setFrameProcessor(processor); - transport->close(); -} - -TEST(SetupResumeAcceptor, InvalidSetup) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - auto serializer = FrameSerializer::createCurrentVersion(); - - // Bogus keepalive time that can't be deserialized. - auto setup = makeSetup(); - setup.keepaliveTime_ = -5; - - input->onNext(serializer->serializeOut(std::move(setup))); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onNext_(_)).WillOnce(Invoke([](auto const& buf) { - auto serializer = FrameSerializer::createCurrentVersion(); - Frame_ERROR frame; - EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); - EXPECT_EQ(frame.errorCode_, ErrorCode::CONNECTION_ERROR); - })); - EXPECT_CALL(*output, onError_(_)); - }); - - acceptor.accept(std::move(connection), setupFail, resumeFail); - - evb.loop(); -} - -TEST(SetupResumeAcceptor, RejectedSetup) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - auto serializer = FrameSerializer::createCurrentVersion(); - input->onNext(serializer->serializeOut(makeSetup())); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onNext_(_)).WillOnce(Invoke([](auto const& buf) { - auto serializer = FrameSerializer::createCurrentVersion(); - Frame_ERROR frame; - EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); - EXPECT_EQ(frame.errorCode_, ErrorCode::REJECTED_SETUP); - })); - EXPECT_CALL(*output, onError_(_)); - }); - - bool setupCalled = false; - - acceptor.accept( - std::move(connection), - [&](auto, auto) { - setupCalled = true; - throw std::runtime_error("Oops"); - }, - resumeFail); - - evb.loop(); - - EXPECT_TRUE(setupCalled); -} - -TEST(SetupResumeAcceptor, RejectedResume) { - folly::EventBase evb; - SetupResumeAcceptor acceptor{ProtocolVersion::Latest, &evb}; - - auto connection = std::make_unique>( - [](auto input) { - auto serializer = FrameSerializer::createCurrentVersion(); - input->onNext(serializer->serializeOut(makeResume())); - }, - [](auto output) { - EXPECT_CALL(*output, onSubscribe_(_)); - EXPECT_CALL(*output, onNext_(_)).WillOnce(Invoke([](auto const& buf) { - auto serializer = FrameSerializer::createCurrentVersion(); - Frame_ERROR frame; - EXPECT_TRUE(serializer->deserializeFrom(frame, buf->clone())); - EXPECT_EQ(frame.errorCode_, ErrorCode::REJECTED_RESUME); - })); - EXPECT_CALL(*output, onError_(_)); - }); - - bool resumeCalled = false; - - acceptor.accept(std::move(connection), setupFail, [&](auto, auto) { - resumeCalled = true; - throw std::runtime_error("Cant resume"); - }); - - evb.loop(); - - EXPECT_TRUE(resumeCalled); -} - -// TODO: Test for whether changing FrameProcessor in on{Resume,Setup} breaks -// things. diff --git a/test/statemachine/StreamStateTest.cpp b/test/statemachine/StreamStateTest.cpp deleted file mode 100644 index 142b425f2..000000000 --- a/test/statemachine/StreamStateTest.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include "rsocket/statemachine/StreamState.h" -#include "test/test_utils/MockStats.h" - -using namespace rsocket; -using namespace testing; - -class StreamStateTest : public Test { - protected: - StrictMock stats_; - StreamState state_{stats_}; -}; - -TEST_F(StreamStateTest, Stats) { - auto frame1Size = 7, frame2Size = 11; - EXPECT_CALL(stats_, streamBufferChanged(1, frame1Size)); - state_.enqueueOutputPendingFrame( - folly::IOBuf::copyBuffer(std::string(frame1Size, 'x'))); - EXPECT_CALL(stats_, streamBufferChanged(1, frame2Size)); - state_.enqueueOutputPendingFrame( - folly::IOBuf::copyBuffer(std::string(frame2Size, 'x'))); - EXPECT_CALL(stats_, streamBufferChanged(-2, -(frame1Size + frame2Size))); - state_.moveOutputPendingFrames(); -} - -TEST_F(StreamStateTest, StatsUpdatedInDtor) { - auto frameSize = 7; - EXPECT_CALL(stats_, streamBufferChanged(1, frameSize)); - state_.enqueueOutputPendingFrame( - folly::IOBuf::copyBuffer(std::string(frameSize, 'x'))); - EXPECT_CALL(stats_, streamBufferChanged(-1, -frameSize)); -} diff --git a/test/test_utils/MockDuplexConnection.h b/test/test_utils/MockDuplexConnection.h deleted file mode 100644 index 068ef2538..000000000 --- a/test/test_utils/MockDuplexConnection.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "rsocket/DuplexConnection.h" -#include "test/test_utils/Mocks.h" - -namespace rsocket { - -class MockDuplexConnection : public DuplexConnection { -public: - using Subscriber = yarpl::flowable::Subscriber>; - - MockDuplexConnection() { - ON_CALL(*this, getOutput_()).WillByDefault(testing::Invoke([] { - return yarpl::make_ref>>(); - })); - } - - /// Creates a DuplexConnection that always runs `in` on the input - /// subscriber and `out` on a default MockSubscriber. - template - MockDuplexConnection(InputFn in, OutputFn out) { - EXPECT_CALL(*this, setInput_(testing::_)) - .WillRepeatedly(testing::Invoke(std::move(in))); - EXPECT_CALL(*this, getOutput_()) - .WillRepeatedly(testing::Invoke([out = std::move(out)] { - auto subscriber = - yarpl::make_ref>>(); - out(subscriber); - return subscriber; - })); - } - - // DuplexConnection. - - void setInput(yarpl::Reference in) override { - setInput_(std::move(in)); - } - - yarpl::Reference getOutput() override { - return getOutput_(); - } - - // Mocks. - - MOCK_METHOD1(setInput_, void(yarpl::Reference)); - MOCK_METHOD0(getOutput_, yarpl::Reference()); -}; - -} diff --git a/test/test_utils/MockFrameProcessor.h b/test/test_utils/MockFrameProcessor.h deleted file mode 100644 index ab1393f0f..000000000 --- a/test/test_utils/MockFrameProcessor.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include -#include - -#include "rsocket/framing/FrameProcessor.h" - -namespace rsocket { - -class MockFrameProcessor : public FrameProcessor { -public: - void processFrame(std::unique_ptr buf) override { - processFrame_(buf); - } - - void onTerminal(folly::exception_wrapper ew) override { - onTerminal_(std::move(ew)); - } - - MOCK_METHOD1(processFrame_, void(std::unique_ptr&)); - MOCK_METHOD1(onTerminal_, void(folly::exception_wrapper)); -}; - -} diff --git a/test/test_utils/MockKeepaliveTimer.h b/test/test_utils/MockKeepaliveTimer.h deleted file mode 100644 index 272ddc357..000000000 --- a/test/test_utils/MockKeepaliveTimer.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include - -#include "rsocket/statemachine/RSocketStateMachine.h" -#include "test/deprecated/ReactiveSocket.h" - -namespace rsocket { -class MockKeepaliveTimer : public KeepaliveTimer { - public: - MOCK_METHOD1(start, void(const std::shared_ptr&)); - MOCK_METHOD0(stop, void()); - MOCK_METHOD0(keepaliveReceived, void()); - MOCK_METHOD0(keepaliveTime, std::chrono::milliseconds()); -}; -} diff --git a/test/test_utils/MockRequestHandler.h b/test/test_utils/MockRequestHandler.h deleted file mode 100644 index d0cfb7135..000000000 --- a/test/test_utils/MockRequestHandler.h +++ /dev/null @@ -1,117 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include - -#include "rsocket/Payload.h" -#include "rsocket/temporary_home/RequestHandler.h" - -namespace rsocket { - -class MockRequestHandler : public RequestHandler { - public: - MOCK_METHOD3( - handleRequestChannel_, - yarpl::Reference>( - Payload& request, - StreamId streamId, - const yarpl::Reference>&)); - MOCK_METHOD3( - handleRequestStream_, - void( - Payload& request, - StreamId streamId, - const yarpl::Reference>&)); - MOCK_METHOD3( - handleRequestResponse_, - void( - Payload& request, - StreamId streamId, - const yarpl::Reference>&)); - MOCK_METHOD2( - handleFireAndForgetRequest_, - void(Payload& request, StreamId streamId)); - MOCK_METHOD1( - handleMetadataPush_, - void(std::unique_ptr& request)); - MOCK_METHOD1( - handleSetupPayload_, - std::shared_ptr(SetupParameters& request)); - MOCK_METHOD1(handleResume_, bool(ResumeParameters& resumeParams)); - - yarpl::Reference> handleRequestChannel( - Payload request, - StreamId streamId, - const yarpl::Reference>& - response) noexcept override { - return handleRequestChannel_(request, streamId, response); - } - - void handleRequestStream( - Payload request, - StreamId streamId, - const yarpl::Reference>& - response) noexcept override { - handleRequestStream_(request, streamId, response); - } - - void handleRequestResponse( - Payload request, - StreamId streamId, - const yarpl::Reference>& - response) noexcept override { - handleRequestResponse_(request, streamId, response); - } - - void handleFireAndForgetRequest( - Payload request, - StreamId streamId) noexcept override { - handleFireAndForgetRequest_(request, streamId); - } - - void handleMetadataPush( - std::unique_ptr request) noexcept override { - handleMetadataPush_(request); - } - - std::shared_ptr handleSetupPayload( - SetupParameters request) noexcept override { - return handleSetupPayload_(request); - } - - bool handleResume(ResumeParameters resumeParams) noexcept override { - return handleResume_(resumeParams); - } - - void handleCleanResume(yarpl::Reference - response) noexcept override {} - void handleDirtyResume(yarpl::Reference - response) noexcept override {} - - MOCK_METHOD1( - onSubscriptionPaused_, - void(const yarpl::Reference&)); - void onSubscriptionPaused( - const yarpl::Reference& - subscription) noexcept override { - onSubscriptionPaused_(std::move(subscription)); - } - void onSubscriptionResumed( - const yarpl::Reference& - subscription) noexcept override {} - void onSubscriberPaused( - const yarpl::Reference>& - subscriber) noexcept override {} - void onSubscriberResumed( - const yarpl::Reference>& - subscriber) noexcept override {} - - MOCK_METHOD0(socketOnConnected, void()); - - MOCK_METHOD1(socketOnClosed, void(folly::exception_wrapper& listener)); - MOCK_METHOD1(socketOnDisconnected, void(folly::exception_wrapper& listener)); -}; -} diff --git a/test/test_utils/PrintSubscriber.cpp b/test/test_utils/PrintSubscriber.cpp deleted file mode 100644 index 862a52eb9..000000000 --- a/test/test_utils/PrintSubscriber.cpp +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "PrintSubscriber.h" -#include -#include -#include -#include "yarpl/utils/ExceptionString.h" - -namespace rsocket { - -PrintSubscriber::~PrintSubscriber() { - LOG(INFO) << "~PrintSubscriber " << this; -} - -void PrintSubscriber::onSubscribe( - yarpl::Reference subscription) noexcept { - LOG(INFO) << "PrintSubscriber " << this << " onSubscribe"; - subscription->request(std::numeric_limits::max()); -} - -void PrintSubscriber::onNext(Payload element) noexcept { - LOG(INFO) << "PrintSubscriber " << this << " onNext " << element; -} - -void PrintSubscriber::onComplete() noexcept { - LOG(INFO) << "PrintSubscriber " << this << " onComplete"; -} - -void PrintSubscriber::onError(std::exception_ptr ex) noexcept { - LOG(INFO) << "PrintSubscriber " << this << " onError " - << yarpl::exceptionStr(ex); -} -} diff --git a/test/test_utils/PrintSubscriber.h b/test/test_utils/PrintSubscriber.h deleted file mode 100644 index bfdd1265e..000000000 --- a/test/test_utils/PrintSubscriber.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "rsocket/Payload.h" -#include "yarpl/flowable/Subscriber.h" - -namespace rsocket { -class PrintSubscriber : public yarpl::flowable::Subscriber { - public: - ~PrintSubscriber(); - - void onSubscribe(yarpl::Reference - subscription) noexcept override; - void onNext(Payload element) noexcept override; - void onComplete() noexcept override; - void onError(std::exception_ptr ex) noexcept override; -}; -} diff --git a/test/transport/DuplexConnectionTest.cpp b/test/transport/DuplexConnectionTest.cpp deleted file mode 100644 index aff577e91..000000000 --- a/test/transport/DuplexConnectionTest.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "DuplexConnectionTest.h" - -#include -#include "../test_utils/Mocks.h" - -namespace rsocket { -namespace tests { - -using namespace folly; -using namespace rsocket; -using namespace ::testing; - -void makeMultipleSetInputGetOutputCalls( - std::unique_ptr serverConnection, - EventBase* serverEvb, - std::unique_ptr clientConnection, - EventBase* clientEvb) { - auto serverSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); - EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(10); - yarpl::Reference>> serverOutput; - yarpl::Reference serverSubscription; - - serverEvb->runInEventBaseThreadAndWait( - [&connection = serverConnection, - &input = serverSubscriber, - &output = serverOutput, - &subscription = serverSubscription]() { - // Keep receiving messages from different subscribers - connection->setInput(input); - // Get another subscriber and send messages - output = connection->getOutput(); - subscription = yarpl::make_ref(); - EXPECT_CALL(*subscription, request_(_)).Times(AtLeast(1)); - EXPECT_CALL(*subscription, cancel_()); - output->onSubscribe(subscription); - }); - - for (int i = 0; i < 10; ++i) { - auto clientSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); - EXPECT_CALL(*clientSubscriber, onNext_(_)); - yarpl::Reference clientSubscription; - - clientEvb->runInEventBaseThreadAndWait( - [&connection = clientConnection, - &input = clientSubscriber, - &subscription = clientSubscription]() { - // Set another subscriber and receive messages - connection->setInput(input); - // Get another subscriber and send messages - auto output = connection->getOutput(); - subscription = yarpl::make_ref(); - EXPECT_CALL(*subscription, request_(_)).Times(AtLeast(1)); - EXPECT_CALL(*subscription, cancel_()); - output->onSubscribe(subscription); - output->onNext(folly::IOBuf::copyBuffer("01234")); - output->onComplete(); - }); - serverSubscriber->awaitFrames(1); - - serverEvb->runInEventBaseThreadAndWait( - [&output = serverOutput]() { - output->onNext(folly::IOBuf::copyBuffer("43210")); - }); - clientSubscriber->awaitFrames(1); - - clientEvb->runInEventBaseThreadAndWait( - [subscriber = std::move(clientSubscriber), - subscription = std::move(clientSubscription)]() { - subscription->cancel(); - // Enables calling setInput again with another subscriber. - subscriber->subscription()->cancel(); - }); - } - - // Cleanup - serverEvb->runInEventBaseThreadAndWait( - [subscriber = std::move(serverSubscriber), - output = std::move(serverOutput), - subscription = std::move(serverSubscription)]() { - output->onComplete(); - subscription->cancel(); - subscriber->subscription()->cancel(); - }); - clientEvb->runInEventBaseThreadAndWait([& connection = clientConnection]() { - auto connectionDeleter = std::move(connection); - }); - serverEvb->runInEventBaseThreadAndWait([& connection = serverConnection]() { - auto connectionDeleter = std::move(connection); - }); -} - -/** - * Closing an Input or Output should not effect the other. - */ -void verifyInputAndOutputIsUntied( - std::unique_ptr serverConnection, - EventBase* serverEvb, - std::unique_ptr clientConnection, - EventBase* clientEvb) { - auto serverSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); - EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(3); - yarpl::Reference>> serverOutput; - auto serverSubscription = yarpl::make_ref(); - EXPECT_CALL(*serverSubscription, request_(_)).Times(AtLeast(1)); - EXPECT_CALL(*serverSubscription, cancel_()); - - serverEvb->runInEventBaseThreadAndWait( - [&connection = serverConnection, - &input = serverSubscriber, - &output = serverOutput, - &subscription = serverSubscription]() { - connection->setInput(input); - output = connection->getOutput(); - output->onSubscribe(subscription); - }); - - auto clientSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); - yarpl::Reference>> clientOutput; - auto clientSubscription = yarpl::make_ref(); - EXPECT_CALL(*clientSubscription, request_(_)).Times(AtLeast(1)); - EXPECT_CALL(*clientSubscription, cancel_()); - - clientEvb->runInEventBaseThreadAndWait( - [&connection = clientConnection, - &input = clientSubscriber, - &output = clientOutput, - &subscription = clientSubscription]() { - connection->setInput(input); - output = connection->getOutput(); - output->onSubscribe(subscription); - output->onNext(folly::IOBuf::copyBuffer("01234")); - }); - serverSubscriber->awaitFrames(1); - - clientEvb->runInEventBaseThreadAndWait( - [&subscriber = clientSubscriber, &output = clientOutput]() { - // Close the client subscriber - { - subscriber->subscription()->cancel(); - auto deleteSubscriber = std::move(subscriber); - } - // Output is still active - output->onNext(folly::IOBuf::copyBuffer("01234")); - }); - serverSubscriber->awaitFrames(1); - - // Another client subscriber - clientSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); - EXPECT_CALL(*clientSubscriber, onNext_(_)); - clientEvb->runInEventBaseThreadAndWait( - [&connection = clientConnection, - &input = clientSubscriber, - &output = clientOutput]() { - // Set new input subscriber - connection->setInput(input); - output->onNext(folly::IOBuf::copyBuffer("01234")); - }); - serverSubscriber->awaitFrames(1); - - // Close output subscriber of client - clientEvb->runInEventBaseThreadAndWait( - [output = std::move(clientOutput), - subscription = std::move(clientSubscription)]() { - subscription->cancel(); - output->onComplete(); - }); - - // Still sending message from server to the client. - serverEvb->runInEventBaseThreadAndWait([&output = serverOutput]() { - output->onNext(folly::IOBuf::copyBuffer("43210")); - output->onComplete(); - }); - clientSubscriber->awaitFrames(1); - - // Cleanup - clientEvb->runInEventBaseThreadAndWait( - [subscriber = std::move(clientSubscriber)]() { - subscriber->subscription()->cancel(); - }); - serverEvb->runInEventBaseThreadAndWait( - [subscriber = std::move(serverSubscriber), - output = std::move(serverOutput), - subscription = std::move(serverSubscription)]() { - subscription->cancel(); - output->onComplete(); - subscriber->subscription()->cancel(); - }); - clientEvb->runInEventBaseThreadAndWait([& connection = clientConnection]() { - auto connectionDeleter = std::move(connection); - }); - serverEvb->runInEventBaseThreadAndWait([& connection = serverConnection]() { - auto connectionDeleter = std::move(connection); - }); -} - -void verifyClosingInputAndOutputDoesntCloseConnection( - std::unique_ptr serverConnection, - folly::EventBase* serverEvb, - std::unique_ptr clientConnection, - folly::EventBase* clientEvb) { - auto serverSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); - yarpl::Reference>> serverOutput; - auto serverSubscription = yarpl::make_ref(); - EXPECT_CALL(*serverSubscription, request_(_)); - EXPECT_CALL(*serverSubscription, cancel_()); - - serverEvb->runInEventBaseThreadAndWait( - [&connection = serverConnection, - &input = serverSubscriber, - &output = serverOutput, - &subscription = serverSubscription]() { - connection->setInput(input); - output = connection->getOutput(); - output->onSubscribe(subscription); - }); - - auto clientSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); - yarpl::Reference>> clientOutput; - auto clientSubscription = yarpl::make_ref(); - EXPECT_CALL(*clientSubscription, request_(_)); - EXPECT_CALL(*clientSubscription, cancel_()); - - clientEvb->runInEventBaseThreadAndWait( - [&connection = clientConnection, - &input = clientSubscriber, - &output = clientOutput, - &subscription = clientSubscription]() { - connection->setInput(input); - output = connection->getOutput(); - output->onSubscribe(subscription); - }); - - // Close all subscribers - clientEvb->runInEventBaseThreadAndWait( - [input = std::move(clientSubscriber), - output = std::move(clientOutput), - subscription = std::move(clientSubscription)]() { - subscription->cancel(); - output->onComplete(); - input->subscription()->cancel(); - }); - - serverEvb->runInEventBaseThreadAndWait( - [input = std::move(serverSubscriber), - output = std::move(serverOutput), - subscription = std::move(serverSubscription)]() { - subscription->cancel(); - output->onComplete(); - input->subscription()->cancel(); - }); - - // Set new subscribers as the connection is not closed - serverSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*serverSubscriber, onSubscribe_(_)); - EXPECT_CALL(*serverSubscriber, onNext_(_)).Times(1); - // The subscriber is to be closed, as the subscription is not cancelled - // but the connection is closed at the end - EXPECT_CALL(*serverSubscriber, onComplete_()); - - serverSubscription = yarpl::make_ref(); - EXPECT_CALL(*serverSubscription, request_(_)); - EXPECT_CALL(*serverSubscription, cancel_()); - - serverEvb->runInEventBaseThreadAndWait( - [&connection = serverConnection, - &input = serverSubscriber, - &output = serverOutput, - &subscription = serverSubscription]() { - connection->setInput(input); - output = connection->getOutput(); - output->onSubscribe(subscription); - }); - - clientSubscriber = - yarpl::make_ref>>(); - EXPECT_CALL(*clientSubscriber, onSubscribe_(_)); - EXPECT_CALL(*clientSubscriber, onNext_(_)).Times(1); - // The subscriber is to be closed, as the subscription is not cancelled - // but the connection is closed at the end - EXPECT_CALL(*clientSubscriber, onComplete_()); - - clientSubscription = yarpl::make_ref(); - EXPECT_CALL(*clientSubscription, request_(_)); - EXPECT_CALL(*clientSubscription, cancel_()); - - clientEvb->runInEventBaseThreadAndWait( - [&connection = clientConnection, - &input = clientSubscriber, - &output = clientOutput, - &subscription = clientSubscription]() { - connection->setInput(input); - output = connection->getOutput(); - output->onSubscribe(subscription); - output->onNext(folly::IOBuf::copyBuffer("01234")); - }); - serverSubscriber->awaitFrames(1); - - // Wait till client is ready before sending message from server. - serverEvb->runInEventBaseThreadAndWait( - [&output = serverOutput]() { - output->onNext(folly::IOBuf::copyBuffer("43210")); - }); - clientSubscriber->awaitFrames(1); - - // Cleanup - clientEvb->runInEventBaseThreadAndWait([& connection = clientConnection]() { - auto connectionDeleter = std::move(connection); - }); - serverEvb->runInEventBaseThreadAndWait([& connection = serverConnection]() { - auto connectionDeleter = std::move(connection); - }); -} - -} // namespace tests -} // namespace rsocket diff --git a/yarpl/CMakeLists.txt b/yarpl/CMakeLists.txt index 34e7cb94d..f4159b82c 100644 --- a/yarpl/CMakeLists.txt +++ b/yarpl/CMakeLists.txt @@ -1,27 +1,51 @@ cmake_minimum_required (VERSION 3.2) - -# To debug the project, set the build type. -set(CMAKE_BUILD_TYPE Debug) - project (yarpl) # CMake Config +set(CMAKE_MODULE_PATH + ${CMAKE_CURRENT_SOURCE_DIR}/../cmake/ + # For shipit-transformed builds + "${CMAKE_CURRENT_SOURCE_DIR}/../build/fbcode_builder/CMake" + ${CMAKE_MODULE_PATH} +) add_definitions(-std=c++14) +option(BUILD_TESTS "BUILD_TESTS" ON) # Generate compilation database set(CMAKE_EXPORT_COMPILE_COMMANDS 1) # Common configuration for all build modes. -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-unused-parameter") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables -Wno-padded") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -momit-leaf-frame-pointer") +if (NOT MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Wno-unused-parameter") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-weak-vtables -Wno-padded") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-omit-frame-pointer") + include(CheckCXXCompilerFlag) + CHECK_CXX_COMPILER_FLAG("-momit-leaf-frame-pointer" HAVE_OMIT_LEAF_FRAME_POINTER) + if(HAVE_OMIT_LEAF_FRAME_POINTER) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -momit-leaf-frame-pointer") + endif() +endif() + +if(YARPL_WRAP_SHARED_IN_LOCK) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DYARPL_WRAP_SHARED_IN_LOCK") + message("Compiler lacks support std::atomic; wrapping with a mutex") +elseif(YARPL_WRAP_SHARED_IN_ATOMIC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DYARPL_WRAP_SHARED_IN_ATOMIC") + message("Compiler lacks std::shared_ptr atomic overloads; wrapping in std::atomic") +else() + message("Compiler has atomic std::shared_ptr support") +endif() + + +if(${CMAKE_CXX_COMPILER_ID} MATCHES GNU) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -latomic") +endif() # The yarpl-tests binary constantly fails with an ASAN error in gtest internal # code on macOS. -if(APPLE) +if(APPLE AND ${CMAKE_CXX_COMPILER_ID} MATCHES Clang) message("== macOS detected, disabling ASAN for yarpl") add_compile_options("-fno-sanitize=address,undefined") endif() @@ -29,104 +53,141 @@ endif() # Using NDEBUG in Release builds. set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG") -find_path(GLOG_INCLUDE_DIR glog/logging.h) -find_library(GLOG_LIBRARY glog) +find_package(Gflags REQUIRED) +find_package(Glog REQUIRED) +find_package(fmt CONFIG REQUIRED) -message("glog include_dir <${GLOG_INCLUDE_DIR}> lib <${GLOG_LIBRARY}>") +IF(NOT FOLLY_VERSION) + include(${CMAKE_CURRENT_SOURCE_DIR}/../cmake/InstallFolly.cmake) +ENDIF() -include_directories(SYSTEM ${GLOG_INCLUDE_DIR}) -include_directories(${CMAKE_SOURCE_DIR}) +include_directories(SYSTEM ${GFLAGS_INCLUDE_DIR}) # library source add_library( yarpl # public API - include/yarpl/Scheduler.h - include/yarpl/Disposable.h - include/yarpl/Refcounted.h + Refcounted.h + Common.h # Flowable public API - include/yarpl/Flowable.h - include/yarpl/flowable/Flowable.h - include/yarpl/flowable/FlowableOperator.h - include/yarpl/flowable/Flowable_FromObservable.h - include/yarpl/flowable/Flowables.h - include/yarpl/flowable/Subscriber.h - include/yarpl/flowable/Subscribers.h - include/yarpl/flowable/Subscription.h - include/yarpl/flowable/TestSubscriber.h - src/yarpl/flowable/sources/Subscription.cpp + Flowable.h + flowable/DeferFlowable.h + flowable/EmitterFlowable.h + flowable/Flowable.h + flowable/FlowableOperator.h + flowable/FlowableConcatOperators.h + flowable/FlowableDoOperator.h + flowable/FlowableObserveOnOperator.h + flowable/Flowable_FromObservable.h + flowable/Flowables.h + flowable/PublishProcessor.h + flowable/Subscriber.h + flowable/Subscription.h + flowable/TestSubscriber.h + flowable/Subscription.cpp + flowable/Flowables.cpp # Observable public API - include/yarpl/Observable.h - include/yarpl/observable/Observable.h - include/yarpl/observable/Observables.h - include/yarpl/observable/ObservableOperator.h - include/yarpl/observable/Observer.h - include/yarpl/observable/Observers.h - include/yarpl/observable/Subscription.h - include/yarpl/observable/Subscriptions.h - include/yarpl/observable/TestObserver.h - src/yarpl/observable/Subscriptions.cpp + Observable.h + observable/DeferObservable.h + observable/Observable.h + observable/Observables.h + observable/ObservableOperator.h + observable/ObservableConcatOperators.h + observable/ObservableDoOperator.h + observable/Observer.h + observable/Subscription.h + observable/TestObserver.h + observable/Subscription.cpp + observable/Observables.cpp # Single - include/yarpl/Single.h - include/yarpl/single/Single.h - include/yarpl/single/Singles.h - include/yarpl/single/SingleOperator.h - include/yarpl/single/SingleObserver.h - include/yarpl/single/SingleObservers.h - include/yarpl/single/SingleSubscription.h - include/yarpl/single/SingleSubscriptions.h - include/yarpl/single/SingleTestObserver.h + Single.h + single/Single.h + single/Singles.h + single/SingleOperator.h + single/SingleObserver.h + single/SingleObservers.h + single/SingleSubscription.h + single/SingleSubscriptions.h + single/SingleTestObserver.h # utils - include/yarpl/utils/type_traits.h - include/yarpl/utils/credits.h - src/yarpl/utils/credits.cpp - include/yarpl/utils/ExceptionString.h - # Scheduler - include/yarpl/schedulers/ThreadScheduler.h - src/yarpl/schedulers/ThreadScheduler.cpp) - + utils/credits.h + utils/credits.cpp) target_include_directories( - yarpl - PUBLIC "${PROJECT_SOURCE_DIR}/include" # allow include paths such as "yarpl/observable.h" - PUBLIC "${PROJECT_SOURCE_DIR}/src" # allow include paths such as "yarpl/flowable/FlowableRange.h" - ) + yarpl + PUBLIC + $ + $ +) -target_link_libraries( - yarpl - folly - ${GLOG_LIBRARY}) - -# Executable for experimenting. -add_executable( - yarpl-playground - examples/yarpl-playground.cpp - examples/FlowableExamples.cpp - examples/FlowableExamples.h) - -target_link_libraries(yarpl-playground yarpl) - -# Unit tests. -add_executable( - yarpl-tests - test/FlowableTest.cpp - test/Observable_test.cpp - test/RefcountedTest.cpp - test/ReferenceTest.cpp - test/Scheduler_test.cpp - test/Single_test.cpp - test/Tuple.cpp - test/Tuple.h - test/credits-test.cpp - test/yarpl-tests.cpp) +message("yarpl source dir: ${CMAKE_CURRENT_SOURCE_DIR}") target_link_libraries( - yarpl-tests yarpl - ${GLOG_LIBRARY} - - # Inherited from rsocket-cpp CMake. - ${GMOCK_LIBS}) - -add_dependencies(yarpl-tests gmock) - -add_test(NAME yarpl-tests COMMAND yarpl-tests) + PUBLIC Folly::folly glog::glog gflags + INTERFACE ${EXTRA_LINK_FLAGS}) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + cmake/yarpl-config.cmake.in + yarpl-config.cmake + INSTALL_DESTINATION lib/cmake/yarpl +) +install(TARGETS yarpl EXPORT yarpl-exports DESTINATION lib) +install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DESTINATION include FILES_MATCHING PATTERN "*.h") +install( + EXPORT yarpl-exports + NAMESPACE yarpl:: + DESTINATION lib/cmake/yarpl +) +install( + FILES ${CMAKE_CURRENT_BINARY_DIR}/yarpl-config.cmake + DESTINATION lib/cmake/yarpl +) + +# RSocket's tests also has dependency on this library +add_library( + yarpl-test-utils + test_utils/Tuple.cpp + test_utils/Tuple.h + test_utils/Mocks.h) + +if (BUILD_TESTS) + # Executable for experimenting. + add_executable( + yarpl-playground + examples/yarpl-playground.cpp + examples/FlowableExamples.cpp + examples/FlowableExamples.h) + + target_link_libraries(yarpl-playground yarpl) + + # Unit tests. + add_executable( + yarpl-tests + test/MocksTest.cpp + test/FlowableTest.cpp + test/FlowableFlatMapTest.cpp + test/Observable_test.cpp + test/PublishProcessorTest.cpp + test/SubscribeObserveOnTests.cpp + test/Single_test.cpp + test/FlowableSubscriberTest.cpp + test/credits-test.cpp + test/yarpl-tests.cpp) + + add_dependencies(yarpl-tests gmock) + target_link_libraries( + yarpl-tests + yarpl + yarpl-test-utils + glog::glog + gflags + + # Inherited from rsocket-cpp CMake. + ${GMOCK_LIBS} # This also needs the preceding `add_dependencies` + ) + + add_dependencies(yarpl-tests yarpl-test-utils gmock) + + add_test(NAME yarpl-tests COMMAND yarpl-tests) +endif() diff --git a/yarpl/Common.h b/yarpl/Common.h new file mode 100644 index 000000000..be9d8287c --- /dev/null +++ b/yarpl/Common.h @@ -0,0 +1,69 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace yarpl { + +namespace observable { +template +class Observable; +} // namespace observable + +namespace flowable { +template +class Subscriber; + +// Exception thrown in case the downstream can't keep up. +class MissingBackpressureException : public std::runtime_error { + public: + MissingBackpressureException() + : std::runtime_error("BACK_PRESSURE: DROP (missing credits onNext)") {} +}; + +} // namespace flowable + +/** + *Strategy for backpressure when converting from Observable to Flowable. + */ +enum class BackpressureStrategy { + BUFFER, // Buffers all onNext values until the downstream consumes them. + DROP, // Drops the most recent onNext value if the downstream can't keep up. + ERROR, // Signals a MissingBackpressureException in case the downstream can't + // keep up. + LATEST, // Keeps only the latest onNext value, overwriting any previous value + // if the downstream can't keep up. + MISSING // OnNext events are written without any buffering or dropping. +}; + +template +class IBackpressureStrategy { + public: + virtual ~IBackpressureStrategy() = default; + + virtual void init( + std::shared_ptr> upstream, + std::shared_ptr> downstream) = 0; + + static std::shared_ptr> buffer(); + static std::shared_ptr> drop(); + static std::shared_ptr> error(); + static std::shared_ptr> latest(); + static std::shared_ptr> missing(); +}; + +} // namespace yarpl diff --git a/yarpl/Disposable.h b/yarpl/Disposable.h new file mode 100644 index 000000000..6cc5d8264 --- /dev/null +++ b/yarpl/Disposable.h @@ -0,0 +1,42 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace yarpl { + +/** + * Represents a disposable resource. + */ +class Disposable { + public: + Disposable() {} + virtual ~Disposable() = default; + Disposable(Disposable&&) = delete; + Disposable(const Disposable&) = delete; + Disposable& operator=(Disposable&&) = delete; + Disposable& operator=(const Disposable&) = delete; + + /** + * Dispose the resource, the operation should be idempotent. + */ + virtual void dispose() = 0; + + /** + * Returns true if this resource has been disposed. + * @return true if this resource has been disposed + */ + virtual bool isDisposed() = 0; +}; +} // namespace yarpl diff --git a/yarpl/Flowable.h b/yarpl/Flowable.h new file mode 100644 index 000000000..34014ddd7 --- /dev/null +++ b/yarpl/Flowable.h @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// include all the things a developer needs for using Flowable +#include "yarpl/flowable/Flowable.h" +#include "yarpl/flowable/Flowables.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/flowable/Subscription.h" + +/** + * // TODO add documentation + */ diff --git a/yarpl/Observable.h b/yarpl/Observable.h new file mode 100644 index 000000000..d115d5160 --- /dev/null +++ b/yarpl/Observable.h @@ -0,0 +1,25 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +// include all the things a developer needs for using Observable +#include "yarpl/observable/Observable.h" +#include "yarpl/observable/Observables.h" +#include "yarpl/observable/Observer.h" +#include "yarpl/observable/Subscription.h" + +/** + * // TODO add documentation + */ diff --git a/yarpl/Refcounted.h b/yarpl/Refcounted.h new file mode 100644 index 000000000..ac0a4950d --- /dev/null +++ b/yarpl/Refcounted.h @@ -0,0 +1,85 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace yarpl { + +template +struct AtomicReference { + folly::Synchronized, std::mutex> ref; + + AtomicReference() = default; + + AtomicReference(std::shared_ptr&& r) { + *(ref.lock()) = std::move(r); + } +}; + +template +std::shared_ptr atomic_load(AtomicReference* ar) { + return *(ar->ref.lock()); +} + +template +std::shared_ptr atomic_exchange( + AtomicReference* ar, + std::shared_ptr r) { + auto refptr = ar->ref.lock(); + auto old = std::move(*refptr); + *refptr = std::move(r); + return old; +} + +template +std::shared_ptr atomic_exchange(AtomicReference* ar, std::nullptr_t) { + return atomic_exchange(ar, std::shared_ptr()); +} + +template +void atomic_store(AtomicReference* ar, std::shared_ptr r) { + *ar->ref.lock() = std::move(r); +} + +class enable_get_ref : public std::enable_shared_from_this { + private: + virtual void dummy_internal_get_ref() {} + + protected: + // materialize a reference to 'this', but a type even further derived from + // Derived, because C++ doesn't have covariant return types on methods + template + std::shared_ptr ref_from_this(As* ptr) { + // at runtime, ensure that the most derived class can indeed be + // converted into an 'as' + (void)ptr; // silence 'unused parameter' errors in Release builds + return std::static_pointer_cast(this->shared_from_this()); + } + + template + std::shared_ptr ref_from_this(As const* ptr) const { + // at runtime, ensure that the most derived class can indeed be + // converted into an 'as' + (void)ptr; // silence 'unused parameter' errors in Release builds + return std::static_pointer_cast(this->shared_from_this()); + } + + public: + virtual ~enable_get_ref() = default; +}; + +} /* namespace yarpl */ diff --git a/yarpl/Single.h b/yarpl/Single.h new file mode 100644 index 000000000..c5b737b5f --- /dev/null +++ b/yarpl/Single.h @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Refcounted.h" + +// include all the things a developer needs for using Single +#include "yarpl/single/Single.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleObservers.h" +#include "yarpl/single/SingleSubscriptions.h" +#include "yarpl/single/Singles.h" + +/** + * Create a single with code such as this: + * + * auto a = Single::create([](std::shared_ptr> obs) { + * obs->onSubscribe(SingleSubscriptions::empty()); + * obs->onSuccess(1); + * }); + * + * // TODO add more documentation + */ diff --git a/yarpl/cmake/yarpl-config.cmake.in b/yarpl/cmake/yarpl-config.cmake.in new file mode 100644 index 000000000..d557b2135 --- /dev/null +++ b/yarpl/cmake/yarpl-config.cmake.in @@ -0,0 +1,13 @@ +# Copyright (c) 2018, Facebook, Inc. +# All rights reserved. + +@PACKAGE_INIT@ + +if(NOT TARGET yarpl::yarpl) + include("${PACKAGE_PREFIX_DIR}/lib/cmake/yarpl/yarpl-exports.cmake") +endif() + +set(YARPL_LIBRARIES yarpl::yarpl) +if (NOT yarpl_FIND_QUIETLY) + message(STATUS "Found YARPL: ${PACKAGE_PREFIX_DIR}") +endif() diff --git a/yarpl/examples/FlowableExamples.cpp b/yarpl/examples/FlowableExamples.cpp index 2beb0c8c8..224726166 100644 --- a/yarpl/examples/FlowableExamples.cpp +++ b/yarpl/examples/FlowableExamples.cpp @@ -1,30 +1,38 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "FlowableExamples.h" - +#include #include #include #include #include - -#include "yarpl/schedulers/ThreadScheduler.h" - #include "yarpl/Flowable.h" -using namespace yarpl; using namespace yarpl::flowable; namespace { template auto printer() { - return Subscribers::create( + return Subscriber::create( [](T value) { std::cout << " next: " << value << std::endl; }, 2 /* low [optional] batch size for demo */); } -Reference> getData() { - return Flowables::range(2, 5); +std::shared_ptr> getData() { + return Flowable<>::range(2, 5); } std::string getThreadId() { @@ -34,7 +42,7 @@ std::string getThreadId() { } void fromPublisherExample() { - auto onSubscribe = [](Reference> subscriber) { + auto onSubscribe = [](std::shared_ptr> subscriber) { class Subscription : public ::yarpl::flowable::Subscription { public: virtual void request(int64_t delta) override { @@ -46,7 +54,7 @@ void fromPublisherExample() { } }; - auto subscription = make_ref(); + auto subscription = std::make_shared(); subscriber->onSubscribe(subscription); subscriber->onNext(1234); subscriber->onNext(5678); @@ -54,7 +62,7 @@ void fromPublisherExample() { subscriber->onComplete(); }; - Flowables::fromPublisher(std::move(onSubscribe)) + Flowable::fromPublisher(std::move(onSubscribe)) ->subscribe(printer()); } @@ -62,31 +70,31 @@ void fromPublisherExample() { void FlowableExamples::run() { std::cout << "create a flowable" << std::endl; - Flowables::range(2, 2); + Flowable<>::range(2, 2); std::cout << "get a flowable from a method" << std::endl; getData()->subscribe(printer()); std::cout << "just: single value" << std::endl; - Flowables::just(23)->subscribe(printer()); + Flowable<>::just(23)->subscribe(printer()); std::cout << "just: multiple values." << std::endl; - Flowables::justN({1, 4, 7, 11})->subscribe(printer()); + Flowable<>::justN({1, 4, 7, 11})->subscribe(printer()); std::cout << "just: string values." << std::endl; - Flowables::justN({"the", "quick", "brown", "fox"}) + Flowable<>::justN({"the", "quick", "brown", "fox"}) ->subscribe(printer()); std::cout << "range operator." << std::endl; - Flowables::range(1, 4)->subscribe(printer()); + Flowable<>::range(1, 4)->subscribe(printer()); std::cout << "map example: squares" << std::endl; - Flowables::range(1, 4) + Flowable<>::range(1, 4) ->map([](int64_t v) { return v * v; }) ->subscribe(printer()); std::cout << "map example: convert to string" << std::endl; - Flowables::range(1, 4) + Flowable<>::range(1, 4) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return std::to_string(v); }) @@ -94,37 +102,30 @@ void FlowableExamples::run() { ->subscribe(printer()); std::cout << "take example: 3 out of 10 items" << std::endl; - Flowables::range(1, 11)->take(3)->subscribe(printer()); + Flowable<>::range(1, 11)->take(3)->subscribe(printer()); - auto flowable = Flowable::create([total = 0]( - Subscriber & subscriber, int64_t requested) mutable { - subscriber.onNext(12345678); - subscriber.onError(std::make_exception_ptr(std::runtime_error("error"))); - return std::make_tuple(int64_t{1}, false); - }); + auto flowable = Flowable::create( + [total = 0](auto& subscriber, int64_t requested) mutable { + subscriber.onNext(12345678); + subscriber.onError(std::runtime_error("error")); + }); - auto subscriber = Subscribers::create( + auto subscriber = Subscriber::create( [](int next) { std::cout << "@next: " << next << std::endl; }, - [](std::exception_ptr eptr) { - try { - std::rethrow_exception(eptr); - } catch (const std::exception& exception) { - std::cerr << " exception: " << exception.what() << std::endl; - } catch (...) { - std::cerr << " !unknown exception!" << std::endl; - } + [](folly::exception_wrapper ex) { + std::cerr << " exception: " << ex << std::endl; }, [] { std::cout << "Completed." << std::endl; }); flowable->subscribe(subscriber); - ThreadScheduler scheduler; + folly::ScopedEventBaseThread worker; std::cout << "subscribe_on example" << std::endl; - Flowables::justN({"0: ", "1: ", "2: "}) + Flowable<>::justN({"0: ", "1: ", "2: "}) ->map([](const char* p) { return std::string(p); }) ->map([](std::string log) { return log + " on " + getThreadId(); }) - ->subscribeOn(scheduler) + ->subscribeOn(*worker.getEventBase()) ->subscribe(printer()); std::cout << " waiting on " << getThreadId() << std::endl; std::this_thread::sleep_for(std::chrono::milliseconds(10)); diff --git a/yarpl/examples/FlowableExamples.h b/yarpl/examples/FlowableExamples.h index 0613efa5a..675140b8b 100644 --- a/yarpl/examples/FlowableExamples.h +++ b/yarpl/examples/FlowableExamples.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once diff --git a/yarpl/examples/yarpl-playground.cpp b/yarpl/examples/yarpl-playground.cpp index c3cfc0f99..5fbe2f4c5 100644 --- a/yarpl/examples/yarpl-playground.cpp +++ b/yarpl/examples/yarpl-playground.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include diff --git a/yarpl/flowable/AsyncGeneratorShim.h b/yarpl/flowable/AsyncGeneratorShim.h new file mode 100644 index 000000000..72d212c83 --- /dev/null +++ b/yarpl/flowable/AsyncGeneratorShim.h @@ -0,0 +1,165 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once +#include +#include +#include +#include +#include +#include +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace detail { +template +class AsyncGeneratorShim { + public: + AsyncGeneratorShim( + folly::coro::AsyncGenerator&& generator, + folly::SequencedExecutor* ex) + : generator_(std::move(generator)), + sharedState_(std::make_shared()) { + sharedState_->executor_ = folly::getKeepAliveToken(ex); + } + + void subscribe( + std::shared_ptr> subscriber) && { + class Subscription : public yarpl::flowable::Subscription { + public: + explicit Subscription(std::weak_ptr state) + : state_(std::move(state)) {} + + void request(int64_t n) override { + if (auto state = state_.lock()) { + state->executor_->add([n, state = std::move(state)]() { + if (state->requested_ == credits::kNoFlowControl || + n == credits::kNoFlowControl) { + state->requested_ = credits::kNoFlowControl; + } else { + state->requested_ += n; + } + state->baton_.post(); + }); + } + } + + void cancel() override { + if (auto state = state_.lock()) { + state->executor_->add([state = std::move(state)]() { + // requestCancellation will execute registered CancellationCallback + // inline, but CancellationCallback should be run in + // executor_ thread + state->cancelSource_.requestCancellation(); + state->baton_.post(); + }); + } + } + + private: + std::weak_ptr state_; + }; + sharedState_->executor_->add( + [keepAlive = sharedState_->executor_.copy(), + subscriber, + subscription = std::make_shared( + std::weak_ptr(sharedState_))]() mutable { + subscriber->onSubscribe(std::move(subscription)); + }); + auto executor = sharedState_->executor_.get(); + folly::coro::co_withCancellation( + sharedState_->cancelSource_.getToken(), + folly::coro::co_invoke( + [subscriber = std::move(subscriber), + self = std::move(*this)]() mutable -> folly::coro::Task { + while (true) { + while (self.sharedState_->requested_ == 0 && + !self.sharedState_->cancelSource_ + .isCancellationRequested()) { + co_await self.sharedState_->baton_; + self.sharedState_->baton_.reset(); + } + + if (self.sharedState_->cancelSource_ + .isCancellationRequested()) { + self.sharedState_->executor_->add( + [subscriber = std::move(subscriber)]() { + // destory subscriber on executor_ thread + }); + co_return; + } + + folly::Try value; + try { + auto item = co_await self.generator_.next(); + + if (item.has_value()) { + value.emplace(std::move(*item)); + } + } catch (const std::exception& ex) { + value.emplaceException(std::current_exception(), ex); + } catch (...) { + value.emplaceException(std::current_exception()); + } + + if (value.hasValue()) { + self.sharedState_->executor_->add( + [subscriber, + keepAlive = self.sharedState_->executor_.copy(), + value = std::move(value)]() mutable { + subscriber->onNext(std::move(value).value()); + }); + } else if (value.hasException()) { + self.sharedState_->executor_->add( + [subscriber = std::move(subscriber), + keepAlive = self.sharedState_->executor_.copy(), + value = std::move(value)]() mutable { + subscriber->onError(std::move(value).exception()); + }); + co_return; + } else { + self.sharedState_->executor_->add( + [subscriber = std::move(subscriber), + keepAlive = + self.sharedState_->executor_.copy()]() mutable { + subscriber->onComplete(); + }); + co_return; + } + + if (self.sharedState_->requested_ != credits::kNoFlowControl) { + self.sharedState_->requested_--; + } + } + })) + .scheduleOn(std::move(executor)) + .start(); + } + + private: + struct SharedState { + SharedState() = default; + explicit SharedState(folly::CancellationSource source) + : cancelSource_(std::move(source)) {} + folly::Executor::KeepAlive executor_; + int64_t requested_{0}; + folly::coro::Baton baton_{0}; + folly::CancellationSource cancelSource_; + }; + + folly::coro::AsyncGenerator generator_; + std::shared_ptr sharedState_; +}; +} // namespace detail + +template +std::shared_ptr> toFlowable( + folly::coro::AsyncGenerator gen, + folly::SequencedExecutor* ex = folly::getEventBase()) { + return yarpl::flowable::internal::flowableFromSubscriber( + [gen = std::move(gen), + ex](std::shared_ptr> subscriber) mutable { + detail::AsyncGeneratorShim(std::move(gen), ex) + .subscribe(std::move(subscriber)); + }); +} +} // namespace yarpl diff --git a/yarpl/flowable/CancelingSubscriber.h b/yarpl/flowable/CancelingSubscriber.h new file mode 100644 index 000000000..0933a6908 --- /dev/null +++ b/yarpl/flowable/CancelingSubscriber.h @@ -0,0 +1,47 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/Subscriber.h" + +#include + +namespace yarpl { +namespace flowable { + +/** + * A Subscriber that always cancels the subscription passed to it. + */ +template +class CancelingSubscriber final : public BaseSubscriber { + public: + void onSubscribeImpl() override { + this->cancel(); + } + + void onNextImpl(T) override { + throw std::logic_error{"CancelingSubscriber::onNext() can never be called"}; + } + void onCompleteImpl() override { + throw std::logic_error{ + "CancelingSubscriber::onComplete() can never be called"}; + } + void onErrorImpl(folly::exception_wrapper) override { + throw std::logic_error{ + "CancelingSubscriber::onError() can never be called"}; + } +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/DeferFlowable.h b/yarpl/flowable/DeferFlowable.h new file mode 100644 index 000000000..b817c85f4 --- /dev/null +++ b/yarpl/flowable/DeferFlowable.h @@ -0,0 +1,49 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template +class DeferFlowable : public Flowable { + static_assert( + std::is_same, FlowableFactory>::value, + "undecayed"); + + public: + template + explicit DeferFlowable(F&& factory) : factory_(std::forward(factory)) {} + + virtual void subscribe(std::shared_ptr> subscriber) { + std::shared_ptr> flowable; + try { + flowable = factory_(); + } catch (const std::exception& ex) { + flowable = Flowable::error(ex, std::current_exception()); + } + flowable->subscribe(std::move(subscriber)); + } + + private: + FlowableFactory factory_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/EmitterFlowable.h b/yarpl/flowable/EmitterFlowable.h new file mode 100644 index 000000000..5c5089551 --- /dev/null +++ b/yarpl/flowable/EmitterFlowable.h @@ -0,0 +1,320 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include + +#include "yarpl/flowable/Flowable.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template +class EmitterBase { + public: + virtual ~EmitterBase() = default; + + virtual std::tuple emit( + std::shared_ptr>, + int64_t) = 0; +}; + +/** + * Manager for a flowable subscription. + * + * This is synchronous: the emit calls are triggered within the context + * of a request(n) call. + */ +template +class EmiterSubscription final : public Subscription, + public Subscriber, + public yarpl::enable_get_ref { + constexpr static auto kCanceled = credits::kCanceled; + constexpr static auto kNoFlowControl = credits::kNoFlowControl; + + public: + EmiterSubscription( + std::shared_ptr> emitter, + std::shared_ptr> subscriber) + : emitter_(std::move(emitter)), subscriber_(std::move(subscriber)) {} + + void init() { + subscriber_->onSubscribe(this->ref_from_this(this)); + } + + virtual ~EmiterSubscription() { + subscriber_.reset(); + } + + void request(int64_t delta) override { + while (true) { + auto current = requested_.load(std::memory_order_relaxed); + + if (current == kCanceled) { + // this can happen because there could be an async barrier between the + // subscriber and the subscription for instance while onComplete is + // being delivered (on effectively cancelled subscription) the + // subscriber can call request(n) + return; + } + + auto const total = credits::add(current, delta); + if (requested_.compare_exchange_strong(current, total)) { + break; + } + } + + process(); + } + + void cancel() override { + // if this is the first terminating signal to receive, we need to + // make sure we break the reference cycle between subscription and + // subscriber + auto previous = requested_.exchange(kCanceled, std::memory_order_relaxed); + if (previous != kCanceled) { + // this can happen because there could be an async barrier between the + // subscriber and the subscription for instance while onComplete is being + // delivered (on effectively cancelled subscription) the subscriber can + // call request(n) + process(); + } + } + + // Subscriber methods. + void onSubscribe(std::shared_ptr) override { + LOG(FATAL) << "Do not call this method"; + } + + void onNext(T value) override { +#ifndef NDEBUG + DCHECK(!hasFinished_) << "onComplete() or onError() already called"; +#endif + if (subscriber_) { + subscriber_->onNext(std::move(value)); + } else { + DCHECK(requested_.load(std::memory_order_relaxed) == kCanceled); + } + } + + void onComplete() override { +#ifndef NDEBUG + DCHECK(!hasFinished_) << "onComplete() or onError() already called"; + hasFinished_ = true; +#endif + if (subscriber_) { + subscriber_->onComplete(); + } else { + DCHECK(requested_.load(std::memory_order_relaxed) == kCanceled); + } + } + + void onError(folly::exception_wrapper error) override { +#ifndef NDEBUG + DCHECK(!hasFinished_) << "onComplete() or onError() already called"; + hasFinished_ = true; +#endif + if (subscriber_) { + subscriber_->onError(error); + } else { + DCHECK(requested_.load(std::memory_order_relaxed) == kCanceled); + } + } + + private: + // Processing loop. Note: this can delete `this` upon completion, + // error, or cancellation; thus, no fields should be accessed once + // this method returns. + // + // Thread-Safety: there is no guarantee as to which thread this is + // invoked on. However, there is a strong guarantee on cancel and + // request(n) calls: no more than one instance of either of these + // can be outstanding at any time. + void process() { + // Guards against re-entrancy in request(n) calls. + if (processing_.exchange(true)) { + return; + } + + auto guard = folly::makeGuard([this] { processing_ = false; }); + + // Keep a reference to ourselves here in case the emit() call + // frees all other references to 'this' + auto this_subscriber = this->ref_from_this(this); + + while (true) { + auto current = requested_.load(std::memory_order_relaxed); + + // Subscription was canceled, completed, or had an error. + if (current == kCanceled) { + guard.dismiss(); + release(); + return; + } + + // If no more items can be emitted now, wait for a request(n). + // See note above re: thread-safety. We are guaranteed that + // request(n) is not simultaneously invoked on another thread. + if (current <= 0) + return; + + int64_t emitted; + bool done; + + std::tie(emitted, done) = emitter_->emit(this_subscriber, current); + + while (true) { + current = requested_.load(std::memory_order_relaxed); + if (current == kCanceled) { + break; + } + int64_t updated; + // generally speaking updated will be number of credits lefted over + // after emitter_->emit(), so updated = current - emitted + // need to handle case where done = true and avoid doing arithmetic + // operation on kNoFlowControl + + // in asynchrnous emitter cases, might have emitted=kNoFlowControl + // this means that emitter will take the responsibility to send the + // whole conext and credits lefted over should be set to 0. + if (current == kNoFlowControl) { + updated = + done ? kCanceled : emitted == kNoFlowControl ? 0 : kNoFlowControl; + } else { + updated = done ? kCanceled : current - emitted; + } + if (requested_.compare_exchange_strong(current, updated)) { + break; + } + } + } + } + + void release() { + emitter_.reset(); + subscriber_.reset(); + } + + // The number of items that can be sent downstream. Each request(n) + // adds n; each onNext consumes 1. If this is MAX, flow-control is + // disabled: items sent downstream don't consume any longer. A MIN + // value represents cancellation. Other -ve values aren't permitted. + std::atomic_int_fast64_t requested_{0}; + +#ifndef NDEBUG + bool hasFinished_{false}; // onComplete or onError called +#endif + + // We don't want to recursively invoke process(); one loop should do. + std::atomic_bool processing_{false}; + + std::shared_ptr> emitter_; + std::shared_ptr> subscriber_; +}; + +template +class TrackingSubscriber : public Subscriber { + public: + TrackingSubscriber( + Subscriber& subscriber, + int64_t +#ifndef NDEBUG + requested +#endif + ) + : inner_(&subscriber) +#ifndef NDEBUG + , + requested_(requested) +#endif + { + } + + void onSubscribe(std::shared_ptr s) override { + inner_->onSubscribe(std::move(s)); + } + + void onComplete() override { + completed_ = true; + inner_->onComplete(); + } + + void onError(folly::exception_wrapper ex) override { + completed_ = true; + inner_->onError(std::move(ex)); + } + + void onNext(T value) override { +#ifndef NDEBUG + auto old = requested_; + DCHECK(old > credits::consume(requested_, 1)) + << "cannot emit more than requested"; +#endif + emitted_++; + inner_->onNext(std::move(value)); + } + + auto getResult() { + return std::make_tuple(emitted_, completed_); + } + + private: + int64_t emitted_{0}; + bool completed_{false}; + Subscriber* inner_; +#ifndef NDEBUG + int64_t requested_; +#endif +}; + +template +class EmitterWrapper : public EmitterBase, public Flowable { + static_assert( + std::is_same, Emitter>::value, + "undecayed"); + + public: + template + explicit EmitterWrapper(F&& emitter) : emitter_(std::forward(emitter)) {} + + void subscribe(std::shared_ptr> subscriber) override { + auto ef = std::make_shared>( + this->ref_from_this(this), std::move(subscriber)); + ef->init(); + } + + std::tuple emit( + std::shared_ptr> subscriber, + int64_t requested) override { + TrackingSubscriber trackingSubscriber(*subscriber, requested); + emitter_(trackingSubscriber, requested); + return trackingSubscriber.getResult(); + } + + private: + Emitter emitter_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Flowable.h b/yarpl/flowable/Flowable.h new file mode 100644 index 000000000..9dff78b03 --- /dev/null +++ b/yarpl/flowable/Flowable.h @@ -0,0 +1,749 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "yarpl/Disposable.h" +#include "yarpl/Refcounted.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { + +class TimeoutException; +namespace detail { +class TimeoutExceptionGenerator; +} + +namespace flowable { + +template +class Flowable; + +namespace details { + +template +struct IsFlowable : std::false_type {}; + +template +struct IsFlowable>> : std::true_type { + using ElemType = R; +}; + +} // namespace details + +template +class Flowable : public yarpl::enable_get_ref { + public: + virtual ~Flowable() = default; + + virtual void subscribe(std::shared_ptr>) = 0; + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + std::unique_ptr subscribe( + Next&& next, + int64_t batch = credits::kNoFlowControl) { + auto subscriber = + details::LambdaSubscriber::create(std::forward(next), batch); + subscribe(subscriber); + return std::make_unique>( + std::move(subscriber)); + } + + /** + * Subscribe overload that accepts lambdas. + * + * Takes an optional batch size for request_n. Default is no flow control. + */ + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + std::unique_ptr subscribe( + Next&& next, + Error&& e, + int64_t batch = credits::kNoFlowControl) { + auto subscriber = details::LambdaSubscriber::create( + std::forward(next), std::forward(e), batch); + subscribe(subscriber); + return std::make_unique>( + std::move(subscriber)); + } + + /** + * Subscribe overload that accepts lambdas. + * + * Takes an optional batch size for request_n. Default is no flow control. + */ + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + std::unique_ptr subscribe( + Next&& next, + Error&& e, + Complete&& complete, + int64_t batch = credits::kNoFlowControl) { + auto subscriber = details::LambdaSubscriber::create( + std::forward(next), + std::forward(e), + std::forward(complete), + batch); + subscribe(subscriber); + return std::make_unique>( + std::move(subscriber)); + } + + void subscribe() { + subscribe(Subscriber::create()); + } + + // + // creator methods: + // + + // Creates Flowable which completes the subscriber right after it subscribes + static std::shared_ptr> empty(); + + // Creates Flowable which will never terminate the subscriber + static std::shared_ptr> never(); + + // Create Flowable which will imediatelly terminate the subscriber upon + // subscription with the provided error + static std::shared_ptr> error(folly::exception_wrapper ex); + + template + static std::shared_ptr> error(Ex&) { + static_assert( + std::is_lvalue_reference::value, + "use variant of error() method accepting also exception_ptr"); + } + + template + static std::shared_ptr> error(Ex& ex, std::exception_ptr ptr) { + return Flowable::error(folly::exception_wrapper(std::move(ptr), ex)); + } + + static std::shared_ptr> just(T value) { + auto lambda = [value = std::move(value)]( + Subscriber& subscriber, int64_t requested) mutable { + DCHECK_GT(requested, 0); + subscriber.onNext(std::move(value)); + subscriber.onComplete(); + }; + + return Flowable::create(std::move(lambda)); + } + + static std::shared_ptr> justN(std::initializer_list list) { + auto lambda = [v = std::vector(std::move(list)), i = size_t{0}]( + Subscriber& subscriber, int64_t requested) mutable { + while (i < v.size() && requested-- > 0) { + subscriber.onNext(v[i++]); + } + + if (i == v.size()) { + // TODO T27302402: Even though having two subscriptions exist + // concurrently for Emitters is not possible still. At least it possible + // to resubscribe and consume the same values again. + i = 0; + subscriber.onComplete(); + } + }; + + return Flowable::create(std::move(lambda)); + } + + // this will generate a flowable which can be subscribed to only once + static std::shared_ptr> justOnce(T value) { + auto lambda = [value = std::move(value), used = false]( + Subscriber& subscriber, int64_t) mutable { + if (used) { + subscriber.onError( + std::runtime_error("justOnce value was already used")); + return; + } + + used = true; + // # requested should be > 0. Ignoring the actual parameter. + subscriber.onNext(std::move(value)); + subscriber.onComplete(); + }; + + return Flowable::create(std::move(lambda)); + } + + template + static std::shared_ptr> fromGenerator(TGenerator&& generator); + + /** + * The Defer operator waits until a subscriber subscribes to it, and then it + * generates a Flowabe with a FlowableFactory function. It + * does this afresh for each subscriber, so although each subscriber may + * think it is subscribing to the same Flowable, in fact each subscriber + * gets its own individual sequence. + */ + template < + typename FlowableFactory, + typename = typename std::enable_if>, + std::decay_t&>::value>::type> + static std::shared_ptr> defer(FlowableFactory&&); + + template < + typename Function, + typename ErrorFunction = + folly::Function, + typename R = typename folly::invoke_result_t, + typename = typename std::enable_if&, + folly::exception_wrapper&&>::value>::type> + std::shared_ptr> map( + Function&& function, + ErrorFunction&& errormapFunc = [](folly::exception_wrapper&& ew) { + return std::move(ew); + }); + + template < + typename Function, + typename R = typename details::IsFlowable< + typename folly::invoke_result_t>::ElemType> + std::shared_ptr> flatMap(Function&& func); + + template + std::shared_ptr> filter(Function&& function); + + template < + typename Function, + typename R = typename folly::invoke_result_t> + std::shared_ptr> reduce(Function&& function); + + std::shared_ptr> take(int64_t); + + std::shared_ptr> skip(int64_t); + + std::shared_ptr> ignoreElements(); + + /* + * To instruct a Flowable to do its work on a particular Executor. + * the onSubscribe, request and cancel methods will be scheduled on the + * provided executor + */ + std::shared_ptr> subscribeOn(folly::Executor&); + + std::shared_ptr> observeOn(folly::Executor&); + + std::shared_ptr> observeOn(folly::Executor::KeepAlive<>); + + std::shared_ptr> concatWith(std::shared_ptr>); + + template + std::shared_ptr> concatWith( + std::shared_ptr> first, + Args... args) { + return concatWith(first)->concatWith(args...); + } + + template + static std::shared_ptr> concat( + std::shared_ptr> first, + Args... args) { + return first->concatWith(args...); + } + + template + using enableWrapRef = + typename std::enable_if::value, Q>::type; + + // Combines multiple Flowables so that they act like a + // single Flowable. The items + // emitted by the merged Flowables may interlieve. + template + enableWrapRef merge() { + return this->flatMap([](auto f) { return std::move(f); }); + } + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnSubscribe(Function&& function); + + // function is invoked when onNext occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>::type> + std::shared_ptr> doOnNext(Function&& function); + + // function is invoked when onError occurs. + template < + typename Function, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> doOnError(Function&& function); + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnComplete(Function&& function); + + // function is invoked when either onComplete or onError occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnTerminate(Function&& function); + + // the function is invoked for each of onNext, onCompleted, onError + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnEach(Function&& function); + + // function is invoked when request(n) is called. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&, int64_t>::value>::type> + std::shared_ptr> doOnRequest(Function&& function); + + // function is invoked when cancel is called. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnCancel(Function&& function); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> + doOn(OnNextFunc&& onNext, OnCompleteFunc&& onComplete, OnErrorFunc&& onError); + + template < + typename ExceptionGenerator = yarpl::detail::TimeoutExceptionGenerator> + std::shared_ptr> timeout( + folly::EventBase& timerEvb, + std::chrono::milliseconds timeout, + std::chrono::milliseconds initTimeout, + ExceptionGenerator&& exnGen = ExceptionGenerator()); + + template < + typename Emitter, + typename = typename std::enable_if&, + Subscriber&, + int64_t>::value>::type> + static std::shared_ptr> create(Emitter&& emitter); + + template < + typename OnSubscribe, + typename = typename std::enable_if>>::value>::type> + // TODO(lehecka): enable this warning once mobile code is clear + // [[deprecated( + // "Flowable::fromPublisher is deprecated: Use PublishProcessor or " + // "contact rsocket team if you can't figure out what to replace it " + // "with")]] + static std::shared_ptr> fromPublisher(OnSubscribe&& function); +}; + +} // namespace flowable +} // namespace yarpl + +#include "yarpl/flowable/DeferFlowable.h" +#include "yarpl/flowable/EmitterFlowable.h" +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +namespace flowable { + +template +template +std::shared_ptr> Flowable::create(Emitter&& emitter) { + return std::make_shared>>( + std::forward(emitter)); +} + +template +std::shared_ptr> Flowable::empty() { + class EmptyFlowable : public Flowable { + void subscribe(std::shared_ptr> subscriber) override { + subscriber->onSubscribe(Subscription::create()); + // does not wait for request(n) to complete + subscriber->onComplete(); + } + }; + return std::make_shared(); +} + +template +std::shared_ptr> Flowable::never() { + class NeverFlowable : public Flowable { + void subscribe(std::shared_ptr> subscriber) override { + subscriber->onSubscribe(Subscription::create()); + } + }; + return std::make_shared(); +} + +template +std::shared_ptr> Flowable::error(folly::exception_wrapper ex) { + class ErrorFlowable : public Flowable { + void subscribe(std::shared_ptr> subscriber) override { + subscriber->onSubscribe(Subscription::create()); + // does not wait for request(n) to error + subscriber->onError(ex_); + } + folly::exception_wrapper ex_; + + public: + explicit ErrorFlowable(folly::exception_wrapper ew) : ex_(std::move(ew)) {} + }; + return std::make_shared(std::move(ex)); +} + +namespace internal { +template +std::shared_ptr> flowableFromSubscriber(OnSubscribe&& function) { + return std::make_shared>>( + std::forward(function)); +} +} // namespace internal + +// TODO(lehecka): remove +template +template +std::shared_ptr> Flowable::fromPublisher( + OnSubscribe&& function) { + return internal::flowableFromSubscriber( + std::forward(function)); +} + +template +template +std::shared_ptr> Flowable::fromGenerator( + TGenerator&& generator) { + auto lambda = [generator = std::forward(generator)]( + Subscriber& subscriber, int64_t requested) mutable { + try { + while (requested-- > 0) { + subscriber.onNext(generator()); + } + } catch (const std::exception& ex) { + subscriber.onError( + folly::exception_wrapper(std::current_exception(), ex)); + } catch (...) { + subscriber.onError(std::runtime_error( + "Flowable::fromGenerator() threw from Subscriber:onNext()")); + } + }; + return Flowable::create(std::move(lambda)); +} // namespace flowable + +template +template +std::shared_ptr> Flowable::defer(FlowableFactory&& factory) { + return std::make_shared< + details::DeferFlowable>>( + std::forward(factory)); +} + +template +template +std::shared_ptr> Flowable::map( + Function&& function, + ErrorFunction&& errorFunction) { + return std::make_shared< + MapOperator, std::decay_t>>( + this->ref_from_this(this), + std::forward(function), + std::forward(errorFunction)); +} + +template +template +std::shared_ptr> Flowable::filter(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +template +std::shared_ptr> Flowable::reduce(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +std::shared_ptr> Flowable::take(int64_t limit) { + return std::make_shared>(this->ref_from_this(this), limit); +} + +template +std::shared_ptr> Flowable::skip(int64_t offset) { + return std::make_shared>(this->ref_from_this(this), offset); +} + +template +std::shared_ptr> Flowable::ignoreElements() { + return std::make_shared>(this->ref_from_this(this)); +} + +template +std::shared_ptr> Flowable::subscribeOn( + folly::Executor& executor) { + return std::make_shared>( + this->ref_from_this(this), executor); +} + +template +std::shared_ptr> Flowable::observeOn(folly::Executor& executor) { + return observeOn(folly::getKeepAliveToken(executor)); +} + +template +std::shared_ptr> Flowable::observeOn( + folly::Executor::KeepAlive<> executor) { + return std::make_shared>( + this->ref_from_this(this), std::move(executor)); +} + +template +template +std::shared_ptr> Flowable::flatMap(Function&& function) { + return std::make_shared>( + this->ref_from_this(this), std::forward(function)); +} + +template +std::shared_ptr> Flowable::concatWith( + std::shared_ptr> next) { + return std::make_shared>( + this->ref_from_this(this), std::move(next)); +} + +template +template +std::shared_ptr> Flowable::doOnSubscribe(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + std::forward(function), + [](const T&) {}, + [](const auto&) {}, + [] {}, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnNext(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(function), + [](const auto&) {}, + [] {}, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnError(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + std::forward(function), + [] {}, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnComplete(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [](const auto&) {}, + std::forward(function), + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnTerminate(Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnEach(Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [sharedFunction](const T&) { (*sharedFunction)(); }, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + [](const auto&) {}, + std::forward(onComplete), + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename, + typename, + typename> +std::shared_ptr> Flowable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete, + OnErrorFunc&& onError) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + std::forward(onError), + std::forward(onComplete), + [](const auto&) {}, // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnRequest(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, // onSubscribe + [](const auto&) {}, // onNext + [](const auto&) {}, // onError + [] {}, // onComplete + std::forward(function), // onRequest + [] {}); // onCancel +} + +template +template +std::shared_ptr> Flowable::doOnCancel(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, // onSubscribe + [](const auto&) {}, // onNext + [](const auto&) {}, // onError + [] {}, // onComplete + [](const auto&) {}, // onRequest + std::forward(function)); // onCancel +} + +template +template +std::shared_ptr> Flowable::timeout( + folly::EventBase& timerEvb, + std::chrono::milliseconds starvationTimeout, + std::chrono::milliseconds initTimeout, + ExceptionGenerator&& exnGen) { + return std::make_shared>( + ref_from_this(this), + timerEvb, + starvationTimeout, + initTimeout, + std::forward(exnGen)); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableConcatOperators.h b/yarpl/flowable/FlowableConcatOperators.h new file mode 100644 index 000000000..56694146b --- /dev/null +++ b/yarpl/flowable/FlowableConcatOperators.h @@ -0,0 +1,189 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template +class ConcatWithOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + ConcatWithOperator( + std::shared_ptr> first, + std::shared_ptr> second) + : first_(std::move(first)), second_(std::move(second)) { + CHECK(first_); + CHECK(second_); + } + + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = + std::make_shared(subscriber, first_, second_); + subscription->init(); + } + + private: + class ForwardSubscriber; + + // Downstream will always point to this subscription + class ConcatWithSubscription + : public yarpl::flowable::Subscription, + public std::enable_shared_from_this { + public: + ConcatWithSubscription( + std::shared_ptr> subscriber, + std::shared_ptr> first, + std::shared_ptr> second) + : downSubscriber_(std::move(subscriber)), + first_(std::move(first)), + second_(std::move(second)) {} + + void init() { + upSubscriber_ = + std::make_shared(this->shared_from_this()); + first_->subscribe(upSubscriber_); + downSubscriber_->onSubscribe(this->shared_from_this()); + } + + void request(int64_t n) override { + credits::add(&requested_, n); + if (!upSubscriber_) { + if (auto second = std::exchange(second_, nullptr)) { + upSubscriber_ = std::make_shared( + this->shared_from_this(), requested_); + second->subscribe(upSubscriber_); + } + } else { + upSubscriber_->request(n); + } + } + + void cancel() override { + if (auto subscriber = std::move(upSubscriber_)) { + subscriber->cancel(); + } + first_.reset(); + second_.reset(); + downSubscriber_.reset(); + upSubscriber_.reset(); + } + + void onNext(T value) { + credits::consume(&requested_, 1); + downSubscriber_->onNext(std::move(value)); + } + + void onComplete() { + upSubscriber_.reset(); + if (auto first = std::move(first_)) { + if (requested_ > 0) { + if (auto second = std::exchange(second_, nullptr)) { + upSubscriber_ = std::make_shared( + this->shared_from_this(), requested_); + // TODO - T28771728 + // Concat should not call 'subscribe' on onComplete + second->subscribe(upSubscriber_); + } + } + } else { + if (auto downSubscriber = std::exchange(downSubscriber_, nullptr)) { + downSubscriber->onComplete(); + } + upSubscriber_.reset(); + } + } + + void onError(folly::exception_wrapper ew) { + downSubscriber_->onError(std::move(ew)); + first_.reset(); + second_.reset(); + downSubscriber_.reset(); + upSubscriber_.reset(); + } + + private: + std::shared_ptr> downSubscriber_; + std::shared_ptr> first_; + std::shared_ptr> second_; + std::shared_ptr upSubscriber_; + std::atomic requested_{0}; + }; + + class ForwardSubscriber : public yarpl::flowable::Subscriber, + public yarpl::flowable::Subscription { + public: + ForwardSubscriber( + std::shared_ptr concatWithSubscription, + uint32_t initialRequest = 0u) + : concatWithSubscription_(std::move(concatWithSubscription)), + initialRequest_(initialRequest) {} + + void request(int64_t n) override { + subscription_->request(n); + } + + void cancel() override { + if (auto subs = std::move(subscription_)) { + subs->cancel(); + } else { + canceled_ = true; + } + } + + void onSubscribe(std::shared_ptr subscription) override { + if (canceled_) { + subscription->cancel(); + return; + } + subscription_ = std::move(subscription); + if (auto req = std::exchange(initialRequest_, 0)) { + subscription_->request(req); + } + } + + void onComplete() override { + auto sub = std::exchange(concatWithSubscription_, nullptr); + sub->onComplete(); + } + + void onError(folly::exception_wrapper ew) override { + auto sub = std::exchange(concatWithSubscription_, nullptr); + sub->onError(std::move(ew)); + } + void onNext(T value) override { + concatWithSubscription_->onNext(std::move(value)); + } + + private: + std::shared_ptr concatWithSubscription_; + std::shared_ptr subscription_; + + uint32_t initialRequest_{0}; + bool canceled_{false}; + }; + + private: + const std::shared_ptr> first_; + const std::shared_ptr> second_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableDoOperator.h b/yarpl/flowable/FlowableDoOperator.h new file mode 100644 index 000000000..256a345ba --- /dev/null +++ b/yarpl/flowable/FlowableDoOperator.h @@ -0,0 +1,190 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +namespace flowable { +namespace details { + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnRequestFunc, + typename OnCancelFunc> +class DoOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert( + std::is_same, OnSubscribeFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnNextFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnErrorFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCompleteFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnRequestFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCancelFunc>::value, + "undecayed"); + + public: + template < + typename FSubscribe, + typename FNext, + typename FError, + typename FComplete, + typename FRequest, + typename FCancel> + DoOperator( + std::shared_ptr> upstream, + FSubscribe&& onSubscribeFunc, + FNext&& onNextFunc, + FError&& onErrorFunc, + FComplete&& onCompleteFunc, + FRequest&& onRequestFunc, + FCancel&& onCancelFunc) + : upstream_(std::move(upstream)), + onSubscribeFunc_(std::forward(onSubscribeFunc)), + onNextFunc_(std::forward(onNextFunc)), + onErrorFunc_(std::forward(onErrorFunc)), + onCompleteFunc_(std::forward(onCompleteFunc)), + onRequestFunc_(std::forward(onRequestFunc)), + onCancelFunc_(std::forward(onCancelFunc)) {} + + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(subscriber)); + upstream_->subscribe( + // Note: implicit cast to a reference to a subscriber. + subscription); + } + + private: + class DoSubscription : public Super::Subscription { + using SuperSub = typename Super::Subscription; + + public: + DoSubscription( + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSub(std::move(subscriber)), flowable_(std::move(flowable)) {} + + void onSubscribeImpl() override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onSubscribeFunc_(); + SuperSub::onSubscribeImpl(); + } + } + + void onNextImpl(U value) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + const auto& valueRef = value; + flowable->onNextFunc_(valueRef); + SuperSub::subscriberOnNext(std::move(value)); + } + } + + void onErrorImpl(folly::exception_wrapper ex) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + const auto& exRef = ex; + flowable->onErrorFunc_(exRef); + SuperSub::onErrorImpl(std::move(ex)); + } + } + + void onCompleteImpl() override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onCompleteFunc_(); + SuperSub::onCompleteImpl(); + } + } + + void cancel() override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onCancelFunc_(); + SuperSub::cancel(); + } + } + + void request(int64_t n) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + flowable->onRequestFunc_(n); + SuperSub::request(n); + } + } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSub::onTerminateImpl(); + } + + private: + AtomicReference flowable_; + }; + + std::shared_ptr> upstream_; + OnSubscribeFunc onSubscribeFunc_; + OnNextFunc onNextFunc_; + OnErrorFunc onErrorFunc_; + OnCompleteFunc onCompleteFunc_; + OnRequestFunc onRequestFunc_; + OnCancelFunc onCancelFunc_; +}; + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnRequestFunc, + typename OnCancelFunc> +inline auto createDoOperator( + std::shared_ptr> upstream, + OnSubscribeFunc&& onSubscribeFunc, + OnNextFunc&& onNextFunc, + OnErrorFunc&& onErrorFunc, + OnCompleteFunc&& onCompleteFunc, + OnRequestFunc&& onRequestFunc, + OnCancelFunc&& onCancelFunc) { + return std::make_shared, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t>>( + std::move(upstream), + std::forward(onSubscribeFunc), + std::forward(onNextFunc), + std::forward(onErrorFunc), + std::forward(onCompleteFunc), + std::forward(onRequestFunc), + std::forward(onCancelFunc)); +} +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableObserveOnOperator.h b/yarpl/flowable/FlowableObserveOnOperator.h new file mode 100644 index 000000000..359540980 --- /dev/null +++ b/yarpl/flowable/FlowableObserveOnOperator.h @@ -0,0 +1,124 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace flowable { +namespace detail { + +template +class ObserveOnOperatorSubscriber; + +template +class ObserveOnOperatorSubscription : public yarpl::flowable::Subscription, + public yarpl::enable_get_ref { + public: + ObserveOnOperatorSubscription( + std::shared_ptr> subscriber, + std::shared_ptr subscription) + : subscriber_(std::move(subscriber)), + subscription_(std::move(subscription)) {} + + // all requesting methods are called from 'executor_' in the + // associated subscriber + void cancel() override { + auto self = this->ref_from_this(this); + + if (auto subscriber = std::move(subscriber_)) { + subscriber->inner_ = nullptr; + } + + subscription_->cancel(); + } + + void request(int64_t n) override { + subscription_->request(n); + } + + private: + std::shared_ptr> subscriber_; + std::shared_ptr subscription_; +}; + +template +class ObserveOnOperatorSubscriber : public yarpl::flowable::Subscriber, + public yarpl::enable_get_ref { + public: + ObserveOnOperatorSubscriber( + std::shared_ptr> inner, + folly::Executor::KeepAlive<> executor) + : inner_(std::move(inner)), executor_(std::move(executor)) {} + + // all signaling methods are called from upstream EB + void onSubscribe(std::shared_ptr subscription) override { + executor_->add([self = this->ref_from_this(this), + s = std::move(subscription)]() mutable { + auto sub = std::make_shared>( + self, std::move(s)); + self->inner_->onSubscribe(std::move(sub)); + }); + } + void onNext(T next) override { + executor_->add( + [self = this->ref_from_this(this), n = std::move(next)]() mutable { + if (auto& inner = self->inner_) { + inner->onNext(std::move(n)); + } + }); + } + void onComplete() override { + executor_->add([self = this->ref_from_this(this)]() mutable { + if (auto inner = std::exchange(self->inner_, nullptr)) { + inner->onComplete(); + } + }); + } + void onError(folly::exception_wrapper err) override { + executor_->add( + [self = this->ref_from_this(this), e = std::move(err)]() mutable { + if (auto inner = std::exchange(self->inner_, nullptr)) { + inner->onError(std::move(e)); + } + }); + } + + private: + friend class ObserveOnOperatorSubscription; + + std::shared_ptr> inner_; + folly::Executor::KeepAlive<> executor_; +}; + +template +class ObserveOnOperator : public yarpl::flowable::Flowable { + public: + ObserveOnOperator( + std::shared_ptr> upstream, + folly::Executor::KeepAlive<> executor) + : upstream_(std::move(upstream)), executor_(std::move(executor)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared>( + std::move(subscriber), folly::getKeepAliveToken(executor_.get()))); + } + + std::shared_ptr> upstream_; + folly::Executor::KeepAlive<> executor_; +}; +} // namespace detail +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/FlowableOperator.h b/yarpl/flowable/FlowableOperator.h new file mode 100644 index 000000000..314ba7f2e --- /dev/null +++ b/yarpl/flowable/FlowableOperator.h @@ -0,0 +1,1010 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "yarpl/flowable/Flowable.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/utils/credits.h" + +#include +#include +#include +#include +#include + +namespace yarpl { +namespace flowable { + +/** + * Base (helper) class for operators. Operators are templated on two types: D + * (downstream) and U (upstream). Operators are created by method calls on an + * upstream Flowable, and are Flowables themselves. Multi-stage pipelines can + * be built: a Flowable heading a sequence of Operators. + */ +template +class FlowableOperator : public Flowable { + protected: + /// An Operator's subscription. + /// + /// When a pipeline chain is active, each Flowable has a corresponding + /// subscription. Except for the first one, the subscriptions are created + /// against Operators. Each operator subscription has two functions: as a + /// subscriber for the previous stage; as a subscription for the next one, the + /// user-supplied subscriber being the last of the pipeline stages. + class Subscription : public yarpl::flowable::Subscription, + public BaseSubscriber { + protected: + explicit Subscription(std::shared_ptr> subscriber) + : subscriber_(std::move(subscriber)) { + CHECK(yarpl::atomic_load(&subscriber_)); + } + + // Subscriber will be provided by the init(Subscriber) call + Subscription() {} + + virtual void init(std::shared_ptr> subscriber) { + if (yarpl::atomic_load(&subscriber_)) { + subscriber->onSubscribe(yarpl::flowable::Subscription::create()); + subscriber->onError(std::runtime_error("already initialized")); + return; + } + subscriber_ = std::move(subscriber); + } + + void subscriberOnNext(D value) { + if (auto subscriber = yarpl::atomic_load(&subscriber_)) { + subscriber->onNext(std::move(value)); + } + } + + /// Terminates both ends of an operator normally. + void terminate() { + std::shared_ptr> null; + auto subscriber = yarpl::atomic_exchange(&subscriber_, null); + BaseSubscriber::cancel(); + if (subscriber) { + subscriber->onComplete(); + } + } + + /// Terminates both ends of an operator with an error. + void terminateErr(folly::exception_wrapper ew) { + std::shared_ptr> null; + auto subscriber = yarpl::atomic_exchange(&subscriber_, null); + BaseSubscriber::cancel(); + if (subscriber) { + subscriber->onError(std::move(ew)); + } + } + + // Subscription. + + void request(int64_t n) override { + BaseSubscriber::request(n); + } + + void cancel() override { + std::shared_ptr> null; + auto subscriber = yarpl::atomic_exchange(&subscriber_, null); + BaseSubscriber::cancel(); + } + + // Subscriber. + + void onSubscribeImpl() override { + yarpl::atomic_load(&subscriber_)->onSubscribe(this->ref_from_this(this)); + } + + void onCompleteImpl() override { + std::shared_ptr> null; + if (auto subscriber = yarpl::atomic_exchange(&subscriber_, null)) { + subscriber->onComplete(); + } + } + + void onErrorImpl(folly::exception_wrapper ew) override { + std::shared_ptr> null; + if (auto subscriber = yarpl::atomic_exchange(&subscriber_, null)) { + subscriber->onError(std::move(ew)); + } + } + + private: + /// This subscription controls the life-cycle of the subscriber. The + /// subscriber is retained as long as calls on it can be made. (Note: the + /// subscriber in turn maintains a reference on this subscription object + /// until cancellation and/or completion.) + AtomicReference> subscriber_; + }; +}; + +template +class MapOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); + static_assert( + folly::is_invocable_r< + folly::exception_wrapper, + EF, + folly::exception_wrapper&&>::value, + "exception handler not invocable"); + + public: + template + MapOperator( + std::shared_ptr> upstream, + Func&& function, + ErrorFunc&& errFunction) + : upstream_(std::move(upstream)), + function_(std::forward(function)), + errFunction_(std::move(errFunction)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( + this->ref_from_this(this), std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription( + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)) {} + + void onNextImpl(U value) override { + try { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + this->subscriberOnNext(flowable->function_(std::move(value))); + } + } catch (const std::exception& exn) { + folly::exception_wrapper ew{std::current_exception(), exn}; + this->terminateErr(std::move(ew)); + } + } + + void onErrorImpl(folly::exception_wrapper ew) override { + try { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + SuperSubscription::onErrorImpl(flowable->errFunction_(std::move(ew))); + } + } catch (const std::exception& exn) { + this->terminateErr( + folly::exception_wrapper{std::current_exception(), exn}); + } + } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSubscription::onTerminateImpl(); + } + + private: + AtomicReference flowable_; + }; + + std::shared_ptr> upstream_; + F function_; + EF errFunction_; +}; + +template +class FilterOperator : public FlowableOperator { + // for use in subclasses + using Super = FlowableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); + + public: + template + FilterOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( + this->ref_from_this(this), std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription( + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)) {} + + void onNextImpl(U value) override { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + if (flowable->function_(value)) { + SuperSubscription::subscriberOnNext(std::move(value)); + } else { + SuperSubscription::request(1); + } + } + } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSubscription::onTerminateImpl(); + } + + private: + AtomicReference flowable_; + }; + + std::shared_ptr> upstream_; + F function_; +}; + +template +class ReduceOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(std::is_assignable::value, "not assignable"); + static_assert(folly::is_invocable_r::value, "not invocable"); + + public: + template + ReduceOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( + this->ref_from_this(this), std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription( + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)), + accInitialized_(false) {} + + void request(int64_t) override { + // Request all of the items. + SuperSubscription::request(credits::kNoFlowControl); + } + + void onNextImpl(U value) override { + if (accInitialized_) { + if (auto flowable = yarpl::atomic_load(&flowable_)) { + acc_ = flowable->function_(std::move(acc_), std::move(value)); + } + } else { + acc_ = std::move(value); + accInitialized_ = true; + } + } + + void onCompleteImpl() override { + if (accInitialized_) { + SuperSubscription::subscriberOnNext(std::move(acc_)); + } + SuperSubscription::onCompleteImpl(); + } + + void onTerminateImpl() override { + yarpl::atomic_exchange(&flowable_, nullptr); + SuperSubscription::onTerminateImpl(); + } + + private: + AtomicReference flowable_; + bool accInitialized_; + D acc_; + }; + + std::shared_ptr> upstream_; + F function_; +}; + +template +class TakeOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + TakeOperator(std::shared_ptr> upstream, int64_t limit) + : upstream_(std::move(upstream)), limit_(limit) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe( + std::make_shared(limit_, std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription(int64_t limit, std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), limit_(limit) {} + + void onSubscribeImpl() override { + SuperSubscription::onSubscribeImpl(); + + if (limit_ <= 0) { + SuperSubscription::terminate(); + } + } + + void onNextImpl(T value) override { + if (limit_-- > 0) { + if (pending_ > 0) { + --pending_; + } + SuperSubscription::subscriberOnNext(std::move(value)); + if (limit_ == 0) { + SuperSubscription::terminate(); + } + } + } + + void request(int64_t delta) override { + delta = std::min(delta, limit_ - pending_); + if (delta > 0) { + pending_ += delta; + SuperSubscription::request(delta); + } + } + + private: + int64_t pending_{0}; + int64_t limit_; + }; + + std::shared_ptr> upstream_; + const int64_t limit_; +}; + +template +class SkipOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + SkipOperator(std::shared_ptr> upstream, int64_t offset) + : upstream_(std::move(upstream)), offset_(offset) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe( + std::make_shared(offset_, std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription(int64_t offset, std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), offset_(offset) {} + + void onNextImpl(T value) override { + if (offset_ > 0) { + --offset_; + } else { + SuperSubscription::subscriberOnNext(std::move(value)); + } + } + + void request(int64_t delta) override { + if (firstRequest_) { + firstRequest_ = false; + delta = credits::add(delta, offset_); + } + SuperSubscription::request(delta); + } + + private: + int64_t offset_; + bool firstRequest_{true}; + }; + + std::shared_ptr> upstream_; + const int64_t offset_; +}; + +template +class IgnoreElementsOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + explicit IgnoreElementsOperator(std::shared_ptr> upstream) + : upstream_(std::move(upstream)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared(std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription(std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)) {} + + void onNextImpl(T) override {} + }; + + std::shared_ptr> upstream_; +}; + +template +class SubscribeOnOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + SubscribeOnOperator( + std::shared_ptr> upstream, + folly::Executor& executor) + : upstream_(std::move(upstream)), executor_(executor) {} + + void subscribe(std::shared_ptr> subscriber) override { + executor_.add([this, self = this->ref_from_this(this), subscriber] { + upstream_->subscribe( + std::make_shared(executor_, std::move(subscriber))); + }); + } + + private: + using SuperSubscription = typename Super::Subscription; + class Subscription : public SuperSubscription { + public: + Subscription( + folly::Executor& executor, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), executor_(executor) {} + + void request(int64_t delta) override { + executor_.add([delta, this, self = this->ref_from_this(this)] { + this->callSuperRequest(delta); + }); + } + + void cancel() override { + executor_.add([this, self = this->ref_from_this(this)] { + this->callSuperCancel(); + }); + } + + void onNextImpl(T value) override { + SuperSubscription::subscriberOnNext(std::move(value)); + } + + private: + // Trampoline to call superclass method; gcc bug 58972. + void callSuperRequest(int64_t delta) { + SuperSubscription::request(delta); + } + + // Trampoline to call superclass method; gcc bug 58972. + void callSuperCancel() { + SuperSubscription::cancel(); + } + + folly::Executor& executor_; + }; + + std::shared_ptr> upstream_; + folly::Executor& executor_; +}; + +template +class FromPublisherOperator : public Flowable { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + + public: + template + explicit FromPublisherOperator(F&& function) + : function_(std::forward(function)) {} + + void subscribe(std::shared_ptr> subscriber) override { + function_(std::move(subscriber)); + } + + private: + OnSubscribe function_; +}; + +template +class FlatMapOperator : public FlowableOperator { + using Super = FlowableOperator; + + public: + FlatMapOperator( + std::shared_ptr> upstream, + folly::Function>(T)> func) + : upstream_(std::move(upstream)), function_(std::move(func)) {} + + void subscribe(std::shared_ptr> subscriber) override { + upstream_->subscribe(std::make_shared( + this->ref_from_this(this), std::move(subscriber))); + } + + private: + using SuperSubscription = typename Super::Subscription; + class FMSubscription : public SuperSubscription { + struct MappedStreamSubscriber; + + public: + FMSubscription( + std::shared_ptr flowable, + std::shared_ptr> subscriber) + : SuperSubscription(std::move(subscriber)), + flowable_(std::move(flowable)) {} + + void onSubscribeImpl() final { + liveSubscribers_++; + SuperSubscription::onSubscribeImpl(); + } + + void onNextImpl(T value) final { + std::shared_ptr> mappedStream; + + try { + mappedStream = flowable_->function_(std::move(value)); + } catch (const std::exception& exn) { + folly::exception_wrapper ew{std::current_exception(), exn}; + { + std::lock_guard g(onErrorExGuard_); + onErrorEx_ = ew; + } + // next iteration of drainLoop will cancel this subscriber as well + drainLoop(); + return; + } + + std::shared_ptr mappedSubscriber = + std::make_shared(this->ref_from_this(this)); + mappedSubscriber->fmReference_ = mappedSubscriber; + + { + // put into pendingValue queue because once the mappedSubscriber + // is subscribed to, it will request elements. We don't want the + // drainLoop to execute while it's on withoutValue, and request + // a second element before the first arrives. + auto l = lists.wlock(); + CHECK(!mappedSubscriber->is_linked()); + l->pendingValue.push_back(*mappedSubscriber.get()); + } + + liveSubscribers_++; + mappedStream->subscribe(mappedSubscriber); + drainLoop(); + } + + void drainImpl() { + // phase 1: clear out terminated subscribers + { + auto clearList = [](auto& list, SubscriberList& t) { + while (!list.empty()) { + auto& elem = list.front(); + auto r = elem.sync.wlock(); + r->freeze = true; + elem.unlink(); + t.push_back(elem); + } + }; + + SubscriberList clearTrash; + if (clearAllSubscribers_.load()) { + auto l = lists.wlock(); + clearList(l->withValue, clearTrash); + clearList(l->withoutValue, clearTrash); + clearList(l->pendingValue, clearTrash); + } + + // clear elements while no locks are held + while (!clearTrash.empty()) { + auto& elem = clearTrash.front(); + elem.unlink(); + elem.cancel(); + elem.fmReference_ = nullptr; + } + } + + // phase 2: check if the subscriber should terminate due to error + // or all subscribers completing + if (!calledDownstreamTerminate_) { + folly::exception_wrapper ex; + { + std::lock_guard exg(onErrorExGuard_); + ex = std::move(onErrorEx_); + } + if (ex) { + calledDownstreamTerminate_ = true; + cancel(); + this->terminateErr(std::move(ex)); + } else if (liveSubscribers_ == 0) { + calledDownstreamTerminate_ = true; + this->terminate(); + } + } + + // phase 3: if the downstream has requested elements, pop values out of + // subscribers which have received a value and call downstream->onNext + while (requested_ != 0) { + R val; + + { + auto l = lists.wlock(); + if (l->withValue.empty()) { + break; + } + + requested_--; + auto& elem = l->withValue.front(); + elem.unlink(); + + { + auto r = elem.sync.wlock(); + CHECK(r->hasValue); + r->hasValue = false; + val = std::move(r->value); + l->withoutValue.push_back(elem); + } + } + + SuperSubscription::subscriberOnNext(std::move(val)); + } + + // phase 4: ask any upstream flowables which don't have pending + // requests for their next element kick off any more requests. + // Put subscribers which have terminated into the trash. + { + SubscriberList terminatedTrash; + + while (true) { + MappedStreamSubscriber* elem; + { + auto l = lists.wlock(); + if (l->withoutValue.empty()) { + break; + } + elem = &l->withoutValue.front(); + + auto r = elem->sync.wlock(); + CHECK(!r->hasValue) << "failed for elem=" << elem; // sanity + + elem->unlink(); + + // Subscribers might call onNext and then terminate; delay + // removing its liveSubscriber reference until we've delivered + // its element to the downstream subscriber and dropped its + // synchronized reference to `r`, as dropping the + // flatMapSubscription_ reference may invoke its destructor + if (r->isTerminated) { + r->freeze = true; + terminatedTrash.push_back(*elem); + continue; // skips the next elem->request(1) + } + + // else, the stream hasn't terminated, request another + // element + l->pendingValue.push_back(*elem); + } + elem->request(1); + } + + // phase 5: destroy any mapped subscribers which have terminated, + // enqueue another drain loop run if we do end up discarding any + // subscribers, as our live subscriber count may have gone to zero + if (!terminatedTrash.empty()) { + drainLoopMutex_++; + } + while (!terminatedTrash.empty()) { + auto& elem = terminatedTrash.front(); + CHECK(elem.sync.wlock()->isTerminated); + elem.unlink(); + elem.fmReference_ = nullptr; + liveSubscribers_--; + } + } + } + + // called from MappedStreamSubscriber, receives the R and the + // subscriber which generated the R + void drainLoop() { + auto self = this->ref_from_this(this); + if (drainLoopMutex_++ == 0) { + do { + drainImpl(); + } while (drainLoopMutex_-- != 1); + } + } + + void onMappedSubscriberNext(MappedStreamSubscriber* elem, R value) { + { + // `elem` may not be in a list, as it may have been canceled. Push it + // on the withValue list and let drainLoop clear it if that's the case. + auto l = lists.wlock(); + auto r = elem->sync.wlock(); + + if (r->freeze) { + return; + } + + CHECK(!r->hasValue) << "failed for elem=" << elem; + r->hasValue = true; + r->value = std::move(value); + + elem->unlink(); + l->withValue.push_back(*elem); + } + + drainLoop(); + } + void onMappedSubscriberTerminate(MappedStreamSubscriber* elem) { + { + auto r = elem->sync.wlock(); + + r->isTerminated = true; + if (r->onErrorEx) { + std::lock_guard exg(onErrorExGuard_); + onErrorEx_ = std::move(r->onErrorEx); + } + + if (r->freeze) { + return; + } + } + + { + auto l = lists.wlock(); + auto r = elem->sync.wlock(); + + if (r->freeze) { + return; + } + + CHECK(elem->is_linked()); + elem->unlink(); + + if (r->hasValue) { + l->withValue.push_back(*elem); + } else { + liveSubscribers_--; + elem->fmReference_ = nullptr; + } + } + + drainLoop(); + } + + // onComplete/onError fall through to onTerminateImpl, which + // will call drainLoop and update the liveSubscribers_ count + void onCompleteImpl() final {} + void onErrorImpl(folly::exception_wrapper ex) final { + std::lock_guard g(onErrorExGuard_); + onErrorEx_ = std::move(ex); + clearAllSubscribers_.store(true); + } + + void onTerminateImpl() final { + liveSubscribers_--; + drainLoop(); + flowable_.reset(); + } + + void request(int64_t n) override { + if ((n + requested_) < requested_) { + requested_ = std::numeric_limits::max(); + } else { + requested_ += n; + } + + if (n > 0) { + // TODO: make max parallelism configurable a-la RxJava 2.x's + // FlowableFlatMapOperator + SuperSubscription::request(std::numeric_limits::max()); + } + + drainLoop(); + } + + void cancel() override { + clearAllSubscribers_.store(true); + drainLoop(); + } + + private: + // buffers at most a single element of type R + struct MappedStreamSubscriber + : public BaseSubscriber, + public boost::intrusive::list_base_hook< + boost::intrusive::link_mode> { + MappedStreamSubscriber(std::shared_ptr subscription) + : flatMapSubscription_(std::move(subscription)) {} + + void onSubscribeImpl() final { + auto fmsb = yarpl::atomic_load(&flatMapSubscription_); + if (!fmsb || fmsb->clearAllSubscribers_) { + BaseSubscriber::cancel(); + return; + } +#ifndef NDEBUG + if (auto fms = yarpl::atomic_load(&flatMapSubscription_)) { + auto l = fms->lists.wlock(); + auto r = sync.wlock(); + if (!is_in_list(*this, l->pendingValue, l)) { + LOG(INFO) << "failed: this=" << this; + LOG(INFO) << "in list: "; + debug_is_in_list(*this, l); + DCHECK(r->freeze); + } else { + } + DCHECK(!r->hasValue); + } +#endif + + BaseSubscriber::request(1); + } + + void onNextImpl(R value) final { + if (auto fms = yarpl::atomic_load(&flatMapSubscription_)) { + fms->onMappedSubscriberNext(this, std::move(value)); + } + } + + // noop + void onCompleteImpl() final {} + + void onErrorImpl(folly::exception_wrapper ex) final { + auto r = sync.wlock(); + r->onErrorEx = std::move(ex); + } + + void onTerminateImpl() override { + std::shared_ptr null; + if (auto fms = yarpl::atomic_exchange(&flatMapSubscription_, null)) { + fms->onMappedSubscriberTerminate(this); + } + } + + struct SyncData { + R value; + bool hasValue{false}; + bool isTerminated{false}; + bool freeze{false}; + folly::exception_wrapper onErrorEx{nullptr}; + }; + folly::Synchronized sync; + + // FMSubscription's 'reference' to this object. FMSubscription + // clears this reference when it drops the MappedStreamSubscriber + // from one of its atomic lists + std::shared_ptr fmReference_{nullptr}; + + // this is both a Subscriber and a Subscription + AtomicReference flatMapSubscription_{nullptr}; + }; + + // used to make sure only one thread at a time is calling subscriberOnNext + std::atomic drainLoopMutex_{0}; + + using SubscriberList = boost::intrusive::list< + MappedStreamSubscriber, + boost::intrusive::constant_time_size>; + + struct Lists { + // subscribers with a ready R + SubscriberList withValue{}; + // subscribers that have requested 1 R, waiting for it to arrive via + // onNext + SubscriberList pendingValue{}; + // idle subscribers + SubscriberList withoutValue{}; + }; + + folly::Synchronized lists; + + template + static bool is_in_list( + MappedStreamSubscriber const& elem, + SubscriberList const& list, + L const& lists) { + return in_list_impl(elem, list, lists, true); + } + template + static bool not_in_list( + MappedStreamSubscriber const& elem, + SubscriberList const& list, + L const& lists) { + return in_list_impl(elem, list, lists, false); + } + + template + static bool in_list_impl( + MappedStreamSubscriber const& elem, + SubscriberList const& list, + L const& lists, + bool should) { + if (is_in_list(elem, list) != should) { +#ifndef NDEBUG + debug_is_in_list(elem, lists); +#else + (void)lists; +#endif + return false; + } + return true; + } + + template + static void debug_is_in_list( + MappedStreamSubscriber const& elem, + L const& lists) { + LOG(INFO) << "in without: " << is_in_list(elem, lists->withoutValue); + LOG(INFO) << "in pending: " << is_in_list(elem, lists->pendingValue); + LOG(INFO) << "in withval: " << is_in_list(elem, lists->withValue); + } + + static bool is_in_list( + MappedStreamSubscriber const& elem, + SubscriberList const& list) { + bool found = false; + for (auto& e : list) { + if (&e == &elem) { + found = true; + break; + } + } + return found; + } + + std::shared_ptr flowable_; + + // got a terminating signal from the upstream flowable + // always modified in the protected drainImpl() + bool calledDownstreamTerminate_{false}; + + std::mutex onErrorExGuard_; + folly::exception_wrapper onErrorEx_{nullptr}; + + // clear all lists of + std::atomic clearAllSubscribers_{false}; + + std::atomic requested_{0}; + + // number of subscribers (FMSubscription + MappedStreamSubscriber) which + // have not received a terminating signal yet + std::atomic liveSubscribers_{0}; + }; + + std::shared_ptr> upstream_; + folly::Function>(T)> function_; +}; + +} // namespace flowable +} // namespace yarpl + +#include "yarpl/flowable/FlowableConcatOperators.h" +#include "yarpl/flowable/FlowableDoOperator.h" +#include "yarpl/flowable/FlowableObserveOnOperator.h" +#include "yarpl/flowable/FlowableTimeoutOperator.h" diff --git a/yarpl/flowable/FlowableTimeoutOperator.h b/yarpl/flowable/FlowableTimeoutOperator.h new file mode 100644 index 000000000..bb14ccc9c --- /dev/null +++ b/yarpl/flowable/FlowableTimeoutOperator.h @@ -0,0 +1,162 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// IWYU pragma: private, include "yarpl/flowable/Flowable.h" + +#pragma once + +#include + +#include "yarpl/flowable/FlowableOperator.h" + +namespace yarpl { +class TimeoutException : public std::runtime_error { + public: + TimeoutException() : std::runtime_error("yarpl::TimeoutException") {} +}; +namespace detail { +class TimeoutExceptionGenerator { + public: + TimeoutException operator()() const { + return {}; + } +}; +} // namespace detail + +namespace flowable { +namespace details { + +template +class TimeoutOperator : public FlowableOperator { + using Super = FlowableOperator; + static_assert( + std::is_same, ExceptionGenerator>::value, + "undecayed"); + + public: + template + TimeoutOperator( + std::shared_ptr> upstream, + folly::EventBase& timerEvb, + std::chrono::milliseconds timeout, + std::chrono::milliseconds initTimeout, + F&& exnGen) + : upstream_(std::move(upstream)), + timerEvb_(timerEvb), + timeout_(timeout), + initTimeout_(initTimeout), + exnGen_(std::forward(exnGen)) {} + + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = std::make_shared( + this->ref_from_this(this), + subscriber, + timerEvb_, + initTimeout_, + timeout_); + upstream_->subscribe(std::move(subscription)); + } + + protected: + class TimeoutSubscription : public Super::Subscription, + public folly::HHWheelTimer::Callback { + using SuperSub = typename Super::Subscription; + + public: + TimeoutSubscription( + std::shared_ptr> flowable, + std::shared_ptr> subscriber, + folly::EventBase& timerEvb, + std::chrono::milliseconds initTimeout, + std::chrono::milliseconds timeout) + : Super::Subscription(std::move(subscriber)), + flowable_(std::move(flowable)), + timerEvb_(timerEvb), + initTimeout_(initTimeout), + timeout_(timeout) {} + + void onSubscribeImpl() override { + DCHECK(timerEvb_.isInEventBaseThread()); + if (initTimeout_.count() > 0) { + nextTime_ = std::chrono::steady_clock::now() + initTimeout_; + timerEvb_.timer().scheduleTimeout(this, initTimeout_); + } else { + nextTime_ = std::chrono::steady_clock::time_point::max(); + } + + SuperSub::onSubscribeImpl(); + } + + void onNextImpl(T value) override { + DCHECK(timerEvb_.isInEventBaseThread()); + if (flowable_) { + if (nextTime_ != std::chrono::steady_clock::time_point::max()) { + cancelTimeout(); // cancel timer before calling onNext + auto currentTime = std::chrono::steady_clock::now(); + if (currentTime > nextTime_) { + timeoutExpired(); + return; + } + nextTime_ = std::chrono::steady_clock::time_point::max(); + } + + SuperSub::subscriberOnNext(std::move(value)); + + if (timeout_.count() > 0) { + nextTime_ = std::chrono::steady_clock::now() + timeout_; + timerEvb_.timer().scheduleTimeout(this, timeout_); + } + } + } + + void onTerminateImpl() override { + DCHECK(timerEvb_.isInEventBaseThread()); + flowable_.reset(); + cancelTimeout(); + } + + void timeoutExpired() noexcept override { + if (auto flowable = std::exchange(flowable_, nullptr)) { + SuperSub::terminateErr([&]() -> folly::exception_wrapper { + try { + return flowable->exnGen_(); + } catch (...) { + return folly::make_exception_wrapper(); + } + }()); + } + } + + void callbackCanceled() noexcept override { + // Do nothing.. + } + + private: + std::shared_ptr> flowable_; + folly::EventBase& timerEvb_; + std::chrono::milliseconds initTimeout_; + std::chrono::milliseconds timeout_; + std::chrono::steady_clock::time_point nextTime_; + }; + + std::shared_ptr> upstream_; + folly::EventBase& timerEvb_; + std::chrono::milliseconds timeout_; + std::chrono::milliseconds initTimeout_; + ExceptionGenerator exnGen_; +}; + +} // namespace details +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Flowable_FromObservable.h b/yarpl/flowable/Flowable_FromObservable.h new file mode 100644 index 000000000..e191ad7c3 --- /dev/null +++ b/yarpl/flowable/Flowable_FromObservable.h @@ -0,0 +1,348 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/Common.h" +#include "yarpl/Flowable.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace observable { +template +class Observable; + +template +class Observer; +} // namespace observable + +template +class BackpressureStrategyBase : public IBackpressureStrategy, + public flowable::Subscription, + public observable::Observer { + protected: + // + // the following methods are to be overridden + // + virtual void onCreditsAvailable(int64_t /*credits*/) = 0; + virtual void onNextWithoutCredits(T /*t*/) = 0; + + public: + void init( + std::shared_ptr> observable, + std::shared_ptr> subscriber) override { + observable_ = std::move(observable); + subscriberWeak_ = subscriber; + subscriber_ = subscriber; + subscriber->onSubscribe(this->ref_from_this(this)); + observable_->subscribe(this->ref_from_this(this)); + } + + BackpressureStrategyBase() = default; + BackpressureStrategyBase(BackpressureStrategyBase&&) = delete; + + BackpressureStrategyBase(const BackpressureStrategyBase&) = delete; + BackpressureStrategyBase& operator=(BackpressureStrategyBase&&) = delete; + BackpressureStrategyBase& operator=(const BackpressureStrategyBase&) = delete; + + // only for testing purposes + void setTestSubscriber(std::shared_ptr> subscriber) { + subscriberWeak_ = subscriber; + subscriber_ = subscriber; + subscriber->onSubscribe(this->ref_from_this(this)); + } + + void request(int64_t n) override { + if (n <= 0) { + return; + } + auto r = credits::add(&requested_, n); + if (r <= 0) { + return; + } + + // it is possible that after calling subscribe or in onCreditsAvailable + // methods, there will be a stream of + // onNext calls which the processing chain might cancel. The cancel signal + // will remove all references to this class and we need to keep this + // instance around to finish this method + auto thisPtr = this->ref_from_this(this); + + if (r > 0) { + onCreditsAvailable(r); + } + } + + void cancel() override { + if (auto subscriber = subscriber_.exchange(nullptr)) { + observable::Observer::unsubscribe(); + observable_.reset(); + } + } + + // Observer override + void onNext(T t) override { + if (subscriberWeak_.expired()) { + return; + } + if (requested_ > 0) { + downstreamOnNext(std::move(t)); + return; + } + onNextWithoutCredits(std::move(t)); + } + + // Observer override + void onComplete() override { + downstreamOnComplete(); + } + + // Observer override + void onError(folly::exception_wrapper ex) override { + downstreamOnError(std::move(ex)); + } + + virtual void downstreamOnNext(T t) { + credits::consume(&requested_, 1); + if (auto subscriber = subscriberWeak_.lock()) { + subscriber->onNext(std::move(t)); + } + } + + void downstreamOnComplete() { + if (auto subscriber = subscriber_.exchange(nullptr)) { + subscriber->onComplete(); + observable::Observer::onComplete(); + observable_.reset(); + } + } + + void downstreamOnError(folly::exception_wrapper error) { + if (auto subscriber = subscriber_.exchange(nullptr)) { + subscriber->onError(std::move(error)); + observable::Observer::onError(folly::exception_wrapper()); + observable_.reset(); + } + } + + void downstreamOnErrorAndCancel(folly::exception_wrapper error) { + if (auto subscriber = subscriber_.exchange(nullptr)) { + subscriber->onError(std::move(error)); + + observable_.reset(); + observable::Observer::unsubscribe(); + } + } + + private: + std::shared_ptr> observable_; + folly::Synchronized>> subscriber_; + std::weak_ptr> subscriberWeak_; + std::atomic requested_{0}; +}; + +template +class DropBackpressureStrategy : public BackpressureStrategyBase { + public: + void onCreditsAvailable(int64_t /*credits*/) override {} + void onNextWithoutCredits(T /*t*/) override { + // drop anything while we don't have credits + } +}; + +template +class ErrorBackpressureStrategy : public BackpressureStrategyBase { + using Super = BackpressureStrategyBase; + + void onCreditsAvailable(int64_t /*credits*/) override {} + + void onNextWithoutCredits(T /*t*/) override { + Super::downstreamOnErrorAndCancel(flowable::MissingBackpressureException()); + } +}; + +template +class BufferBackpressureStrategy : public BackpressureStrategyBase { + public: + static constexpr size_t kNoLimit = 0; + + explicit BufferBackpressureStrategy(size_t bufferSizeLimit = kNoLimit) + : buffer_(folly::in_place, bufferSizeLimit) {} + + private: + using Super = BackpressureStrategyBase; + + void onComplete() override { + if (!buffer_.rlock()->empty()) { + // we have buffered some items so we will defer delivering on complete for + // later + completed_ = true; + } else { + Super::onComplete(); + } + } + + void onNext(T t) override { + { + auto buffer = buffer_.wlock(); + if (!buffer->empty()) { + if (buffer->push(std::move(t))) { + return; + } + buffer.unlock(); + Super::downstreamOnErrorAndCancel( + flowable::MissingBackpressureException()); + return; + } + } + BackpressureStrategyBase::onNext(std::move(t)); + } + + // + // onError signal is delivered immediately by design + // + + void onNextWithoutCredits(T t) override { + if (buffer_.wlock()->push(std::move(t))) { + return; + } + Super::downstreamOnErrorAndCancel(flowable::MissingBackpressureException()); + } + + void onCreditsAvailable(int64_t credits) override { + DCHECK(credits > 0); + auto lockedBuffer = buffer_.wlock(); + while (credits-- > 0 && !lockedBuffer->empty()) { + Super::downstreamOnNext(std::move(lockedBuffer->front())); + lockedBuffer->pop(); + } + + if (lockedBuffer->empty() && completed_) { + Super::onComplete(); + } + } + + struct Buffer { + public: + explicit Buffer(size_t sizeLimit) : sizeLimit_(sizeLimit) {} + + bool empty() const { + return buffer_.empty(); + } + + bool push(T&& value) { + if (sizeLimit_ != kNoLimit && buffer_.size() >= sizeLimit_) { + return false; + } + buffer_.push(std::move(value)); + return true; + } + + T& front() { + return buffer_.front(); + } + + void pop() { + buffer_.pop(); + } + + private: + const size_t sizeLimit_; + std::queue buffer_; + }; + + folly::Synchronized buffer_; + std::atomic completed_{false}; +}; + +template +class LatestBackpressureStrategy : public BackpressureStrategyBase { + using Super = BackpressureStrategyBase; + + void onComplete() override { + if (storesLatest_) { + // we have buffered an item so we will defer delivering on complete for + // later + completed_ = true; + } else { + Super::onComplete(); + } + } + + // + // onError signal is delivered immediately by design + // + + void onNextWithoutCredits(T t) override { + storesLatest_ = true; + *latest_.wlock() = std::move(t); + } + + void onCreditsAvailable(int64_t credits) override { + DCHECK(credits > 0); + if (storesLatest_) { + storesLatest_ = false; + Super::downstreamOnNext(std::move(*latest_.wlock())); + + if (completed_) { + Super::onComplete(); + } + } + } + + std::atomic storesLatest_{false}; + std::atomic completed_{false}; + folly::Synchronized latest_; +}; + +template +class MissingBackpressureStrategy : public BackpressureStrategyBase { + using Super = BackpressureStrategyBase; + + void onCreditsAvailable(int64_t /*credits*/) override {} + + void onNextWithoutCredits(T t) override { + // call onNext anyways (and potentially violating the protocol) + Super::downstreamOnNext(std::move(t)); + } +}; + +template +std::shared_ptr> IBackpressureStrategy::buffer() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::drop() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::error() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::latest() { + return std::make_shared>(); +} + +template +std::shared_ptr> IBackpressureStrategy::missing() { + return std::make_shared>(); +} + +} // namespace yarpl diff --git a/yarpl/flowable/Flowables.cpp b/yarpl/flowable/Flowables.cpp new file mode 100644 index 000000000..4b61540f5 --- /dev/null +++ b/yarpl/flowable/Flowables.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/Flowables.h" + +namespace yarpl { +namespace flowable { + +std::shared_ptr> Flowable<>::range( + int64_t start, + int64_t count) { + auto lambda = [start, count, i = start]( + Subscriber& subscriber, + int64_t requested) mutable { + int64_t end = start + count; + + while (i < end && requested-- > 0) { + subscriber.onNext(i++); + } + + if (i >= end) { + // TODO T27302402: Even though having two subscriptions exist concurrently + // for Emitters is not possible still. At least it possible to resubscribe + // and consume the same values again. + i = start; + subscriber.onComplete(); + } + }; + return Flowable::create(std::move(lambda)); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Flowables.h b/yarpl/flowable/Flowables.h new file mode 100644 index 000000000..56cb8c034 --- /dev/null +++ b/yarpl/flowable/Flowables.h @@ -0,0 +1,64 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "yarpl/flowable/Flowable.h" + +namespace yarpl { +namespace flowable { + +template <> +class Flowable { + public: + /** + * Emit the sequence of numbers [start, start + count). + */ + static std::shared_ptr> range(int64_t start, int64_t count); + + template + static std::shared_ptr> just(T&& value) { + return Flowable>::just(std::forward(value)); + } + + template + static std::shared_ptr> justN(std::initializer_list list) { + return Flowable>::justN(std::move(list)); + } + + // this will generate a flowable which can be subscribed to only once + template + static std::shared_ptr> justOnce(T&& value) { + return Flowable>::justOnce(std::forward(value)); + } + + template + static std::shared_ptr> concat( + std::shared_ptr> first, + Args... args) { + return first->concatWith(args...); + } + + private: + Flowable() = delete; +}; + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/PublishProcessor.h b/yarpl/flowable/PublishProcessor.h new file mode 100644 index 000000000..8232c4a82 --- /dev/null +++ b/yarpl/flowable/PublishProcessor.h @@ -0,0 +1,255 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/Common.h" +#include "yarpl/flowable/Flowable.h" +#include "yarpl/observable/Observable.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace flowable { + +// Processor that multicasts all subsequently observed items to its current +// Subscribers. The processor does not coordinate backpressure for its +// subscribers and implements a weaker onSubscribe which calls requests +// kNoFlowControl from the incoming Subscriptions. This makes it possible to +// subscribe the PublishProcessor to multiple sources unlike the standard +// Subscriber contract. If subscribers are not able to keep up with the flow +// control, they are terminated with MissingBackpressureException. The +// implementation of onXXX() and subscribe() methods are technically thread-safe +// but non-serialized calls to them may lead to undefined state in the currently +// subscribed Subscribers. +template +class PublishProcessor : public observable::Observable, + public Subscriber { + class PublisherSubscription; + using PublishersVector = std::vector>; + + public: + static std::shared_ptr create() { + return std::shared_ptr(new PublishProcessor()); + } + + ~PublishProcessor() { + auto publishers = std::make_shared(); + publishers_.swap(publishers); + + for (const auto& publisher : *publishers) { + publisher->terminate(); + } + } + + bool hasSubscribers() const { + return !publishers_.copy()->empty(); + } + + std::shared_ptr subscribe( + std::shared_ptr> subscriber) override { + auto publisher = std::make_shared(subscriber, this); + // we have to call onSubscribe before adding it to the list of publishers + // because they might start emitting right away + subscriber->onSubscribe(publisher); + + if (publisher->isCancelled()) { + return publisher; + } + + auto publishers = tryAddPublisher(publisher); + + if (publishers == kCompletedPublishers()) { + publisher->onComplete(); + } else if (publishers == kErroredPublishers()) { + publisher->onError(std::runtime_error("ErroredPublisher")); + } + + return publisher; + } + + void onSubscribe(std::shared_ptr subscription) override { + auto publishers = publishers_.copy(); + if (publishers == kCompletedPublishers() || + publishers == kErroredPublishers()) { + subscription->cancel(); + return; + } + + subscription->request(credits::kNoFlowControl); + } + + void onNext(T value) override { + auto publishers = publishers_.copy(); + DCHECK(publishers != kCompletedPublishers()); + DCHECK(publishers != kErroredPublishers()); + + for (const auto& publisher : *publishers) { + publisher->onNext(value); + } + } + + void onError(folly::exception_wrapper ex) override { + auto publishers = kErroredPublishers(); + publishers_.swap(publishers); + DCHECK(publishers != kCompletedPublishers()); + DCHECK(publishers != kErroredPublishers()); + + for (const auto& publisher : *publishers) { + publisher->onError(ex); + } + } + + void onComplete() override { + auto publishers = kCompletedPublishers(); + publishers_.swap(publishers); + DCHECK(publishers != kCompletedPublishers()); + DCHECK(publishers != kErroredPublishers()); + + for (const auto& publisher : *publishers) { + publisher->onComplete(); + } + } + + private: + PublishProcessor() : publishers_{std::make_shared()} {} + + std::shared_ptr tryAddPublisher( + std::shared_ptr subscriber) { + while (true) { + auto oldPublishers = publishers_.copy(); + if (oldPublishers == kCompletedPublishers() || + oldPublishers == kErroredPublishers()) { + return oldPublishers; + } + + auto newPublishers = std::make_shared(); + newPublishers->reserve(oldPublishers->size() + 1); + newPublishers->insert( + newPublishers->begin(), + oldPublishers->cbegin(), + oldPublishers->cend()); + newPublishers->push_back(subscriber); + + auto locked = publishers_.lock(); + if (*locked == oldPublishers) { + *locked = newPublishers; + return newPublishers; + } + // else the vector changed so we will have to do it again + } + } + + void removePublisher(PublisherSubscription* subscriber) { + while (true) { + auto oldPublishers = publishers_.copy(); + + auto removingItem = std::find_if( + oldPublishers->cbegin(), + oldPublishers->cend(), + [&](const auto& publisherPtr) { + return publisherPtr.get() == subscriber; + }); + + if (removingItem == oldPublishers->cend()) { + // not found anymore + return; + } + + auto newPublishers = std::make_shared(); + newPublishers->reserve(oldPublishers->size() - 1); + newPublishers->insert( + newPublishers->begin(), oldPublishers->cbegin(), removingItem); + newPublishers->insert( + newPublishers->end(), std::next(removingItem), oldPublishers->cend()); + + auto locked = publishers_.lock(); + if (*locked == oldPublishers) { + *locked = std::move(newPublishers); + return; + } + // else the vector changed so we will have to do it again + } + } + + class PublisherSubscription : public observable::Subscription { + public: + PublisherSubscription( + std::shared_ptr> subscriber, + PublishProcessor* processor) + : subscriber_(std::move(subscriber)), processor_(processor) {} + + // cancel may race with terminate(), but the + // PublishProcessor::removePublisher will take care of that the race with + // on{Next, Error, Complete} methods is allowed by the spec + void cancel() override { + subscriber_.reset(); + processor_->removePublisher(this); + } + + // terminate will never race with on{Next, Error, Complete} because they are + // all called from PublishProcessor and terminate is called only from dtor + void terminate() { + if (auto subscriber = std::exchange(subscriber_, nullptr)) { + subscriber->onError(std::runtime_error("PublishProcessor shutdown")); + } + } + + void onNext(T value) { + if (subscriber_) { + subscriber_->onNext(std::move(value)); + } + } + + // used internally, not an interface method + void onError(folly::exception_wrapper ex) { + if (auto subscriber = std::exchange(subscriber_, nullptr)) { + subscriber->onError(std::move(ex)); + } + } + + // used internally, not an interface method + void onComplete() { + if (auto subscriber = std::exchange(subscriber_, nullptr)) { + subscriber->onComplete(); + } + } + + bool isCancelled() const { + return !subscriber_; + } + + private: + std::shared_ptr> subscriber_; + PublishProcessor* processor_; + }; + + static const std::shared_ptr& kCompletedPublishers() { + static std::shared_ptr constant = + std::make_shared(); + return constant; + } + + static const std::shared_ptr& kErroredPublishers() { + static std::shared_ptr constant = + std::make_shared(); + return constant; + } + + folly::Synchronized, std::mutex> + publishers_; +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Subscriber.h b/yarpl/flowable/Subscriber.h new file mode 100644 index 000000000..d1dc3b525 --- /dev/null +++ b/yarpl/flowable/Subscriber.h @@ -0,0 +1,448 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "yarpl/Disposable.h" +#include "yarpl/Refcounted.h" +#include "yarpl/flowable/Subscription.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace flowable { + +template +class Subscriber : boost::noncopyable { + public: + virtual ~Subscriber() = default; + virtual void onSubscribe(std::shared_ptr) = 0; + virtual void onComplete() = 0; + virtual void onError(folly::exception_wrapper) = 0; + virtual void onNext(T) = 0; + + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + static std::shared_ptr> create( + Next&& next, + int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + static std::shared_ptr> + create(Next&& next, Error&& error, int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + static std::shared_ptr> create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch = credits::kNoFlowControl); + + static std::shared_ptr> create() { + class NullSubscriber : public Subscriber { + void onSubscribe(std::shared_ptr s) override final { + s->request(credits::kNoFlowControl); + } + + void onNext(T) override final {} + void onComplete() override {} + void onError(folly::exception_wrapper) override {} + }; + return std::make_shared(); + } +}; + +namespace details { + +template +class BaseSubscriberDisposable; + +} // namespace details + +#define KEEP_REF_TO_THIS() \ + std::shared_ptr self; \ + if (keep_reference_to_this) { \ + self = this->ref_from_this(this); \ + } + +// T : Type of Flowable that this Subscriber operates on +// +// keep_reference_to_this : BaseSubscriber will keep a live reference to +// itself on the stack while in a signaling or requesting method, in case +// the derived class causes all other references to itself to be dropped. +// +// Classes that ensure that at least one reference will stay live can +// use `keep_reference_to_this = false` as an optimization to +// prevent an atomic inc/dec pair +template +class BaseSubscriber : public Subscriber, public yarpl::enable_get_ref { + public: + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + void onSubscribe(std::shared_ptr subscription) final override { + CHECK(subscription); + CHECK(!yarpl::atomic_load(&subscription_)); + +#ifndef NDEBUG + DCHECK(!gotOnSubscribe_.exchange(true)) + << "Already subscribed to BaseSubscriber"; +#endif + + yarpl::atomic_store(&subscription_, std::move(subscription)); + KEEP_REF_TO_THIS(); + onSubscribeImpl(); + } + + // No further calls to the subscription after this method is invoked. + void onComplete() final override { +#ifndef NDEBUG + DCHECK(gotOnSubscribe_.load()) << "Not subscribed to BaseSubscriber"; + DCHECK(!gotTerminating_.exchange(true)) + << "Already got terminating signal method"; +#endif + + std::shared_ptr null; + if (auto sub = yarpl::atomic_exchange(&subscription_, null)) { + KEEP_REF_TO_THIS(); + onCompleteImpl(); + onTerminateImpl(); + } + } + + // No further calls to the subscription after this method is invoked. + void onError(folly::exception_wrapper e) final override { +#ifndef NDEBUG + DCHECK(gotOnSubscribe_.load()) << "Not subscribed to BaseSubscriber"; + DCHECK(!gotTerminating_.exchange(true)) + << "Already got terminating signal method"; +#endif + + std::shared_ptr null; + if (auto sub = yarpl::atomic_exchange(&subscription_, null)) { + KEEP_REF_TO_THIS(); + onErrorImpl(std::move(e)); + onTerminateImpl(); + } + } + + void onNext(T t) final override { +#ifndef NDEBUG + DCHECK(gotOnSubscribe_.load()) << "Not subscibed to BaseSubscriber"; + if (gotTerminating_.load()) { + VLOG(2) << "BaseSubscriber already got terminating signal method"; + } +#endif + + if (auto sub = yarpl::atomic_load(&subscription_)) { + KEEP_REF_TO_THIS(); + onNextImpl(std::move(t)); + } + } + + void cancel() { + std::shared_ptr null; + if (auto sub = yarpl::atomic_exchange(&subscription_, null)) { + KEEP_REF_TO_THIS(); + sub->cancel(); + onTerminateImpl(); + } +#ifndef NDEBUG + else { + VLOG(2) << "cancel() on BaseSubscriber with no subscription_"; + } +#endif + } + + void request(int64_t n) { + if (auto sub = yarpl::atomic_load(&subscription_)) { + KEEP_REF_TO_THIS(); + sub->request(n); + } +#ifndef NDEBUG + else { + VLOG(2) << "request() on BaseSubscriber with no subscription_"; + } +#endif + } + + protected: + virtual void onSubscribeImpl() = 0; + virtual void onCompleteImpl() = 0; + virtual void onNextImpl(T) = 0; + virtual void onErrorImpl(folly::exception_wrapper) = 0; + + virtual void onTerminateImpl() {} + + private: + bool isTerminated() { + return !yarpl::atomic_load(&subscription_); + } + + friend class ::yarpl::flowable::details::BaseSubscriberDisposable; + + // keeps a reference alive to the subscription + AtomicReference subscription_; + +#ifndef NDEBUG + std::atomic gotOnSubscribe_{false}; + std::atomic gotTerminating_{false}; +#endif +}; + +namespace details { + +template +class BaseSubscriberDisposable : public Disposable { + public: + BaseSubscriberDisposable(std::shared_ptr> subscriber) + : subscriber_(std::move(subscriber)) {} + + void dispose() override { + if (auto sub = yarpl::atomic_exchange(&subscriber_, nullptr)) { + sub->cancel(); + } + } + + bool isDisposed() override { + if (auto sub = yarpl::atomic_load(&subscriber_)) { + return sub->isTerminated(); + } else { + return true; + } + } + + private: + AtomicReference> subscriber_; +}; + +template +class LambdaSubscriber : public BaseSubscriber { + public: + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + static std::shared_ptr> create( + Next&& next, + int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + static std::shared_ptr> + create(Next&& next, Error&& error, int64_t batch = credits::kNoFlowControl); + + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + static std::shared_ptr> create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch = credits::kNoFlowControl); +}; + +template +class Base : public LambdaSubscriber { + static_assert(std::is_same, Next>::value, "undecayed"); + + public: + template + Base(FNext&& next, int64_t batch) + : next_(std::forward(next)), batch_(batch), pending_(0) {} + + void onSubscribeImpl() override final { + pending_ = batch_; + this->request(batch_); + } + + void onNextImpl(T value) override final { + try { + next_(std::move(value)); + } catch (const std::exception& exn) { + this->cancel(); + auto ew = folly::exception_wrapper{std::current_exception(), exn}; + LOG(ERROR) << "'next' method should not throw: " << ew.what(); + onErrorImpl(ew); + return; + } + + if (--pending_ <= batch_ / 2) { + const auto delta = batch_ - pending_; + pending_ += delta; + this->request(delta); + } + } + + void onCompleteImpl() override {} + void onErrorImpl(folly::exception_wrapper) override {} + + private: + Next next_; + const int64_t batch_; + int64_t pending_; +}; + +template +class WithError : public Base { + static_assert(std::is_same, Error>::value, "undecayed"); + + public: + template + WithError(FNext&& next, FError&& error, int64_t batch) + : Base(std::forward(next), batch), + error_(std::forward(error)) {} + + void onErrorImpl(folly::exception_wrapper error) override final { + try { + error_(std::move(error)); + } catch (const std::exception& exn) { + LOG(ERROR) << "'error' method should not throw: " << exn.what(); + } + } + + private: + Error error_; +}; + +template +class WithErrorAndComplete : public WithError { + static_assert( + std::is_same, Complete>::value, + "undecayed"); + + public: + template + WithErrorAndComplete( + FNext&& next, + FError&& error, + FComplete&& complete, + int64_t batch) + : WithError( + std::forward(next), + std::forward(error), + batch), + complete_(std::forward(complete)) {} + + void onCompleteImpl() override final { + try { + complete_(); + } catch (const std::exception& exn) { + LOG(ERROR) << "'complete' method should not throw: " << exn.what(); + } + } + + private: + Complete complete_; +}; + +template +template +std::shared_ptr> LambdaSubscriber::create( + Next&& next, + int64_t batch) { + return std::make_shared>>( + std::forward(next), batch); +} + +template +template +std::shared_ptr> +LambdaSubscriber::create(Next&& next, Error&& error, int64_t batch) { + return std::make_shared< + details::WithError, std::decay_t>>( + std::forward(next), std::forward(error), batch); +} + +template +template +std::shared_ptr> LambdaSubscriber::create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch) { + return std::make_shared, + std::decay_t, + std::decay_t>>( + std::forward(next), + std::forward(error), + std::forward(complete), + batch); +} + +} // namespace details + +template +template +std::shared_ptr> Subscriber::create( + Next&& next, + int64_t batch) { + return details::LambdaSubscriber::create(std::forward(next), batch); +} + +template +template +std::shared_ptr> +Subscriber::create(Next&& next, Error&& error, int64_t batch) { + return details::LambdaSubscriber::create( + std::forward(next), std::forward(error), batch); +} + +template +template +std::shared_ptr> Subscriber::create( + Next&& next, + Error&& error, + Complete&& complete, + int64_t batch) { + return details::LambdaSubscriber::create( + std::forward(next), + std::forward(error), + std::forward(complete), + batch); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Subscription.cpp b/yarpl/flowable/Subscription.cpp new file mode 100644 index 000000000..a49e1c97c --- /dev/null +++ b/yarpl/flowable/Subscription.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/Subscription.h" + +namespace yarpl { +namespace flowable { + +std::shared_ptr Subscription::create() { + class NullSubscription : public Subscription { + void request(int64_t) override {} + void cancel() override {} + }; + return std::make_shared(); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/Subscription.h b/yarpl/flowable/Subscription.h new file mode 100644 index 000000000..bc4c49bbe --- /dev/null +++ b/yarpl/flowable/Subscription.h @@ -0,0 +1,87 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Refcounted.h" + +namespace yarpl { +namespace flowable { + +class Subscription { + public: + virtual ~Subscription() = default; + + virtual void request(int64_t n) = 0; + virtual void cancel() = 0; + + static std::shared_ptr create(); + + template + static std::shared_ptr create(CancelFunc&& onCancel); + + template + static std::shared_ptr create( + CancelFunc&& onCancel, + RequestFunc&& onRequest); +}; + +namespace detail { + +template +class CallbackSubscription : public Subscription { + static_assert( + std::is_same, CancelFunc>::value, + "undecayed"); + static_assert( + std::is_same, RequestFunc>::value, + "undecayed"); + + public: + template + CallbackSubscription(FCancel&& onCancel, FRequest&& onRequest) + : onCancel_(std::forward(onCancel)), + onRequest_(std::forward(onRequest)) {} + + void request(int64_t n) override { + onRequest_(n); + } + void cancel() override { + onCancel_(); + } + + private: + CancelFunc onCancel_; + RequestFunc onRequest_; +}; +} // namespace detail + +template +std::shared_ptr Subscription::create( + CancelFunc&& onCancel, + RequestFunc&& onRequest) { + return std::make_shared, + std::decay_t>>( + std::forward(onCancel), std::forward(onRequest)); +} + +template +std::shared_ptr Subscription::create(CancelFunc&& onCancel) { + return Subscription::create( + std::forward(onCancel), [](int64_t) {}); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/TestSubscriber.h b/yarpl/flowable/TestSubscriber.h new file mode 100644 index 000000000..127b7fd0f --- /dev/null +++ b/yarpl/flowable/TestSubscriber.h @@ -0,0 +1,270 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "yarpl/flowable/Flowable.h" +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/utils/credits.h" + +namespace yarpl { +namespace flowable { + +/** + * A utility class for unit testing or experimenting with Flowable. + * + * Example usage: + * + * auto flowable = ... + * auto ts = TestSubscriber::create(); + * flowable->subscribe(to); + * ts->awaitTerminalEvent(); + * ts->assert... + */ +template +class TestSubscriber : public BaseSubscriber, + public yarpl::flowable::Subscription { + public: + static_assert( + std::is_copy_constructible::value, + "Requires copyable types in case of a delegate subscriber"); + + constexpr static auto kCanceled = credits::kCanceled; + constexpr static auto kNoFlowControl = credits::kNoFlowControl; + + /** + * Create a TestSubscriber that will subscribe and store the value it + * receives. + */ + static std::shared_ptr> create( + int64_t initial = kNoFlowControl) { + return std::make_shared>(initial); + } + + /** + * Create a TestSubscriber that will delegate all on* method calls + * to the provided Subscriber. + * + * This will store the value it receives to allow assertions. + */ + static std::shared_ptr> create( + std::shared_ptr> delegate, + int64_t initial = kNoFlowControl) { + return std::make_shared>(std::move(delegate), initial); + } + + explicit TestSubscriber(int64_t initial = kNoFlowControl) + : TestSubscriber(std::shared_ptr>{}, initial) {} + + explicit TestSubscriber( + std::shared_ptr> delegate, + int64_t initial = kNoFlowControl) + : delegate_(std::move(delegate)), initial_{initial} {} + + void onSubscribeImpl() override { + if (delegate_) { + delegate_->onSubscribe(this->ref_from_this(this)); + } + this->request(initial_); + } + + void onNextImpl(T t) override final { + manuallyPush(std::move(t)); + } + + void manuallyPush(T t) { + if (dropValues_) { + valueCount_++; + } else { + if (delegate_) { + values_.push_back(t); + delegate_->onNext(std::move(t)); + } else { + values_.push_back(std::move(t)); + } + } + + terminalEventCV_.notify_all(); + } + + void onCompleteImpl() override final { + if (delegate_) { + delegate_->onComplete(); + } + } + + void onErrorImpl(folly::exception_wrapper ex) override final { + if (delegate_) { + delegate_->onError(ex); + } + e_ = std::move(ex); + } + + void onTerminateImpl() override final { + std::unique_lock lk(m_); + terminated_ = true; + terminalEventCV_.notify_all(); + } + + // flowable::Subscription methods + void request(int64_t n) override { + this->BaseSubscriber::request(n); + } + void cancel() override { + this->BaseSubscriber::cancel(); + } + + /** + * Block the current thread until either onSuccess or onError is called. + */ + void awaitTerminalEvent( + std::chrono::milliseconds ms = std::chrono::seconds{1}) { + // now block this thread + std::unique_lock lk(m_); + // if shutdown gets implemented this would then be released by it + if (!terminalEventCV_.wait_for(lk, ms, [this] { return terminated_; })) { + throw std::runtime_error("timeout in awaitTerminalEvent"); + } + } + + void awaitValueCount( + int64_t n, + std::chrono::milliseconds ms = std::chrono::seconds{1}) { + // now block this thread + std::unique_lock lk(m_); + + auto didTimeOut = terminalEventCV_.wait_for(lk, ms, [this, n] { + if (getValueCount() < n && terminated_) { + std::stringstream msg; + msg << "onComplete/onError called before valueCount() == n;\nvalueCount: " + << getValueCount() << " != " << n; + throw std::runtime_error(msg.str()); + } + return getValueCount() >= n; + }); + + if (!didTimeOut) { + throw std::runtime_error("timeout in awaitValueCount"); + }; + } + + void assertValueCount(size_t count) { + if (values_.size() != count) { + std::stringstream ss; + ss << "Value count " << values_.size() << " does not match " << count; + throw std::runtime_error(ss.str()); + } + } + + int64_t getValueCount() { + if (dropValues_) { + return valueCount_; + } else { + return values_.size(); + } + } + + std::vector& values() { + return values_; + } + + const std::vector& values() const { + return values_; + } + + bool isComplete() const { + return terminated_ && !e_; + } + + bool isError() const { + return terminated_ && e_; + } + + const folly::exception_wrapper& exceptionWrapper() const { + return e_; + } + + std::string getErrorMsg() const { + return e_ ? e_.get_exception()->what() : ""; + } + + void assertValueAt(int64_t index, T expected) { + if (index < getValueCount()) { + auto& v = values_[index]; + if (expected != v) { + std::stringstream ss; + ss << "Expected: " << expected << " Actual: " << v; + throw std::runtime_error(ss.str()); + } + } else { + std::stringstream ss; + ss << "Index " << index << " is larger than received values " + << values_.size(); + throw std::runtime_error(ss.str()); + } + } + + /** + * If an onComplete call was not received throw a runtime_error + */ + void assertSuccess() { + if (!terminated_) { + throw std::runtime_error("Did not receive terminal event."); + } + if (e_) { + throw std::runtime_error("Received onError instead of onSuccess"); + } + } + + /** + * If the onError exception_wrapper points to an error containing + * the given msg, complete successfully, otherwise throw a runtime_error + */ + void assertOnErrorMessage(std::string msg) { + if (!e_ || e_.get_exception()->what() != msg) { + std::stringstream ss; + ss << "Error is: '" << e_ << "' but expected: '" << msg << "'"; + throw std::runtime_error(ss.str()); + } + } + + folly::exception_wrapper getException() const { + return e_; + } + + void dropValues(bool drop) { + valueCount_ = getValueCount(); + dropValues_ = drop; + } + + private: + bool dropValues_{false}; + std::atomic valueCount_{0}; + + std::shared_ptr> delegate_; + std::vector values_; + folly::exception_wrapper e_; + int64_t initial_{kNoFlowControl}; + bool terminated_{false}; + std::mutex m_; + std::condition_variable terminalEventCV_; + std::shared_ptr subscription_; +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/flowable/ThriftStreamShim.h b/yarpl/flowable/ThriftStreamShim.h new file mode 100644 index 000000000..7d42fef44 --- /dev/null +++ b/yarpl/flowable/ThriftStreamShim.h @@ -0,0 +1,263 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#pragma once + +#include +#if FOLLY_HAS_COROUTINES +#include +#include +#include +#endif +#include + +#include +#include +#include +#include + +namespace yarpl { +namespace flowable { +class ThriftStreamShim { + public: +#if FOLLY_HAS_COROUTINES + template + static std::shared_ptr> fromClientStream( + apache::thrift::ClientBufferedStream&& stream, + folly::Executor::KeepAlive<> ex) { + struct SharedState { + SharedState( + apache::thrift::detail::ClientStreamBridge::ClientPtr streamBridge, + folly::Executor::KeepAlive<> ex) + : streamBridge_(std::move(streamBridge)), + ex_(folly::SerialExecutor::create(std::move(ex))) {} + apache::thrift::detail::ClientStreamBridge::Ptr streamBridge_; + folly::Executor::KeepAlive ex_; + std::atomic canceled_{false}; + }; + + return yarpl::flowable::internal::flowableFromSubscriber( + [state = + std::make_shared(std::move(stream.streamBridge_), ex), + decode = + stream.decode_](std::shared_ptr> + subscriber) mutable { + class Subscription : public yarpl::flowable::Subscription { + public: + explicit Subscription(std::weak_ptr state) + : state_(std::move(state)) {} + + void request(int64_t n) override { + CHECK(n != yarpl::credits::kNoFlowControl) + << "kNoFlowControl unsupported"; + + if (auto state = state_.lock()) { + state->ex_->add([n, state = std::move(state)]() { + state->streamBridge_->requestN(n); + }); + } + } + + void cancel() override { + if (auto state = state_.lock()) { + state->ex_->add([state = std::move(state)]() { + state->streamBridge_->cancel(); + state->canceled_ = true; + }); + } + } + + private: + std::weak_ptr state_; + }; + + state->ex_->add([keepAlive = state->ex_.copy(), + subscriber, + subscription = std::make_shared( + std::weak_ptr(state))]() mutable { + subscriber->onSubscribe(std::move(subscription)); + }); + + folly::coro::co_invoke( + [subscriber = std::move(subscriber), + state, + decode]() mutable -> folly::coro::Task { + apache::thrift::detail::ClientStreamBridge::ClientQueue queue; + class ReadyCallback + : public apache::thrift::detail::ClientStreamConsumer { + public: + void consume() override { + baton.post(); + } + + void canceled() override { + baton.post(); + } + + folly::coro::Baton baton; + }; + + while (!state->canceled_) { + if (queue.empty()) { + ReadyCallback callback; + if (state->streamBridge_->wait(&callback)) { + co_await callback.baton; + } + queue = state->streamBridge_->getMessages(); + if (queue.empty()) { + // we've been cancelled + apache::thrift::detail::ClientStreamBridge::Ptr( + state->streamBridge_.release()); + break; + } + } + + { + auto& payload = queue.front(); + if (!payload.hasValue() && !payload.hasException()) { + state->ex_->add([subscriber = std::move(subscriber), + keepAlive = state->ex_.copy()] { + subscriber->onComplete(); + }); + break; + } + auto value = decode(std::move(payload)); + queue.pop(); + if (value.hasValue()) { + state->ex_->add([subscriber, + keepAlive = state->ex_.copy(), + value = std::move(value)]() mutable { + subscriber->onNext(std::move(value).value()); + }); + } else if (value.hasException()) { + state->ex_->add([subscriber = std::move(subscriber), + keepAlive = state->ex_.copy(), + value = std::move(value)]() mutable { + subscriber->onError(std::move(value).exception()); + }); + break; + } else { + LOG(FATAL) << "unreachable"; + } + } + } + }) + .scheduleOn(state->ex_) + .start(); + }); + } +#endif + + template + static apache::thrift::ServerStream toServerStream( + std::shared_ptr> flowable) { + class StreamServerCallbackAdaptor final + : public apache::thrift::StreamServerCallback, + public Subscriber { + public: + StreamServerCallbackAdaptor( + apache::thrift::detail::StreamElementEncoder* encode, + folly::EventBase* eb, + apache::thrift::TilePtr&& interaction) + : encode_(encode), + eb_(eb), + interaction_(apache::thrift::TileStreamGuard::transferFrom( + std::move(interaction))) {} + // StreamServerCallback implementation + bool onStreamRequestN(uint64_t tokens) override { + if (!subscription_) { + tokensBeforeSubscribe_ += tokens; + } else { + DCHECK_EQ(0, tokensBeforeSubscribe_); + subscription_->request(tokens); + } + return clientCallback_; + } + void onStreamCancel() override { + clientCallback_ = nullptr; + if (auto subscription = std::move(subscription_)) { + subscription->cancel(); + } + self_.reset(); + } + void resetClientCallback( + apache::thrift::StreamClientCallback& clientCallback) override { + clientCallback_ = &clientCallback; + } + + // Subscriber implementation + void onSubscribe(std::shared_ptr subscription) override { + eb_->add([this, subscription = std::move(subscription)]() mutable { + if (!clientCallback_) { + return subscription->cancel(); + } + + subscription_ = std::move(subscription); + if (auto tokens = std::exchange(tokensBeforeSubscribe_, 0)) { + subscription_->request(tokens); + } + }); + } + void onNext(T next) override { + eb_->add([this, next = std::move(next), s = self_]() mutable { + if (clientCallback_) { + std::ignore = + clientCallback_->onStreamNext(apache::thrift::StreamPayload{ + (*encode_)(std::move(next)).value().payload, {}}); + } + }); + } + void onError(folly::exception_wrapper ew) override { + eb_->add([this, ew = std::move(ew), s = self_]() mutable { + if (clientCallback_) { + std::exchange(clientCallback_, nullptr) + ->onStreamError((*encode_)(std::move(ew)).exception()); + self_.reset(); + } + }); + } + void onComplete() override { + eb_->add([this, s = self_] { + if (clientCallback_) { + std::exchange(clientCallback_, nullptr)->onStreamComplete(); + self_.reset(); + } + }); + } + + void takeRef(std::shared_ptr self) { + self_ = std::move(self); + } + + private: + apache::thrift::StreamClientCallback* clientCallback_{nullptr}; + std::shared_ptr subscription_; + uint32_t tokensBeforeSubscribe_{0}; + apache::thrift::detail::StreamElementEncoder* encode_; + folly::Executor::KeepAlive eb_; + std::shared_ptr self_; + apache::thrift::TileStreamGuard interaction_; + }; + + return apache::thrift::ServerStream( + [flowable = std::move(flowable)]( + folly::Executor::KeepAlive<>, + apache::thrift::detail::StreamElementEncoder* encode) mutable { + return apache::thrift::detail::ServerStreamFactory( + [flowable = std::move(flowable), encode]( + apache::thrift::FirstResponsePayload&& payload, + apache::thrift::StreamClientCallback* callback, + folly::EventBase* clientEb, + apache::thrift::TilePtr&& interaction) mutable { + auto stream = std::make_shared( + encode, clientEb, std::move(interaction)); + stream->takeRef(stream); + stream->resetClientCallback(*callback); + std::ignore = callback->onFirstResponse( + std::move(payload), clientEb, stream.get()); + flowable->subscribe(std::move(stream)); + }); + }); + } +}; +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/include/yarpl/Disposable.h b/yarpl/include/yarpl/Disposable.h deleted file mode 100644 index 1c0451e06..000000000 --- a/yarpl/include/yarpl/Disposable.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -namespace yarpl { - -/** - * Represents a disposable resource. - */ -class Disposable { - public: - Disposable() {} - virtual ~Disposable() = default; - Disposable(Disposable&&) = delete; - Disposable(const Disposable&) = delete; - Disposable& operator=(Disposable&&) = delete; - Disposable& operator=(const Disposable&) = delete; - - /** - * Dispose the resource, the operation should be idempotent. - */ - virtual void dispose() = 0; - - /** - * Returns true if this resource has been disposed. - * @return true if this resource has been disposed - */ - virtual bool isDisposed() = 0; -}; -} diff --git a/yarpl/include/yarpl/Flowable.h b/yarpl/include/yarpl/Flowable.h deleted file mode 100644 index 488272411..000000000 --- a/yarpl/include/yarpl/Flowable.h +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -// include all the things a developer needs for using Flowable -#include "yarpl/flowable/Flowable.h" -#include "yarpl/flowable/Flowables.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscribers.h" -#include "yarpl/flowable/Subscription.h" - -/** - * // TODO add documentation - */ diff --git a/yarpl/include/yarpl/Observable.h b/yarpl/include/yarpl/Observable.h deleted file mode 100644 index c70f71f77..000000000 --- a/yarpl/include/yarpl/Observable.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -// include all the things a developer needs for using Observable -#include "yarpl/observable/Observable.h" -#include "yarpl/observable/Observables.h" -#include "yarpl/observable/Observer.h" -#include "yarpl/observable/Observers.h" -#include "yarpl/observable/Subscription.h" -#include "yarpl/observable/Subscriptions.h" - -/** - * // TODO add documentation - */ diff --git a/yarpl/include/yarpl/Refcounted.h b/yarpl/include/yarpl/Refcounted.h deleted file mode 100644 index 2f30051f5..000000000 --- a/yarpl/include/yarpl/Refcounted.h +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace yarpl { - -/// Base of refcounted objects. The intention is the same as that -/// of boost::intrusive_ptr<>, except that we have virtual methods -/// anyway, and want to avoid argument-dependent lookup. -/// -/// NOTE: Only derive using "virtual public" inheritance. -class Refcounted { - public: - virtual ~Refcounted() = default; - - // Return the current count. For testing. - std::size_t count() const { - return refcount_; - } - - // Not intended to be broadly used by the application code mostly for library - // code (static to purposely make it more awkward). - static void incRef(Refcounted& obj) { - obj.incRef(); - } - - // Not intended to be broadly used by the application code mostly for library - // code (static to purposely make it more awkward). - static void decRef(Refcounted& obj) { - obj.decRef(); - } - - private: - void incRef() { - refcount_.fetch_add(1, std::memory_order_relaxed); - } - - void decRef() { - auto previous = refcount_.fetch_sub(1, std::memory_order_relaxed); - assert(previous >= 1 && "decRef on a destroyed object!"); - if (previous == 1) { - std::atomic_thread_fence(std::memory_order_acquire); - delete this; - } - } - - mutable std::atomic_size_t refcount_{0}; -}; - -/// RAII-enabling smart pointer for refcounted objects. Each reference -/// constructed against a target refcounted object increases its count by 1 -/// during its lifetime. -template -class Reference { - public: - template - friend class Reference; - - Reference() = default; - inline /* implicit */ Reference(std::nullptr_t) {} - - explicit Reference(T* pointer) : pointer_(pointer) { - inc(); - } - - ~Reference() { - dec(); - } - - ////////////////////////////////////////////////////////////////////////////// - - Reference(const Reference& other) : pointer_(other.pointer_) { - inc(); - } - - Reference(Reference&& other) noexcept : pointer_(other.pointer_) { - other.pointer_ = nullptr; - } - - template - Reference(const Reference& other) : pointer_(other.pointer_) { - inc(); - } - - template - Reference(Reference&& other) : pointer_(other.pointer_) { - other.pointer_ = nullptr; - } - - ////////////////////////////////////////////////////////////////////////////// - - Reference& operator=(std::nullptr_t) { - reset(); - return *this; - } - - Reference& operator=(const Reference& other) { - return assign(other); - } - - Reference& operator=(Reference&& other) { - return assign(std::move(other)); - } - - template - Reference& operator=(const Reference& other) { - return assign(other); - } - - template - Reference& operator=(Reference&& other) { - return assign(std::move(other)); - } - - ////////////////////////////////////////////////////////////////////////////// - - T* get() const { - return pointer_; - } - - T& operator*() const { - return *pointer_; - } - - T* operator->() const { - return pointer_; - } - - void reset() { - Reference{}.swap(*this); - } - - explicit operator bool() const { - return pointer_; - } - - private: - void inc() { - static_assert( - std::is_base_of::value, - "Reference must be used with types that virtually derive Refcounted"); - - if (pointer_) { - Refcounted::incRef(*pointer_); - } - } - - void dec() { - static_assert( - std::is_base_of::value, - "Reference must be used with types that virtually derive Refcounted"); - - if (pointer_) { - Refcounted::decRef(*pointer_); - } - } - - void swap(Reference& other) { - std::swap(pointer_, other.pointer_); - } - - template - Reference& assign(Ref&& other) { - Reference temp(std::forward(other)); - swap(temp); - return *this; - } - - T* pointer_{nullptr}; -}; - -template -bool operator==(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() == rhs.get(); -} - -template -bool operator==(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() == nullptr; -} - -template -bool operator==(std::nullptr_t, const Reference& rhs) noexcept { - return rhs.get() == nullptr; -} - -template -bool operator!=(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() != rhs.get(); -} - -template -bool operator!=(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() != nullptr; -} - -template -bool operator!=(std::nullptr_t, const Reference& rhs) noexcept { - return rhs.get() != nullptr; -} - -template -bool operator<(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() < rhs.get(); -} - -template -bool operator<(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() < nullptr; -} - -template -bool operator<(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr < rhs.get(); -} - -template -bool operator<=(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() <= rhs.get(); -} - -template -bool operator<=(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() <= nullptr; -} - -template -bool operator<=(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr <= rhs.get(); -} - -template -bool operator>(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() > rhs.get(); -} - -template -bool operator>(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() > nullptr; -} - -template -bool operator>(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr > rhs.get(); -} - -template -bool operator>=(const Reference& lhs, const Reference& rhs) noexcept { - return lhs.get() >= rhs.get(); -} - -template -bool operator>=(const Reference& lhs, std::nullptr_t) noexcept { - return lhs.get() >= nullptr; -} - -template -bool operator>=(std::nullptr_t, const Reference& rhs) noexcept { - return nullptr >= rhs.get(); -} - -//////////////////////////////////////////////////////////////////////////////// - -template -Reference make_ref(Args&&... args) { - return Reference(new T(std::forward(args)...)); -} - -template -Reference get_ref(T& object) { - return Reference(&object); -} - -template -Reference get_ref(T* object) { - return Reference(object); -} - -} // namespace yarpl - -// -// custom specialization of std::hash> -// -namespace std -{ -template -struct hash> -{ - typedef yarpl::Reference argument_type; - typedef typename std::hash::result_type result_type; - - result_type operator()(argument_type const& s) const - { - return std::hash()(s.get()); - } -}; -} diff --git a/yarpl/include/yarpl/Scheduler.h b/yarpl/include/yarpl/Scheduler.h deleted file mode 100644 index eca844d55..000000000 --- a/yarpl/include/yarpl/Scheduler.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include "yarpl/Disposable.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { - -class Worker : public yarpl::Disposable { - public: - Worker() {} - Worker(Worker&&) = delete; - Worker(const Worker&) = delete; - Worker& operator=(Worker&&) = delete; - Worker& operator=(const Worker&) = delete; - - // template < - // typename F, - // typename = typename std::enable_if< - // std::is_callable::type>::value>:: - // type> - // virtual yarpl::Disposable schedule(F&&) = 0; // TODO can't do this, so how - // do we allow different impls? - - virtual std::unique_ptr schedule( - std::function&&) = 0; - - virtual void dispose() override = 0; - - virtual bool isDisposed() override = 0; - - // TODO add schedule methods with delays and periodical execution -}; - -class Scheduler { - public: - Scheduler() {} - virtual ~Scheduler() = default; - Scheduler(Scheduler&&) = delete; - Scheduler(const Scheduler&) = delete; - Scheduler& operator=(Scheduler&&) = delete; - Scheduler& operator=(const Scheduler&) = delete; - /** - * - * Retrieves or creates a new {@link Scheduler.Worker} that represents serial - * execution of actions. - *

- * When work is completed it should be disposed using - * Scheduler::Worker::dispose(). - *

- * Work on a Scheduler::Worker is guaranteed to be sequential. - * - * @return a Worker representing a serial queue of actions to be executed - */ - virtual std::unique_ptr createWorker() = 0; -}; -} diff --git a/yarpl/include/yarpl/Single.h b/yarpl/include/yarpl/Single.h deleted file mode 100644 index 2c2ad24b9..000000000 --- a/yarpl/include/yarpl/Single.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -// include all the things a developer needs for using Single -#include "yarpl/single/Single.h" -#include "yarpl/single/SingleObserver.h" -#include "yarpl/single/SingleObservers.h" -#include "yarpl/single/SingleSubscriptions.h" -#include "yarpl/single/Singles.h" - -/** - * Create a single with code such as this: - * - * auto a = Single::create([](Reference> obs) { - * obs->onSubscribe(SingleSubscriptions::empty()); - * obs->onSuccess(1); - * }); - * - * // TODO add more documentation - */ diff --git a/yarpl/include/yarpl/flowable/Flowable.h b/yarpl/include/yarpl/flowable/Flowable.h deleted file mode 100644 index 70576f07e..000000000 --- a/yarpl/include/yarpl/flowable/Flowable.h +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/Scheduler.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscribers.h" -#include "yarpl/utils/credits.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace flowable { - -template -class Flowable : public virtual Refcounted { - constexpr static auto kCanceled = credits::kCanceled; - constexpr static auto kNoFlowControl = credits::kNoFlowControl; - - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename = - typename std::enable_if::value>::type> - void subscribe(Next&& next, int64_t batch = kNoFlowControl) { - subscribe(Subscribers::create(next, batch)); - } - - /** - * Subscribe overload that accepts lambdas. - * - * Takes an optional batch size for request_n. Default is no flow control. - */ - template < - typename Next, - typename Error, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value>::type> - void subscribe( - Next&& next, - Error&& error, - int64_t batch = kNoFlowControl) { - subscribe(Subscribers::create(next, error, batch)); - } - - /** - * Subscribe overload that accepts lambdas. - * - * Takes an optional batch size for request_n. Default is no flow control. - */ - template < - typename Next, - typename Error, - typename Complete, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value && - std::is_callable::value>::type> - void subscribe( - Next&& next, - Error&& error, - Complete&& complete, - int64_t batch = kNoFlowControl) { - subscribe(Subscribers::create(next, error, complete, batch)); - } - - template - auto map(Function&& function); - - template - auto filter(Function&& function); - - template - auto reduce(Function&& function); - - auto take(int64_t); - - auto skip(int64_t); - - auto ignoreElements(); - - auto subscribeOn(Scheduler&); - - /** - * \brief Create a flowable from an emitter. - * - * \param emitter function that is invoked to emit values to a subscriber. - * The emitter's signature is: - * \code{.cpp} - * std::tuple emitter(Subscriber&, int64_t requested); - * \endcode - * - * The emitter can invoke up to \b requested calls to `onNext()`, and can - * optionally make a final call to `onComplete()` or `onError()`; returns - * the actual number of `onNext()` calls; and whether the subscription is - * finished (completed/in error). - * - * \return a handle to a flowable that will use the emitter. - */ - template - class EmitterWrapper; - - template < - typename Emitter, - typename = typename std::enable_if&, int64_t), - std::tuple>::value>::type> - static auto create(Emitter&& emitter); - - private: - virtual std::tuple emit(Subscriber&, int64_t) { - return std::make_tuple(static_cast(0), false); - } - - /** - * Manager for a flowable subscription. - * - * This is synchronous: the emit calls are triggered within the context - * of a request(n) call. - */ - class SynchronousSubscription : private Subscription, private Subscriber { - public: - SynchronousSubscription( - Reference flowable, - Reference> subscriber) - : flowable_(std::move(flowable)), subscriber_(std::move(subscriber)) { - // We expect to be heap-allocated; until this subscription finishes - // (is canceled; completes; error's out), hold a reference so we are - // not deallocated (by the subscriber). - Refcounted::incRef(*this); - subscriber_->onSubscribe(Reference(this)); - } - - virtual ~SynchronousSubscription() { - subscriber_.reset(); - } - - void request(int64_t delta) override { - if (delta <= 0) { - auto message = "request(n): " + std::to_string(delta) + " <= 0"; - throw std::logic_error(message); - } - - while (true) { - auto current = requested_.load(std::memory_order_relaxed); - - if (current == kCanceled) { - // this can happen because there could be an async barrier between - // the subscriber and the subscription - // for instance while onComplete is being delivered - // (on effectively cancelled subscription) the subscriber can call call request(n) - return; - } - - auto const total = credits::add(current, delta); - if (requested_.compare_exchange_strong(current, total)) { - break; - } - } - - process(); - } - - void cancel() override { - // if this is the first terminating signal to receive, we need to - // make sure we break the reference cycle between subscription and - // subscriber - // - auto previous = requested_.exchange(kCanceled, std::memory_order_relaxed); - if(previous != kCanceled) { - // this can happen because there could be an async barrier between - // the subscriber and the subscription - // for instance while onComplete is being delivered - // (on effectively cancelled subscription) the subscriber can call call request(n) - process(); - } - } - - // Subscriber methods. - void onSubscribe(Reference) override { - LOG(FATAL) << "Do not call this method"; - } - - void onNext(T value) override { - subscriber_->onNext(std::move(value)); - } - - void onComplete() override { - // we will set the flag first to save a potential call to lock.try_lock() - // in the process method via cancel or request methods - auto old = requested_.exchange(kCanceled, std::memory_order_relaxed); - DCHECK_NE(old, kCanceled) << "Calling onComplete or onError twice or on " - << "canceled subscription"; - - subscriber_->onComplete(); - // We should already be in process(); nothing more to do. - // - // Note: we're not invoking the Subscriber superclass' method: - // we're following the Subscription's protocol instead. - } - - void onError(std::exception_ptr error) override { - // we will set the flag first to save a potential call to lock.try_lock() - // in the process method via cancel or request methods - auto old = requested_.exchange(kCanceled, std::memory_order_relaxed); - DCHECK_NE(old, kCanceled) << "Calling onComplete or onError twice or on " - << "canceled subscription"; - - subscriber_->onError(error); - // We should already be in process(); nothing more to do. - // - // Note: we're not invoking the Subscriber superclass' method: - // we're following the Subscription's protocol instead. - } - - private: - // Processing loop. Note: this can delete `this` upon completion, - // error, or cancellation; thus, no fields should be accessed once - // this method returns. - // - // Thread-Safety: there is no guarantee as to which thread this is - // invoked on. However, there is a strong guarantee on cancel and - // request(n) calls: no more than one instance of either of these - // can be outstanding at any time. - void process() { - // This lock guards against re-entrancy in request(n) calls. By - // the strict terms of the subscriber guarantees, this could be - // replaced by a re-entrancy count. - std::unique_lock lock(processing_, std::defer_lock); - if (!lock.try_lock()) { - return; - } - - while (true) { - auto current = requested_.load(std::memory_order_relaxed); - - // Subscription was canceled, completed, or had an error. - if (current == kCanceled) { - // Don't destroy a locked mutex. - lock.unlock(); - - release(); - return; - } - - // If no more items can be emitted now, wait for a request(n). - // See note above re: thread-safety. We are guaranteed that - // request(n) is not simultaneously invoked on another thread. - if (current <= 0) - return; - - int64_t emitted; - bool done; - - std::tie(emitted, done) = flowable_->emit( - *this /* implicit conversion to subscriber */, current); - - while (true) { - current = requested_.load(std::memory_order_relaxed); - if (current == kCanceled || (current == kNoFlowControl && !done)) { - break; - } - - auto updated = done ? kCanceled : current - emitted; - if (requested_.compare_exchange_strong(current, updated)) { - break; - } - } - } - } - - void release() { - flowable_.reset(); - subscriber_.reset(); - Refcounted::decRef(*this); - } - - // The number of items that can be sent downstream. Each request(n) - // adds n; each onNext consumes 1. If this is MAX, flow-control is - // disabled: items sent downstream don't consume any longer. A MIN - // value represents cancellation. Other -ve values aren't permitted. - std::atomic_int_fast64_t requested_{0}; - - // We don't want to recursively invoke process(); one loop should do. - std::mutex processing_; - - Reference flowable_; - Reference> subscriber_; - }; -}; - -} // flowable -} // yarpl - -#include "yarpl/flowable/FlowableOperator.h" - -namespace yarpl { -namespace flowable { - -template -template -class Flowable::EmitterWrapper : public Flowable { - public: - explicit EmitterWrapper(Emitter&& emitter) - : emitter_(std::forward(emitter)) {} - - void subscribe(Reference> subscriber) override { - new SynchronousSubscription( - Reference(this), std::move(subscriber)); - } - - std::tuple emit(Subscriber& subscriber, int64_t requested) - override { - return emitter_(subscriber, requested); - } - - private: - Emitter emitter_; -}; - -template -template -auto Flowable::create(Emitter&& emitter) { - return Reference>( - new Flowable::EmitterWrapper(std::forward(emitter))); -} - -template -template -auto Flowable::map(Function&& function) { - using D = typename std::result_of::type; - return Reference>(new MapOperator( - Reference>(this), std::forward(function))); -} - -template -template -auto Flowable::filter(Function&& function) { - return Reference>(new FilterOperator( - Reference>(this), std::forward(function))); -} - -template -template -auto Flowable::reduce(Function&& function) { - using D = typename std::result_of::type; - return Reference>(new ReduceOperator( - Reference>(this), std::forward(function))); -} - -template -auto Flowable::take(int64_t limit) { - return Reference>( - new TakeOperator(Reference>(this), limit)); -} - -template -auto Flowable::skip(int64_t offset) { - return Reference>( - new SkipOperator(Reference>(this), offset)); -} - -template -auto Flowable::ignoreElements() { - return Reference>( - new IgnoreElementsOperator(Reference>(this))); -} - -template -auto Flowable::subscribeOn(Scheduler& scheduler) { - return Reference>( - new SubscribeOnOperator(Reference>(this), scheduler)); -} - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/FlowableOperator.h b/yarpl/include/yarpl/flowable/FlowableOperator.h deleted file mode 100644 index f86f33405..000000000 --- a/yarpl/include/yarpl/flowable/FlowableOperator.h +++ /dev/null @@ -1,501 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include "yarpl/flowable/Flowable.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscription.h" -#include "yarpl/utils/credits.h" - -namespace yarpl { -namespace flowable { - -/** - * Base (helper) class for operators. Operators are templated on two types: D - * (downstream) and U (upstream). Operators are created by method calls on an - * upstream Flowable, and are Flowables themselves. Multi-stage pipelines can - * be built: a Flowable heading a sequence of Operators. - */ -template -class FlowableOperator : public Flowable { - public: - explicit FlowableOperator(Reference> upstream) - : upstream_(std::move(upstream)) {} - - protected: - /// An Operator's subscription. - /// - /// When a pipeline chain is active, each Flowable has a corresponding - /// subscription. Except for the first one, the subscriptions are created - /// against Operators. Each operator subscription has two functions: as a - /// subscriber for the previous stage; as a subscription for the next one, the - /// user-supplied subscriber being the last of the pipeline stages. - class Subscription : public yarpl::flowable::Subscription, - public Subscriber { - protected: - Subscription( - Reference> flowable, - Reference> subscriber) - : flowable_(std::move(flowable)), subscriber_(std::move(subscriber)) { - assert(flowable_); - assert(subscriber_); - - // We expect to be heap-allocated; until this subscription finishes (is - // canceled; completes; error's out), hold a reference so we are not - // deallocated (by the subscriber). - Refcounted::incRef(*this); - } - - template - TOperator* getFlowableAs() { - return static_cast(flowable_.get()); - } - - void subscriberOnNext(D value) { - if (subscriber_) { - subscriber_->onNext(std::move(value)); - } - } - - /// Terminates both ends of an operator normally. - void terminate() { - terminateImpl(TerminateState::Both()); - } - - /// Terminates both ends of an operator with an error. - void terminateErr(std::exception_ptr eptr) { - terminateImpl(TerminateState::Both(), std::move(eptr)); - } - - // Subscription. - - void request(int64_t delta) override { - if (upstream_) { - upstream_->request(delta); - } - } - - void cancel() override { - terminateImpl(TerminateState::Up()); - } - - // Subscriber. - - void onSubscribe( - Reference subscription) override { - if (upstream_) { - subscription->cancel(); - return; - } - - upstream_ = std::move(subscription); - subscriber_->onSubscribe(Reference(this)); - } - - void onComplete() override { - terminateImpl(TerminateState::Down()); - } - - void onError(std::exception_ptr eptr) override { - terminateImpl(TerminateState::Down(), std::move(eptr)); - } - - private: - struct TerminateState { - TerminateState(bool u, bool d) : up{u}, down{d} {} - - static TerminateState Down() { - return TerminateState{false, true}; - } - - static TerminateState Up() { - return TerminateState{true, false}; - } - - static TerminateState Both() { - return TerminateState{true, true}; - } - - const bool up{false}; - const bool down{false}; - }; - - bool isTerminated() const { - return !upstream_ && !subscriber_; - } - - /// Terminates an operator, sending cancel() and on{Complete,Error}() - /// signals as necessary. - void terminateImpl( - TerminateState state, - std::exception_ptr eptr = nullptr) { - if (isTerminated()) { - return; - } - - if (auto upstream = std::move(upstream_)) { - if (state.up) { - upstream->cancel(); - } - } - - if (auto subscriber = std::move(subscriber_)) { - if (state.down) { - if (eptr) { - subscriber->onError(std::move(eptr)); - } else { - subscriber->onComplete(); - } - } - } - - Refcounted::decRef(*this); - } - - /// The Flowable has the lambda, and other creation parameters. - Reference> flowable_; - - /// This subscription controls the life-cycle of the subscriber. The - /// subscriber is retained as long as calls on it can be made. (Note: the - /// subscriber in turn maintains a reference on this subscription object - /// until cancellation and/or completion.) - Reference> subscriber_; - - /// In an active pipeline, cancel and (possibly modified) request(n) calls - /// should be forwarded upstream. Note that `this` is also a subscriber for - /// the upstream stage: thus, there are cycles; all of the objects drop - /// their references at cancel/complete. - Reference upstream_; - }; - - Reference> upstream_; -}; - -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>::type> -class MapOperator : public FlowableOperator { - public: - MapOperator(Reference> upstream, F&& function) - : FlowableOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)) {} - - void onNext(U value) override { - auto map = Super::template getFlowableAs(); - Super::subscriberOnNext(map->function_(std::move(value))); - } - }; - - F function_; -}; - -template < - typename U, - typename F, - typename = - typename std::enable_if::value>::type> -class FilterOperator : public FlowableOperator { - public: - FilterOperator(Reference> upstream, F&& function) - : FlowableOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)) {} - - void onNext(U value) override { - auto filter = Super::template getFlowableAs(); - if (filter->function_(value)) { - Super::subscriberOnNext(std::move(value)); - } else { - Super::request(1); - } - } - }; - - F function_; -}; - -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>, - typename = - typename std::enable_if::value>::type> -class ReduceOperator : public FlowableOperator { - public: - ReduceOperator(Reference> upstream, F&& function) - : FlowableOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)), - accInitialized_(false) {} - - void request(int64_t) override { - // Request all of the items - Super::request(credits::kNoFlowControl); - } - - void onNext(U value) override { - auto reduce = Super::template getFlowableAs(); - if (accInitialized_) { - acc_ = reduce->function_(std::move(acc_), std::move(value)); - } else { - acc_ = std::move(value); - accInitialized_ = true; - } - } - - void onComplete() override { - if (accInitialized_) { - Super::subscriberOnNext(std::move(acc_)); - } - Super::onComplete(); - } - - private: - bool accInitialized_; - D acc_; - }; - - F function_; -}; - -template -class TakeOperator : public FlowableOperator { - public: - TakeOperator(Reference> upstream, int64_t limit) - : FlowableOperator(std::move(upstream)), limit_(limit) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), limit_, std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - int64_t limit, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)), limit_(limit) {} - - void onNext(T value) override { - if (limit_-- > 0) { - if (pending_ > 0) { - --pending_; - } - Super::subscriberOnNext(std::move(value)); - if (limit_ == 0) { - Super::terminate(); - } - } - } - - void request(int64_t delta) override { - delta = std::min(delta, limit_ - pending_); - if (delta > 0) { - pending_ += delta; - Super::request(delta); - } - } - - private: - int64_t pending_{0}; - int64_t limit_; - }; - - const int64_t limit_; -}; - -template -class SkipOperator : public FlowableOperator { - public: - SkipOperator(Reference> upstream, int64_t offset) - : FlowableOperator(std::move(upstream)), offset_(offset) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), offset_, std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - int64_t offset, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)), offset_(offset) {} - - void onNext(T value) override { - if (offset_ > 0) { - --offset_; - } else { - Super::subscriberOnNext(std::move(value)); - } - } - - void request(int64_t delta) override { - if (firstRequest_) { - firstRequest_ = false; - delta = credits::add(delta, offset_); - } - Super::request(delta); - } - - private: - int64_t offset_; - bool firstRequest_{true}; - }; - - const int64_t offset_; -}; - -template -class IgnoreElementsOperator : public FlowableOperator { - public: - explicit IgnoreElementsOperator(Reference> upstream) - : FlowableOperator(std::move(upstream)) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)) {} - - void onNext(T) override {} - }; -}; - -template -class SubscribeOnOperator : public FlowableOperator { - public: - SubscribeOnOperator(Reference> upstream, Scheduler& scheduler) - : FlowableOperator(std::move(upstream)), - worker_(scheduler.createWorker()) {} - - void subscribe(Reference> subscriber) override { - FlowableOperator::upstream_->subscribe(make_ref( - Reference>(this), - std::move(worker_), - std::move(subscriber))); - } - - private: - class Subscription : public FlowableOperator::Subscription { - using Super = typename FlowableOperator::Subscription; - - public: - Subscription( - Reference> flowable, - std::unique_ptr worker, - Reference> subscriber) - : Super(std::move(flowable), std::move(subscriber)), - worker_(std::move(worker)) {} - - void request(int64_t delta) override { - worker_->schedule([delta, this] { this->callSuperRequest(delta); }); - } - - void cancel() override { - worker_->schedule([this] { this->callSuperCancel(); }); - } - - void onNext(T value) override { - Super::subscriberOnNext(std::move(value)); - } - - private: - // Trampoline to call superclass method; gcc bug 58972. - void callSuperRequest(int64_t delta) { - Super::request(delta); - } - - // Trampoline to call superclass method; gcc bug 58972. - void callSuperCancel() { - Super::cancel(); - } - - std::unique_ptr worker_; - }; - - std::unique_ptr worker_; -}; - -template -class FromPublisherOperator : public Flowable { - public: - explicit FromPublisherOperator(OnSubscribe&& function) - : function_(std::move(function)) {} - - void subscribe(Reference> subscriber) override { - function_(std::move(subscriber)); - } - - private: - OnSubscribe function_; -}; - -} // namespace flowable -} // namespace yarpl diff --git a/yarpl/include/yarpl/flowable/Flowable_FromObservable.h b/yarpl/include/yarpl/flowable/Flowable_FromObservable.h deleted file mode 100644 index f86a8d04b..000000000 --- a/yarpl/include/yarpl/flowable/Flowable_FromObservable.h +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Flowable.h" -#include "yarpl/utils/credits.h" - -namespace yarpl { -namespace observable { -template -class Observable; -} -} - -namespace yarpl { -namespace flowable { -namespace sources { - -template -class FlowableFromObservableSubscription - : public yarpl::flowable::Subscription, - public yarpl::observable::Observer { - public: - FlowableFromObservableSubscription( - Reference> observable, - Reference> s) - : observable_(std::move(observable)), subscriber_(std::move(s)) { - // We expect to be heap-allocated; until this subscription finishes - // (is canceled; completes; error's out), hold a reference so we are - // not deallocated (by the subscriber). - Refcounted::incRef(*this); - } - - FlowableFromObservableSubscription(FlowableFromObservableSubscription&&) = - delete; - - FlowableFromObservableSubscription( - const FlowableFromObservableSubscription&) = delete; - FlowableFromObservableSubscription& operator=( - FlowableFromObservableSubscription&&) = delete; - FlowableFromObservableSubscription& operator=( - const FlowableFromObservableSubscription&) = delete; - - void request(int64_t n) override { - if (n <= 0) { - return; - } - auto const r = credits::add(&requested_, n); - if (r <= 0) { - return; - } - - if (!started) { - bool expected = false; - if (started.compare_exchange_strong(expected, true)) { - observable_->subscribe(Reference>(this)); - } - } - } - - void cancel() override { - if (credits::cancel(&requested_)) { - // if this is the first time calling cancel, send the cancel - observableSubscription_->cancel(); - release(); - } - } - - // Observer override - void onSubscribe( - Reference subscription) override { - observableSubscription_ = subscription; - } - - // Observer override - void onNext(T t) override { - if (requested_ > 0) { - subscriber_->onNext(std::move(t)); - credits::consume(&requested_, 1); - } - // drop anything else received while we don't have credits - } - - // Observer override - void onComplete() override { - subscriber_->onComplete(); - release(); - } - - // Observer override - void onError(std::exception_ptr error) override { - subscriber_->onError(error); - release(); - } - - private: - void release() { - observable_.reset(); - subscriber_.reset(); - observableSubscription_.reset(); - Refcounted::decRef(*this); - } - - Reference> observable_; - Reference> subscriber_; - Reference observableSubscription_; - std::atomic_bool started{false}; - std::atomic requested_{0}; -}; -} -} -} diff --git a/yarpl/include/yarpl/flowable/Flowables.h b/yarpl/include/yarpl/flowable/Flowables.h deleted file mode 100644 index 7b8efa1cd..000000000 --- a/yarpl/include/yarpl/flowable/Flowables.h +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include - -#include - -#include "yarpl/flowable/Flowable.h" - -namespace yarpl { -namespace flowable { - -class Flowables { - public: - /** - * Emit the sequence of numbers [start, start + count). - */ - static Reference> range(int64_t start, int64_t count) { - auto lambda = [ start, count, i = start ]( - Subscriber & subscriber, int64_t requested) mutable { - int64_t emitted = 0; - bool done = false; - int64_t end = start + count; - - while (i < end && emitted < requested) { - subscriber.onNext(i++); - ++emitted; - } - - if (i >= end) { - subscriber.onComplete(); - done = true; - } - - return std::make_tuple(requested, done); - }; - - return Flowable::create(std::move(lambda)); - } - - template - static Reference> just(const T& value) { - auto lambda = [value](Subscriber& subscriber, int64_t) { - // # requested should be > 0. Ignoring the actual parameter. - subscriber.onNext(value); - subscriber.onComplete(); - return std::make_tuple(static_cast(1), true); - }; - - return Flowable::create(std::move(lambda)); - } - - template - static Reference> justN(std::initializer_list list) { - std::vector vec(list); - - auto lambda = [ v = std::move(vec), i = size_t{0} ]( - Subscriber & subscriber, int64_t requested) mutable { - int64_t emitted = 0; - bool done = false; - - while (i < v.size() && emitted < requested) { - subscriber.onNext(v[i++]); - ++emitted; - } - - if (i == v.size()) { - subscriber.onComplete(); - done = true; - } - - return std::make_tuple(emitted, done); - }; - - return Flowable::create(std::move(lambda)); - } - - // this will generate a flowable which can be subscribed to only once - template - static Reference> justOnce(T value) { - auto lambda = [value = std::move(value), used = false](Subscriber& subscriber, int64_t) mutable { - if (used) { - subscriber.onError( - std::make_exception_ptr(std::runtime_error("justOnce value was already used"))); - return std::make_tuple(static_cast(0), true); - } - - used = true; - // # requested should be > 0. Ignoring the actual parameter. - subscriber.onNext(std::move(value)); - subscriber.onComplete(); - return std::make_tuple(static_cast(1), true); - }; - - return Flowable::create(std::move(lambda)); - } - - template < - typename T, - typename OnSubscribe, - typename = typename std::enable_if>), - void>::value>::type> - static Reference> fromPublisher(OnSubscribe&& function) { - return Reference>(new FromPublisherOperator( - std::forward(function))); - } - - template - static Reference> empty() { - auto lambda = [](Subscriber& subscriber, int64_t) { - subscriber.onComplete(); - return std::make_tuple(static_cast(0), true); - }; - return Flowable::create(std::move(lambda)); - } - - template - static Reference> error(std::exception_ptr ex) { - auto lambda = [ex](Subscriber& subscriber, int64_t) { - subscriber.onError(ex); - return std::make_tuple(static_cast(0), true); - }; - return Flowable::create(std::move(lambda)); - } - - template - static Reference> error(const ExceptionType& ex) { - auto lambda = [ex](Subscriber& subscriber, int64_t) { - subscriber.onError(std::make_exception_ptr(ex)); - return std::make_tuple(static_cast(0), true); - }; - return Flowable::create(std::move(lambda)); - } - - template - static Reference> fromGenerator(TGenerator generator) { - auto lambda = [generator = std::move(generator)] - (Subscriber& subscriber, int64_t requested) { - int64_t generated = 0; - try { - while (generated < requested) { - subscriber.onNext(generator()); - ++generated; - } - return std::make_tuple(generated, false); - } catch(...) { - subscriber.onError(std::current_exception()); - return std::make_tuple(generated, true); - } - }; - return Flowable::create(std::move(lambda)); - } - - private: - Flowables() = delete; -}; - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/Subscriber.h b/yarpl/include/yarpl/flowable/Subscriber.h deleted file mode 100644 index bb5461857..000000000 --- a/yarpl/include/yarpl/flowable/Subscriber.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/flowable/Subscription.h" - -namespace yarpl { -namespace flowable { - -template -class Subscriber : public virtual Refcounted { - public: - // Note: if any of the following methods is overridden in a subclass, - // the new methods SHOULD ensure that these are invoked as well. - virtual void onSubscribe(Reference subscription) { - subscription_ = subscription; - } - - // No further calls to the subscription after this method is invoked. - virtual void onComplete() { - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(std::exception_ptr) { - subscription_.reset(); - } - - virtual void onNext(T) = 0; - - protected: - Subscription* subscription() { - return subscription_.operator->(); - } - - private: - // "Our" reference to the subscription, to ensure that it is retained - // while calls to its methods are in-flight. - Reference subscription_{nullptr}; -}; - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/Subscribers.h b/yarpl/include/yarpl/flowable/Subscribers.h deleted file mode 100644 index 984368ac7..000000000 --- a/yarpl/include/yarpl/flowable/Subscribers.h +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/utils/credits.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace flowable { - -/// Helper methods for constructing subscriber instances from functions: -/// one, two, or three functions (callables; can be lamda, for instance) -/// may be specified, corresponding to onNext, onError and onSubscribe -/// method bodies in the subscriber. -class Subscribers { - constexpr static auto kNoFlowControl = credits::kNoFlowControl; - - public: - template < - typename T, - typename Next, - typename = - typename std::enable_if::value>::type> - static auto create(Next&& next, int64_t batch = kNoFlowControl) { - return Reference>( - new Base(std::forward(next), batch)); - } - - template < - typename T, - typename Next, - typename Error, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value>::type> - static auto - create(Next&& next, Error&& error, int64_t batch = kNoFlowControl) { - return Reference>(new WithError( - std::forward(next), std::forward(error), batch)); - } - - template < - typename T, - typename Next, - typename Error, - typename Complete, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value && - std::is_callable::value>::type> - static auto create( - Next&& next, - Error&& error, - Complete&& complete, - int64_t batch = kNoFlowControl) { - return Reference>( - new WithErrorAndComplete( - std::forward(next), - std::forward(error), - std::forward(complete), - batch)); - } - - private: - template - class Base : public Subscriber { - public: - Base(Next&& next, int64_t batch) - : next_(std::forward(next)), batch_(batch), pending_(0) {} - - void onSubscribe(Reference subscription) override { - Subscriber::onSubscribe(subscription); - pending_ += batch_; - subscription->request(batch_); - } - - void onNext(T value) override { - next_(std::move(value)); - if (--pending_ < batch_ / 2) { - const auto delta = batch_ - pending_; - pending_ += delta; - Subscriber::subscription()->request(delta); - } - } - - private: - Next next_; - const int64_t batch_; - int64_t pending_; - }; - - template - class WithError : public Base { - public: - WithError(Next&& next, Error&& error, int64_t batch) - : Base(std::forward(next), batch), error_(error) {} - - void onError(std::exception_ptr error) override { - Subscriber::onError(error); - error_(error); - } - - private: - Error error_; - }; - - template - class WithErrorAndComplete : public WithError { - public: - WithErrorAndComplete( - Next&& next, - Error&& error, - Complete&& complete, - int64_t batch) - : WithError( - std::forward(next), - std::forward(error), - batch), - complete_(complete) {} - - void onComplete() { - Subscriber::onComplete(); - complete_(); - } - - private: - Complete complete_; - }; - - Subscribers() = delete; -}; - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/Subscription.h b/yarpl/include/yarpl/flowable/Subscription.h deleted file mode 100644 index b4ccfe523..000000000 --- a/yarpl/include/yarpl/flowable/Subscription.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -namespace yarpl { -namespace flowable { - -class Subscription : public virtual Refcounted { - public: - virtual ~Subscription() = default; - - virtual void request(int64_t n) = 0; - virtual void cancel() = 0; - - static yarpl::Reference empty(); -}; - -} // flowable -} // yarpl diff --git a/yarpl/include/yarpl/flowable/TestSubscriber.h b/yarpl/include/yarpl/flowable/TestSubscriber.h deleted file mode 100644 index 8f2bf9e1b..000000000 --- a/yarpl/include/yarpl/flowable/TestSubscriber.h +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include - -#include "yarpl/flowable/Flowable.h" -#include "yarpl/flowable/Subscriber.h" -#include "yarpl/utils/ExceptionString.h" -#include "yarpl/utils/credits.h" - -namespace yarpl { -namespace flowable { - -/** - * A utility class for unit testing or experimenting with Flowable. - * - * Example usage: - * - * auto flowable = ... - * auto ts = TestSubscriber::create(); - * flowable->subscribe(to); - * ts->awaitTerminalEvent(); - * ts->assert... - */ -template -class TestSubscriber : public Subscriber { - public: - static_assert( - std::is_copy_constructible::value, - "Requires copyable types in case of a delegate subscriber"); - - constexpr static auto kCanceled = credits::kCanceled; - constexpr static auto kNoFlowControl = credits::kNoFlowControl; - - /** - * Create a TestSubscriber that will subscribe and store the value it - * receives. - */ - static Reference> create(int64_t initial = kNoFlowControl) { - return make_ref>(initial); - } - - /** - * Create a TestSubscriber that will delegate all on* method calls - * to the provided Subscriber. - * - * This will store the value it receives to allow assertions. - */ - static Reference> create( - Reference> delegate, - int64_t initial = kNoFlowControl) { - return make_ref>(std::move(delegate), initial); - } - - explicit TestSubscriber(int64_t initial = kNoFlowControl) - : TestSubscriber(Reference>{}, initial) {} - - explicit TestSubscriber( - Reference> delegate, - int64_t initial = kNoFlowControl) - : delegate_(std::move(delegate)), initial_{initial} {} - - void onSubscribe(Reference subscription) override { - if (delegate_) { - subscription_ = subscription; // copy - delegate_->onSubscribe(std::move(subscription)); - } else { - subscription_ = std::move(subscription); - } - subscription_->request(initial_); - } - - void onNext(T t) override { - if (delegate_) { - values_.push_back(t); - delegate_->onNext(std::move(t)); - } else { - values_.push_back(std::move(t)); - } - } - - void onComplete() override { - if (delegate_) { - delegate_->onComplete(); - } - subscription_.reset(); - terminated_ = true; - terminalEventCV_.notify_all(); - } - - void onError(std::exception_ptr ex) override { - if (delegate_) { - delegate_->onError(ex); - } - e_ = ex; - subscription_.reset(); - terminated_ = true; - terminalEventCV_.notify_all(); - } - - /** - * Block the current thread until either onSuccess or onError is called. - */ - void awaitTerminalEvent() { - // now block this thread - std::unique_lock lk(m_); - // if shutdown gets implemented this would then be released by it - terminalEventCV_.wait(lk, [this] { return terminated_; }); - } - - void assertValueCount(size_t count) { - if (values_.size() != count) { - std::stringstream ss; - ss << "Value count " << values_.size() << " does not match " << count; - throw std::runtime_error(ss.str()); - } - } - - int64_t getValueCount() { - return values_.size(); - } - - std::vector& values() { - return values_; - } - - const std::vector& values() const { - return values_; - } - - bool isComplete() const { - return terminated_ && !e_; - } - - bool isError() const { - return terminated_ && e_; - } - - std::string getErrorMsg() const { - if (e_ == nullptr) { - return ""; - } - return exceptionStr(e_); - } - - void assertValueAt(int64_t index, T expected) { - if (index < getValueCount()) { - auto& v = values_[index]; - if (expected != v) { - std::stringstream ss; - ss << "Expected: " << expected << " Actual: " << v; - throw std::runtime_error(ss.str()); - } - } else { - std::stringstream ss; - ss << "Index " << index << " is larger than received values " - << values_.size(); - throw std::runtime_error(ss.str()); - } - } - - /** - * If an onComplete call was not received throw a runtime_error - */ - void assertSuccess() { - if (!terminated_) { - throw std::runtime_error("Did not receive terminal event."); - } - if (e_) { - throw std::runtime_error("Received onError instead of onSuccess"); - } - } - - /** - * If the onError exception_ptr points to an error containing - * the given msg, complete successfully, otherwise throw a runtime_error - */ - void assertOnErrorMessage(std::string msg) { - if (e_ == nullptr) { - std::stringstream ss; - ss << "exception_ptr == nullptr, but expected " << msg; - throw std::runtime_error(ss.str()); - } - try { - std::rethrow_exception(e_); - } catch (std::runtime_error& re) { - if (re.what() != msg) { - std::stringstream ss; - ss << "Error message is: " << re.what() << " but expected: " << msg; - throw std::runtime_error(ss.str()); - } - } catch (...) { - throw std::runtime_error("Expects an std::runtime_error"); - } - } - - /** - * Submit Subscription->cancel(); - */ - void cancel() { - subscription_->cancel(); - } - - void request(int64_t n) { - subscription_->request(n); - } - - private: - Reference> delegate_; - std::vector values_; - std::exception_ptr e_; - int64_t initial_{kNoFlowControl}; - bool terminated_{false}; - std::mutex m_; - std::condition_variable terminalEventCV_; - Reference subscription_; -}; -} -} diff --git a/yarpl/include/yarpl/observable/Observable.h b/yarpl/include/yarpl/observable/Observable.h deleted file mode 100644 index 8d24a6881..000000000 --- a/yarpl/include/yarpl/observable/Observable.h +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "yarpl/Scheduler.h" -#include "yarpl/utils/type_traits.h" - -#include "yarpl/Refcounted.h" -#include "yarpl/observable/Observer.h" -#include "yarpl/observable/Observers.h" -#include "yarpl/observable/Subscription.h" - -#include "yarpl/Flowable.h" -#include "yarpl/flowable/Flowable_FromObservable.h" - -namespace yarpl { -namespace observable { - -/** -*Strategy for backpressure when converting from Observable to Flowable. -*/ -enum class BackpressureStrategy { DROP }; - -template -class Observable : public virtual Refcounted { - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename = - typename std::enable_if::value>::type> - void subscribe(Next&& next) { - subscribe(Observers::create(next)); - } - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename Error, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value>::type> - void subscribe(Next&& next, Error&& error) { - subscribe(Observers::create(next, error)); - } - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Next, - typename Error, - typename Complete, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value && - std::is_callable::value>::type> - void subscribe(Next&& next, Error&& error, Complete&& complete) { - subscribe(Observers::create(next, error, complete)); - } - - template - static auto create(OnSubscribe&&); - - template - auto map(Function&& function); - - template - auto filter(Function&& function); - - template - auto reduce(Function&& function); - - auto take(int64_t); - - auto skip(int64_t); - - auto ignoreElements(); - - auto subscribeOn(Scheduler&); - - /** - * Convert from Observable to Flowable with a given BackpressureStrategy. - * - * Currently the only strategy is DROP. - */ - auto toFlowable(BackpressureStrategy strategy); -}; -} // observable -} // yarpl - -#include "yarpl/observable/ObservableOperator.h" - -namespace yarpl { -namespace observable { - -template -template -auto Observable::create(OnSubscribe&& function) { - static_assert( - std::is_callable>), void>(), - "OnSubscribe must have type `void(Reference>)`"); - - return make_ref>( - std::forward(function)); -} - -template -template -auto Observable::map(Function&& function) { - using D = typename std::result_of::type; - return Reference>(new MapOperator( - Reference>(this), std::forward(function))); -} - -template -template -auto Observable::filter(Function&& function) { - return Reference>(new FilterOperator( - Reference>(this), std::forward(function))); -} - -template -template -auto Observable::reduce(Function&& function) { - using D = typename std::result_of::type; - return Reference>(new ReduceOperator( - Reference>(this), std::forward(function))); -} - -template -auto Observable::take(int64_t limit) { - return Reference>( - new TakeOperator(Reference>(this), limit)); -} - -template -auto Observable::skip(int64_t offset) { - return Reference>( - new SkipOperator(Reference>(this), offset)); -} - -template -auto Observable::ignoreElements() { - return Reference>( - new IgnoreElementsOperator(Reference>(this))); -} - -template -auto Observable::subscribeOn(Scheduler& scheduler) { - return Reference>( - new SubscribeOnOperator(Reference>(this), scheduler)); -} - -template -auto Observable::toFlowable(BackpressureStrategy strategy) { - // we currently ONLY support the DROP strategy - // so do not use the strategy parameter for anything - auto o = Reference>(this); - return yarpl::flowable::Flowables::fromPublisher([ - o = std::move(o), // the Observable to pass through - strategy - ](Reference> s) { - s->onSubscribe(Reference( - new yarpl::flowable::sources::FlowableFromObservableSubscription( - std::move(o), std::move(s)))); - }); -} - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/ObservableOperator.h b/yarpl/include/yarpl/observable/ObservableOperator.h deleted file mode 100644 index 26ffde752..000000000 --- a/yarpl/include/yarpl/observable/ObservableOperator.h +++ /dev/null @@ -1,480 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/Observable.h" -#include "yarpl/observable/Observer.h" -#include "yarpl/observable/Subscription.h" - -namespace yarpl { -namespace observable { - -/** - * Base (helper) class for operators. Operators are templated on two types: - * D (downstream) and U (upstream). Operators are created by method calls on - * an upstream Observable, and are Observables themselves. Multi-stage - * pipelines - * can be built: a Observable heading a sequence of Operators. - */ -template -class ObservableOperator : public Observable { - public: - explicit ObservableOperator(Reference> upstream) - : upstream_(std::move(upstream)) {} - - protected: - /// An Operator's subscription. - /// - /// When a pipeline chain is active, each Observable has a corresponding - /// subscription. Except for the first one, the subscriptions are created - /// against Operators. Each operator subscription has two functions: as a - /// subscriber for the previous stage; as a subscription for the next one, - /// the user-supplied subscriber being the last of the pipeline stages. - class Subscription : public ::yarpl::observable::Subscription, - public Observer { - protected: - Subscription( - Reference> observable, - Reference> observer) - : observable_(std::move(observable)), observer_(std::move(observer)) { - assert(observable_); - assert(observer_); - - // We expect to be heap-allocated; until this subscription finishes (is - // canceled; completes; error's out), hold a reference so we are not - // deallocated (by the subscriber). - Refcounted::incRef(*this); - } - - template - TOperator* getObservableAs() { - return static_cast(observable_.get()); - } - - void observerOnNext(D value) { - if (observer_) { - observer_->onNext(std::move(value)); - } - } - - /// Terminates both ends of an operator normally. - void terminate() { - terminateImpl(TerminateState::Both()); - } - - /// Terminates both ends of an operator with an error. - void terminateErr(std::exception_ptr eptr) { - terminateImpl(TerminateState::Both(), std::move(eptr)); - } - - // Subscription. - - void cancel() override { - terminateImpl(TerminateState::Up()); - } - - // Observer. - - void onSubscribe( - Reference subscription) override { - if (upstream_) { - subscription->cancel(); - return; - } - - upstream_ = std::move(subscription); - observer_->onSubscribe(Reference(this)); - } - - void onComplete() override { - terminateImpl(TerminateState::Down()); - } - - void onError(std::exception_ptr eptr) override { - terminateImpl(TerminateState::Down(), std::move(eptr)); - } - - private: - struct TerminateState { - TerminateState(bool u, bool d) : up{u}, down{d} {} - - static TerminateState Down() { - return TerminateState{false, true}; - } - - static TerminateState Up() { - return TerminateState{true, false}; - } - - static TerminateState Both() { - return TerminateState{true, true}; - } - - const bool up{false}; - const bool down{false}; - }; - - bool isTerminated() const { - return !upstream_ && !observer_; - } - - /// Terminates an operator, sending cancel() and on{Complete,Error}() - /// signals as necessary. - void terminateImpl( - TerminateState state, - std::exception_ptr eptr = nullptr) { - if (isTerminated()) { - return; - } - - if (auto upstream = std::move(upstream_)) { - if (state.up) { - upstream->cancel(); - } - } - - if (auto observer = std::move(observer_)) { - if (state.down) { - if (eptr) { - observer->onError(std::move(eptr)); - } else { - observer->onComplete(); - } - } - } - - Refcounted::decRef(*this); - } - - /// The Observable has the lambda, and other creation parameters. - Reference> observable_; - - /// This subscription controls the life-cycle of the observer. The - /// observer is retained as long as calls on it can be made. (Note: - /// the observer in turn maintains a reference on this subscription - /// object until cancellation and/or completion.) - Reference> observer_; - - /// In an active pipeline, cancel and (possibly modified) request(n) - /// calls should be forwarded upstream. Note that `this` is also a - /// observer for the upstream stage: thus, there are cycles; all of - /// the objects drop their references at cancel/complete. - Reference<::yarpl::observable::Subscription> upstream_; - }; - - Reference> upstream_; -}; - -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>::type> -class MapOperator : public ObservableOperator { - public: - MapOperator(Reference> upstream, F&& function) - : ObservableOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> observer) override { - ObservableOperator::upstream_->subscribe( - // Note: implicit cast to a reference to a observer. - Reference(new Subscription( - Reference>(this), std::move(observer)))); - } - - private: - class Subscription : public ObservableOperator::Subscription { - using Super = typename ObservableOperator::Subscription; - public: - Subscription( - Reference> observable, - Reference> observer) - : Super( - std::move(observable), - std::move(observer)) {} - - void onNext(U value) override { - auto* map = Super::template getObservableAs(); - Super::observerOnNext(map->function_(std::move(value))); - } - }; - - F function_; -}; - -template < - typename U, - typename F, - typename = - typename std::enable_if::value>::type> -class FilterOperator : public ObservableOperator { - public: - FilterOperator(Reference> upstream, F&& function) - : ObservableOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> observer) override { - ObservableOperator::upstream_->subscribe( - // Note: implicit cast to a reference to a observer. - Reference(new Subscription( - Reference>(this), std::move(observer)))); - } - - private: - class Subscription : public ObservableOperator::Subscription { - using Super = typename ObservableOperator::Subscription; - public: - Subscription( - Reference> observable, - Reference> observer) - : Super( - std::move(observable), - std::move(observer)) {} - - void onNext(U value) override { - auto* filter = Super::template getObservableAs(); - if (filter->function_(value)) { - Super::observerOnNext(std::move(value)); - } - } - }; - - F function_; -}; - -template< - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>, - typename = typename std::enable_if::value>::type> -class ReduceOperator : public ObservableOperator { -public: - ReduceOperator(Reference> upstream, F &&function) - : ObservableOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> subscriber) override { - ObservableOperator::upstream_->subscribe( - // Note: implicit cast to a reference to a subscriber. - Reference(new Subscription( - Reference>(this), std::move(subscriber)))); - } - -private: - class Subscription : public ObservableOperator::Subscription { - using Super = typename ObservableOperator::Subscription; - - public: - Subscription( - Reference > flowable, - Reference > subscriber) - : Super( - std::move(flowable), - std::move(subscriber)), - accInitialized_(false) {} - - void onNext(U value) override { - auto* reduce = Super::template getObservableAs(); - if (accInitialized_) { - acc_ = reduce->function_(std::move(acc_), std::move(value)); - } else { - acc_ = std::move(value); - accInitialized_ = true; - } - } - - void onComplete() override { - if (accInitialized_) { - Super::observerOnNext(std::move(acc_)); - } - Super::onComplete(); - } - - private: - bool accInitialized_; - D acc_; - }; - - F function_; -}; - -template -class TakeOperator : public ObservableOperator { - public: - TakeOperator(Reference> upstream, int64_t limit) - : ObservableOperator(std::move(upstream)), limit_(limit) {} - - void subscribe(Reference> observer) override { - ObservableOperator::upstream_->subscribe( - Reference(new Subscription( - Reference>(this), limit_, std::move(observer)))); - } - - private: - class Subscription : public ObservableOperator::Subscription { - using Super = typename ObservableOperator::Subscription; - public: - Subscription( - Reference> observable, - int64_t limit, - Reference> observer) - : Super( - std::move(observable), - std::move(observer)), - limit_(limit) {} - - void onNext(T value) override { - if (limit_-- > 0) { - if (pending_ > 0) - --pending_; - Super::observerOnNext( - std::move(value)); - if (limit_ == 0) { - Super::terminate(); - } - } - } - - private: - int64_t pending_{0}; - int64_t limit_; - }; - - const int64_t limit_; -}; - -template -class SkipOperator : public ObservableOperator { - public: - SkipOperator(Reference> upstream, int64_t offset) - : ObservableOperator(std::move(upstream)), offset_(offset) {} - - void subscribe(Reference> observer) override { - ObservableOperator::upstream_->subscribe( - make_ref( - Reference>(this), offset_, std::move(observer))); - } - - private: - class Subscription : public ObservableOperator::Subscription { - using Super = typename ObservableOperator::Subscription; - public: - Subscription( - Reference> observable, - int64_t offset, - Reference> observer) - : Super(std::move(observable), std::move(observer)), - offset_(offset) {} - - void onNext(T value) override { - if (offset_ <= 0) { - Super::observerOnNext( - std::move(value)); - } else { - --offset_; - } - } - - private: - int64_t offset_; - }; - - const int64_t offset_; -}; - -template -class IgnoreElementsOperator : public ObservableOperator { - public: - explicit IgnoreElementsOperator(Reference> upstream) - : ObservableOperator(std::move(upstream)) {} - - void subscribe(Reference> observer) override { - ObservableOperator::upstream_->subscribe( - Reference(new Subscription( - Reference>(this), std::move(observer)))); - } - - private: - class Subscription : public ObservableOperator::Subscription { - using Super = typename ObservableOperator::Subscription; - public: - Subscription( - Reference> observable, - Reference> observer) - : ObservableOperator::Subscription( - std::move(observable), - std::move(observer)) {} - - void onNext(T) override {} - }; -}; - -template -class SubscribeOnOperator : public ObservableOperator { - public: - SubscribeOnOperator(Reference> upstream, Scheduler& scheduler) - : ObservableOperator(std::move(upstream)), - worker_(scheduler.createWorker()) {} - - void subscribe(Reference> observer) override { - ObservableOperator::upstream_->subscribe( - Reference(new Subscription( - Reference>(this), - std::move(worker_), - std::move(observer)))); - } - - private: - class Subscription : public ObservableOperator::Subscription { - public: - Subscription( - Reference> observable, - std::unique_ptr worker, - Reference> observer) - : ObservableOperator::Subscription( - std::move(observable), - std::move(observer)), - worker_(std::move(worker)) {} - - void cancel() override { - worker_->schedule([this] { this->callSuperCancel(); }); - } - - void onNext(T value) override { - auto* observer = - ObservableOperator::Subscription::observer_.get(); - observer->onNext(std::move(value)); - } - - private: - // Trampoline to call superclass method; gcc bug 58972. - void callSuperCancel() { - ObservableOperator::Subscription::cancel(); - } - - std::unique_ptr worker_; - }; - - std::unique_ptr worker_; -}; - -template -class FromPublisherOperator : public Observable { - public: - explicit FromPublisherOperator(OnSubscribe&& function) - : function_(std::move(function)) {} - - void subscribe(Reference> observer) override { - function_(std::move(observer)); - } - - private: - OnSubscribe function_; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Observables.h b/yarpl/include/yarpl/observable/Observables.h deleted file mode 100644 index 5612d712a..000000000 --- a/yarpl/include/yarpl/observable/Observables.h +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/observable/Observable.h" - -namespace yarpl { -namespace observable { - -class Observables { - public: - static Reference> range(int64_t start, int64_t end) { - auto lambda = [start, end](Reference> observer) { - for (int64_t i = start; i < end; ++i) { - observer->onNext(i); - } - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - template - static Reference> just(const T& value) { - auto lambda = [value](Reference> observer) { - // # requested should be > 0. Ignoring the actual parameter. - observer->onNext(value); - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - template - static Reference> justN(std::initializer_list list) { - std::vector vec(list); - - auto lambda = [v = std::move(vec)](Reference> observer) { - for (auto const& elem : v) { - observer->onNext(elem); - } - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - // this will generate an observable which can be subscribed to only once - template - static Reference> justOnce(T value) { - auto lambda = [value = std::move(value), used = false](Reference> observer) mutable { - if (used) { - observer->onError( - std::make_exception_ptr(std::runtime_error("justOnce value was already used"))); - return; - } - - used = true; - // # requested should be > 0. Ignoring the actual parameter. - observer->onNext(std::move(value)); - observer->onComplete(); - }; - - return Observable::create(std::move(lambda)); - } - - template < - typename T, - typename OnSubscribe, - typename = typename std::enable_if< - std::is_callable>), void>::value>:: - type> - static Reference> create(OnSubscribe&& function) { - return Reference>(new FromPublisherOperator( - std::forward(function))); - } - - template - static Reference> empty() { - auto lambda = [](Reference> observer) { - observer->onComplete(); - }; - return Observable::create(std::move(lambda)); - } - - template - static Reference> error(std::exception_ptr ex) { - auto lambda = [ex](Reference> observer) { - observer->onError(ex); - }; - return Observable::create(std::move(lambda)); - } - - template - static Reference> error(const ExceptionType& ex) { - auto lambda = [ex](Reference> observer) { - observer->onError(std::make_exception_ptr(ex)); - }; - return Observable::create(std::move(lambda)); - } - - private: - Observables() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Observer.h b/yarpl/include/yarpl/observable/Observer.h deleted file mode 100644 index 29fdae5ea..000000000 --- a/yarpl/include/yarpl/observable/Observer.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/observable/Subscription.h" - -namespace yarpl { -namespace observable { - -template -class Observer : public virtual Refcounted { - public: - // Note: if any of the following methods is overridden in a subclass, - // the new methods SHOULD ensure that these are invoked as well. - virtual void onSubscribe(Reference subscription) { - subscription_ = subscription; - } - - // No further calls to the subscription after this method is invoked. - virtual void onComplete() { - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(std::exception_ptr) { - subscription_.reset(); - } - - virtual void onNext(T) = 0; - - protected: - Subscription* subscription() { - return subscription_.operator->(); - } - - private: - // "Our" reference to the subscription, to ensure that it is retained - // while calls to its methods are in-flight. - Reference subscription_{nullptr}; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Observers.h b/yarpl/include/yarpl/observable/Observers.h deleted file mode 100644 index ebeb89311..000000000 --- a/yarpl/include/yarpl/observable/Observers.h +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include - -#include "yarpl/observable/Observer.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace observable { - -/// Helper methods for constructing subscriber instances from functions: -/// one, two, or three functions (callables; can be lamda, for instance) -/// may be specified, corresponding to onNext, onError and onComplete -/// method bodies in the subscriber. -class Observers { - private: - /// Defined if Next, Error and Complete are signature-compatible with - /// onNext, onError and onComplete subscriber methods respectively. - template < - typename T, - typename Next, - typename Error = void (*)(std::exception_ptr), - typename Complete = void (*)()> - using EnableIfCompatible = typename std::enable_if< - std::is_callable::value && - std::is_callable::value && - std::is_callable::value>::type; - - public: - template > - static auto create(Next&& next) { - return Reference>(new Base(std::forward(next))); - } - - template < - typename T, - typename Next, - typename Error, - typename = EnableIfCompatible> - static auto create(Next&& next, Error&& error) { - return Reference>(new WithError( - std::forward(next), std::forward(error))); - } - - template < - typename T, - typename Next, - typename Error, - typename Complete, - typename = EnableIfCompatible> - static auto create(Next&& next, Error&& error, Complete&& complete) { - return Reference>( - new WithErrorAndComplete( - std::forward(next), - std::forward(error), - std::forward(complete))); - } - - private: - template - class Base : public Observer { - public: - explicit Base(Next&& next) : next_(std::forward(next)) {} - - void onNext(T value) override { - next_(std::move(value)); - } - - private: - Next next_; - }; - - template - class WithError : public Base { - public: - WithError(Next&& next, Error&& error) - : Base(std::forward(next)), error_(error) {} - - void onError(std::exception_ptr error) override { - error_(error); - } - - private: - Error error_; - }; - - template - class WithErrorAndComplete : public WithError { - public: - WithErrorAndComplete(Next&& next, Error&& error, Complete&& complete) - : WithError( - std::forward(next), - std::forward(error)), - complete_(complete) {} - - void onComplete() override { - complete_(); - } - - private: - Complete complete_; - }; - - Observers() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Subscription.h b/yarpl/include/yarpl/observable/Subscription.h deleted file mode 100644 index 30f243d26..000000000 --- a/yarpl/include/yarpl/observable/Subscription.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -namespace yarpl { -namespace observable { - -class Subscription : public virtual Refcounted { - public: - virtual ~Subscription() = default; - virtual void cancel() = 0; - - protected: - Subscription() = default; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/observable/Subscriptions.h b/yarpl/include/yarpl/observable/Subscriptions.h deleted file mode 100644 index e1081cc9f..000000000 --- a/yarpl/include/yarpl/observable/Subscriptions.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include -#include -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/observable/Subscription.h" - -namespace yarpl { -namespace observable { - -/** -* Implementation that allows checking if a Subscription is cancelled. -*/ -class AtomicBoolSubscription : public Subscription { - public: - void cancel() override; - bool isCancelled() const; - - private: - std::atomic_bool cancelled_{false}; -}; - -/** -* Implementation that gets a callback when cancellation occurs. -*/ -class CallbackSubscription : public Subscription { - public: - explicit CallbackSubscription(std::function&& onCancel); - void cancel() override; - bool isCancelled() const; - - private: - std::atomic_bool cancelled_{false}; - std::function onCancel_; -}; - -class Subscriptions { - public: - static Reference create(std::function onCancel); - static Reference create(std::atomic_bool& cancelled); - static Reference empty(); - static Reference atomicBoolSubscription(); -}; - -} // observable namespace -} // yarpl namespace diff --git a/yarpl/include/yarpl/schedulers/ThreadScheduler.h b/yarpl/include/yarpl/schedulers/ThreadScheduler.h deleted file mode 100644 index 04d0371e8..000000000 --- a/yarpl/include/yarpl/schedulers/ThreadScheduler.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include -#include "yarpl/Scheduler.h" - -namespace yarpl { - -class ThreadScheduler : public Scheduler { - public: - ThreadScheduler() {} - - std::unique_ptr createWorker() override; - - private: - ThreadScheduler(ThreadScheduler&&) = delete; - ThreadScheduler(const ThreadScheduler&) = delete; - ThreadScheduler& operator=(ThreadScheduler&&) = delete; - ThreadScheduler& operator=(const ThreadScheduler&) = delete; -}; -} diff --git a/yarpl/include/yarpl/single/Single.h b/yarpl/include/yarpl/single/Single.h deleted file mode 100644 index 97304386a..000000000 --- a/yarpl/include/yarpl/single/Single.h +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/single/SingleObserver.h" -#include "yarpl/single/SingleObservers.h" -#include "yarpl/single/SingleSubscription.h" -#include "yarpl/utils/type_traits.h" - -namespace yarpl { -namespace single { - -namespace details { - -template -class FromPublisherOperator; - -// specialization of Single -template -class SingleVoidFromPublisherOperator; -} - -template -class Single : public virtual Refcounted { - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Success, - typename = typename std::enable_if< - std::is_callable::value>::type> - void subscribe(Success&& next) { - subscribe(SingleObservers::create(next)); - } - - /** - * Subscribe overload that accepts lambdas. - */ - template < - typename Success, - typename Error, - typename = typename std::enable_if< - std::is_callable::value && - std::is_callable::value>::type> - void subscribe(Success&& next, Error&& error) { - subscribe(SingleObservers::create(next, error)); - } - - /** - * Blocking subscribe that accepts lambdas. - * - * This blocks the current thread waiting on the response. - */ - template < - typename Success, - typename = typename std::enable_if< - std::is_callable::value>::type> - void subscribeBlocking(Success&& next) { - auto waiting_ = std::make_shared>(); - subscribe(SingleObservers::create( - [ next = std::forward(next), waiting_ ](T t) { - next(std::move(t)); - waiting_->post(); - })); - // TODO get errors and throw if one is received - waiting_->wait(); - } - - template < - typename OnSubscribe, - typename = typename std::enable_if>), - void>::value>::type> - static auto create(OnSubscribe&& function) { - return Reference>( - new details::FromPublisherOperator( - std::forward(function))); - } - - template - auto map(Function&& function); -}; - -template <> -class Single : public virtual Refcounted { - public: - virtual void subscribe(Reference>) = 0; - - /** - * Subscribe overload taking lambda for onSuccess that is called upon writing - * to the network. - */ - template < - typename Success, - typename = typename std::enable_if< - std::is_callable::value>::type> - void subscribe(Success&& s) { - class SuccessSingleObserver : public SingleObserver { - public: - SuccessSingleObserver(Success&& s) : success_{std::move(s)} {} - - void onSubscribe(Reference subscription) override { - SingleObserver::onSubscribe(std::move(subscription)); - } - - virtual void onSuccess() override { - success_(); - SingleObserver::onSuccess(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(std::exception_ptr eptr) override { - SingleObserver::onError(eptr); - } - - private: - Success success_; - }; - - subscribe(make_ref(std::forward(s))); - } - - template < - typename OnSubscribe, - typename = typename std::enable_if>), - void>::value>::type> - static auto create(OnSubscribe&& function) { - return Reference>( - new details::SingleVoidFromPublisherOperator( - std::forward(function))); - } -}; - -namespace details { - -template -class FromPublisherOperator : public Single { - public: - explicit FromPublisherOperator(OnSubscribe&& function) - : function_(std::move(function)) {} - - void subscribe(Reference> subscriber) override { - function_(std::move(subscriber)); - } - - private: - OnSubscribe function_; -}; - -template -class SingleVoidFromPublisherOperator : public Single { - public: - explicit SingleVoidFromPublisherOperator(OnSubscribe&& function) - : function_(std::move(function)) {} - - void subscribe(Reference> subscriber) override { - function_(std::move(subscriber)); - } - - private: - OnSubscribe function_; -}; -} // details - -} // observable -} // yarpl - -#include "yarpl/single/SingleOperator.h" - -namespace yarpl { -namespace single { -template -template -auto Single::map(Function&& function) { - using D = typename std::result_of::type; - return Reference>(new MapOperator( - Reference>(this), std::forward(function))); -} - -} // single -} // yarpl diff --git a/yarpl/include/yarpl/single/SingleObserver.h b/yarpl/include/yarpl/single/SingleObserver.h deleted file mode 100644 index cbedeece0..000000000 --- a/yarpl/include/yarpl/single/SingleObserver.h +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/Refcounted.h" -#include "yarpl/single/SingleSubscription.h" - -namespace yarpl { -namespace single { - -template -class SingleObserver : public virtual Refcounted { - public: - // Note: if any of the following methods is overridden in a subclass, - // the new methods SHOULD ensure that these are invoked as well. - virtual void onSubscribe(Reference subscription) { - subscription_ = subscription; - } - - // No further calls to the subscription after this method is invoked. - virtual void onSuccess(T) { - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(std::exception_ptr) { - subscription_.reset(); - } - - protected: - SingleSubscription* subscription() { - return subscription_.operator->(); - } - - private: - // "Our" reference to the subscription, to ensure that it is retained - // while calls to its methods are in-flight. - Reference subscription_{nullptr}; -}; - -// specialization of SingleObserver -template <> -class SingleObserver : public virtual Refcounted { - public: - // Note: if any of the following methods is overridden in a subclass, - // the new methods SHOULD ensure that these are invoked as well. - virtual void onSubscribe(Reference subscription) { - subscription_ = subscription; - } - - // No further calls to the subscription after this method is invoked. - virtual void onSuccess() { - subscription_.reset(); - } - - // No further calls to the subscription after this method is invoked. - virtual void onError(std::exception_ptr) { - subscription_.reset(); - } - - protected: - SingleSubscription* subscription() { - return subscription_.operator->(); - } - - private: - // "Our" reference to the subscription, to ensure that it is retained - // while calls to its methods are in-flight. - Reference subscription_{nullptr}; -}; - -} // single -} // yarpl diff --git a/yarpl/include/yarpl/single/SingleObservers.h b/yarpl/include/yarpl/single/SingleObservers.h deleted file mode 100644 index 7c47ffdd8..000000000 --- a/yarpl/include/yarpl/single/SingleObservers.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/utils/type_traits.h" - -#include "yarpl/single/SingleObserver.h" - -namespace yarpl { -namespace single { - -/// Helper methods for constructing subscriber instances from functions: -/// one or two functions (callables; can be lamda, for instance) -/// may be specified, corresponding to onNext, onError and onComplete -/// method bodies in the subscriber. -class SingleObservers { - private: - /// Defined if Success and Error are signature-compatible with - /// onSuccess and onError subscriber methods respectively. - template < - typename T, - typename Success, - typename Error = void (*)(std::exception_ptr)> - using EnableIfCompatible = typename std::enable_if< - std::is_callable::value && - std::is_callable::value>::type; - - public: - template > - static auto create(Next&& next) { - return Reference>( - new Base(std::forward(next))); - } - - template < - typename T, - typename Success, - typename Error, - typename = EnableIfCompatible> - static auto create(Success&& next, Error&& error) { - return Reference>(new WithError( - std::forward(next), std::forward(error))); - } - - private: - template - class Base : public SingleObserver { - public: - explicit Base(Next&& next) : next_(std::forward(next)) {} - - void onSuccess(T value) override { - next_(std::move(value)); - // TODO how do we call the super to trigger release? - // SingleObserver::onSuccess(value); - } - - private: - Next next_; - }; - - template - class WithError : public Base { - public: - WithError(Success&& next, Error&& error) - : Base(std::forward(next)), error_(error) {} - - void onError(std::exception_ptr error) override { - error_(error); - // TODO do we call the super here to trigger release? - Base::onError(error); - } - - private: - Error error_; - }; - - SingleObservers() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/single/SingleOperator.h b/yarpl/include/yarpl/single/SingleOperator.h deleted file mode 100644 index cf3ed9291..000000000 --- a/yarpl/include/yarpl/single/SingleOperator.h +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#include "yarpl/single/Single.h" -#include "yarpl/single/SingleObserver.h" -#include "yarpl/single/SingleSubscription.h" - -namespace yarpl { -namespace single { -/** - * Base (helper) class for operators. Operators are templated on two types: - * D (downstream) and U (upstream). Operators are created by method calls on - * an upstream Single, and are Observables themselves. Multi-stage - * pipelines - * can be built: a Single heading a sequence of Operators. - */ -template -class SingleOperator : public Single { - public: - explicit SingleOperator(Reference> upstream) - : upstream_(std::move(upstream)) {} - - protected: - /// - /// \brief An Operator's subscription. - /// - /// When a pipeline chain is active, each Single has a corresponding - /// subscription. Except for the first one, the subscriptions are created - /// against Operators. Each operator subscription has two functions: as a - /// observer for the previous stage; as a subscription for the next one, - /// the user-supplied observer being the last of the pipeline stages. - class Subscription : public ::yarpl::single::SingleSubscription, - public SingleObserver { - protected: - Subscription( - Reference> single, - Reference> observer) - : single_(std::move(single)), observer_(std::move(observer)) {} - - ~Subscription() { - observer_.reset(); - } - - void observerOnSuccess(D value) { - observer_->onSuccess(std::move(value)); - upstreamSubscription_.reset(); // should break the cycle to this - } - - template - TOperator* getObservableAs() { - return static_cast(single_.get()); - } - - private: - void onSubscribe( - Reference<::yarpl::single::SingleSubscription> subscription) override { - upstreamSubscription_ = std::move(subscription); - observer_->onSubscribe( - Reference<::yarpl::single::SingleSubscription>(this)); - } - - void onError(std::exception_ptr error) override { - observer_->onError(error); - upstreamSubscription_.reset(); // should break the cycle to this - } - - void cancel() override { - upstreamSubscription_->cancel(); - observer_.reset(); // breaking the cycle - } - - /// The Single has the lambda, and other creation parameters. - Reference> single_; - - /// This subscription controls the life-cycle of the observer. The - /// observer is retained as long as calls on it can be made. (Note: - /// the observer in turn maintains a reference on this subscription - /// object until cancellation and/or completion.) - Reference> observer_; - - /// In an active pipeline, cancel and (possibly modified) request(n) - /// calls should be forwarded upstream. Note that `this` is also a - /// observer for the upstream stage: thus, there are cycles; all of - /// the objects drop their references at cancel/complete. - Reference<::yarpl::single::SingleSubscription> upstreamSubscription_; - }; - - Reference> upstream_; -}; - -template < - typename U, - typename D, - typename F, - typename = typename std::enable_if::value>::type> -class MapOperator : public SingleOperator { - public: - MapOperator(Reference> upstream, F&& function) - : SingleOperator(std::move(upstream)), - function_(std::forward(function)) {} - - void subscribe(Reference> observer) override { - SingleOperator::upstream_->subscribe( - // Note: implicit cast to a reference to a observer. - Reference(new Subscription( - Reference>(this), std::move(observer)))); - } - - private: - class Subscription : public SingleOperator::Subscription { - using Super = typename SingleOperator::Subscription; - public: - Subscription( - Reference> single, - Reference> observer) - : SingleOperator::Subscription( - std::move(single), - std::move(observer)) {} - - void onSuccess(U value) override { - auto* map = Super::template getObservableAs(); - Super::observerOnSuccess(map->function_(std::move(value))); - } - }; - - F function_; -}; - -template -class FromPublisherOperator : public Single { - public: - explicit FromPublisherOperator(OnSubscribe&& function) - : function_(std::move(function)) {} - - void subscribe(Reference> observer) override { - function_(std::move(observer)); - } - - private: - OnSubscribe function_; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/single/SingleSubscription.h b/yarpl/include/yarpl/single/SingleSubscription.h deleted file mode 100644 index b22faf838..000000000 --- a/yarpl/include/yarpl/single/SingleSubscription.h +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/Refcounted.h" - -namespace yarpl { -namespace single { - -class SingleSubscription : public virtual Refcounted { - public: - virtual ~SingleSubscription() = default; - virtual void cancel() = 0; - - protected: - SingleSubscription() {} -}; - -} // single -} // yarpl diff --git a/yarpl/include/yarpl/single/Singles.h b/yarpl/include/yarpl/single/Singles.h deleted file mode 100644 index 76a94174c..000000000 --- a/yarpl/include/yarpl/single/Singles.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include "yarpl/utils/type_traits.h" - -#include "yarpl/single/Single.h" - -namespace yarpl { -namespace single { - -class Singles { - public: - template - static Reference> just(const T& value) { - auto lambda = [value](Reference> observer) { - observer->onSuccess(value); - }; - - return Single::create(std::move(lambda)); - } - - template < - typename T, - typename OnSubscribe, - typename = typename std::enable_if>), - void>::value>::type> - static Reference> create(OnSubscribe&& function) { - return Reference>(new FromPublisherOperator( - std::forward(function))); - } - - template - static Reference> error(std::exception_ptr ex) { - auto lambda = [ex](Reference> observer) { - observer->onError(ex); - }; - return Single::create(std::move(lambda)); - } - - template - static Reference> error(const ExceptionType& ex) { - auto lambda = [ex](Reference> observer) { - observer->onError(std::make_exception_ptr(ex)); - }; - return Single::create(std::move(lambda)); - } - - private: - Singles() = delete; -}; - -} // observable -} // yarpl diff --git a/yarpl/include/yarpl/utils/ExceptionString.h b/yarpl/include/yarpl/utils/ExceptionString.h deleted file mode 100644 index 2ed8f9ffb..000000000 --- a/yarpl/include/yarpl/utils/ExceptionString.h +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -namespace yarpl { - -inline const char* exceptionStr(std::exception_ptr ep) { - try { - std::rethrow_exception(ep); - } catch (const std::exception& e) { - return e.what(); - } catch (...) { - return ""; - } -} -} diff --git a/yarpl/include/yarpl/utils/type_traits.h b/yarpl/include/yarpl/utils/type_traits.h deleted file mode 100644 index eb2fb35e3..000000000 --- a/yarpl/include/yarpl/utils/type_traits.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#pragma once - -#include - -#if __cplusplus < 201500 - -namespace std { - -namespace implementation { - -template -struct is_callable : std::false_type {}; - -template -struct is_callable< - F(Args...), - R, - std::enable_if_t>::value>> - : std::true_type {}; - -} // implementation - -template -struct is_callable : implementation::is_callable {}; - -} // std -#endif // __cplusplus diff --git a/yarpl/observable/DeferObservable.h b/yarpl/observable/DeferObservable.h new file mode 100644 index 000000000..302aeecaa --- /dev/null +++ b/yarpl/observable/DeferObservable.h @@ -0,0 +1,50 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/observable/Observable.h" + +namespace yarpl { +namespace observable { +namespace details { + +template +class DeferObservable : public Observable { + static_assert( + std::is_same, ObservableFactory>::value, + "undecayed"); + + public: + template + explicit DeferObservable(F&& factory) : factory_(std::forward(factory)) {} + + virtual std::shared_ptr subscribe( + std::shared_ptr> observer) { + std::shared_ptr> observable; + try { + observable = factory_(); + } catch (const std::exception& ex) { + observable = Observable::error(ex, std::current_exception()); + } + return observable->subscribe(std::move(observer)); + } + + private: + ObservableFactory factory_; +}; + +} // namespace details +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Observable.h b/yarpl/observable/Observable.h new file mode 100644 index 000000000..30729c360 --- /dev/null +++ b/yarpl/observable/Observable.h @@ -0,0 +1,560 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "yarpl/Refcounted.h" +#include "yarpl/observable/Observer.h" +#include "yarpl/observable/Subscription.h" + +#include "yarpl/Common.h" +#include "yarpl/Flowable.h" +#include "yarpl/flowable/Flowable_FromObservable.h" + +namespace yarpl { + +namespace observable { + +template +class Observable : public yarpl::enable_get_ref { + public: + static std::shared_ptr> empty() { + auto lambda = [](std::shared_ptr> observer) { + observer->onComplete(); + }; + return Observable::create(std::move(lambda)); + } + + static std::shared_ptr> error(folly::exception_wrapper ex) { + auto lambda = + [ex = std::move(ex)](std::shared_ptr> observer) mutable { + observer->onError(std::move(ex)); + }; + return Observable::create(std::move(lambda)); + } + + template + static std::shared_ptr> error(Ex&) { + static_assert( + std::is_lvalue_reference::value, + "use variant of error() method accepting also exception_ptr"); + } + + template + static std::shared_ptr> error(Ex& ex, std::exception_ptr ptr) { + auto lambda = [ew = folly::exception_wrapper(std::move(ptr), ex)]( + std::shared_ptr> observer) mutable { + observer->onError(std::move(ew)); + }; + return Observable::create(std::move(lambda)); + } + + static std::shared_ptr> just(T value) { + auto lambda = + [value = std::move(value)](std::shared_ptr> observer) { + observer->onNext(value); + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); + } + + /** + * The Defer operator waits until an observer subscribes to it, and then it + * generates an Observable with an ObservableFactory function. It + * does this afresh for each subscriber, so although each subscriber may + * think it is subscribing to the same Observable, in fact each subscriber + * gets its own individual sequence. + */ + template < + typename ObservableFactory, + typename = typename std::enable_if>, + std::decay_t&>::value>::type> + static std::shared_ptr> defer(ObservableFactory&&); + + static std::shared_ptr> justN(std::initializer_list list) { + auto lambda = [v = std::vector(std::move(list))]( + std::shared_ptr> observer) { + for (auto const& elem : v) { + observer->onNext(elem); + } + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); + } + + // this will generate an observable which can be subscribed to only once + static std::shared_ptr> justOnce(T value) { + auto lambda = [value = std::move(value), used = false]( + std::shared_ptr> observer) mutable { + if (used) { + observer->onError( + std::runtime_error("justOnce value was already used")); + return; + } + + used = true; + observer->onNext(std::move(value)); + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); + } + + template + static std::shared_ptr> create(OnSubscribe&&); + + template + static std::shared_ptr> createEx(OnSubscribe&&); + + virtual std::shared_ptr subscribe( + std::shared_ptr>) = 0; + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + std::shared_ptr subscribe(Next&& next) { + return subscribe(Observer::create(std::forward(next))); + } + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + std::shared_ptr subscribe(Next&& next, Error&& error) { + return subscribe(Observer::create( + std::forward(next), std::forward(error))); + } + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Next, + typename Error, + typename Complete, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value && + folly::is_invocable&>::value>::type> + std::shared_ptr + subscribe(Next&& next, Error&& error, Complete&& complete) { + return subscribe(Observer::create( + std::forward(next), + std::forward(error), + std::forward(complete))); + } + + std::shared_ptr subscribe() { + return subscribe(Observer::create()); + } + + template < + typename Function, + typename R = typename folly::invoke_result_t> + std::shared_ptr> map(Function&& function); + + template + std::shared_ptr> filter(Function&& function); + + template < + typename Function, + typename R = typename folly::invoke_result_t> + std::shared_ptr> reduce(Function&& function); + + std::shared_ptr> take(int64_t); + + std::shared_ptr> skip(int64_t); + + std::shared_ptr> ignoreElements(); + + std::shared_ptr> subscribeOn(folly::Executor&); + + std::shared_ptr> concatWith(std::shared_ptr>); + + template + std::shared_ptr> concatWith( + std::shared_ptr> first, + Args... args) { + return concatWith(first)->concatWith(args...); + } + + template + static std::shared_ptr> concat( + std::shared_ptr> first, + Args... args) { + return first->concatWith(args...); + } + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnSubscribe(Function&& function); + + // function is invoked when onNext occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>::type> + std::shared_ptr> doOnNext(Function&& function); + + // function is invoked when onError occurs. + template < + typename Function, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> doOnError(Function&& function); + + // function is invoked when onComplete occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnComplete(Function&& function); + + // function is invoked when either onComplete or onError occurs. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnTerminate(Function&& function); + + // the function is invoked for each of onNext, onCompleted, onError + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnEach(Function&& function); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete); + + // the callbacks will be invoked of each of the signals + template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename = typename std::enable_if< + folly::is_invocable&, const T&>::value>:: + type, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type, + typename = typename std::enable_if&, + folly::exception_wrapper&>::value>::type> + std::shared_ptr> + doOn(OnNextFunc&& onNext, OnCompleteFunc&& onComplete, OnErrorFunc&& onError); + + // function is invoked when cancel is called. + template < + typename Function, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + std::shared_ptr> doOnCancel(Function&& function); + + /** + * Convert from Observable to Flowable with a given BackpressureStrategy. + */ + auto toFlowable(BackpressureStrategy strategy); + + /** + * Convert from Observable to Flowable with a given BackpressureStrategy. + */ + auto toFlowable(std::shared_ptr> strategy); +}; +} // namespace observable +} // namespace yarpl + +#include "yarpl/observable/DeferObservable.h" +#include "yarpl/observable/ObservableOperator.h" + +namespace yarpl { +namespace observable { + +template +template +std::shared_ptr> Observable::create(OnSubscribe&& function) { + static_assert( + folly::is_invocable>>::value, + "OnSubscribe must have type `void(std::shared_ptr>)`"); + + return createEx([func = std::forward(function)]( + std::shared_ptr> observer, + std::shared_ptr) mutable { + func(std::move(observer)); + }); +} + +template +template +std::shared_ptr> Observable::createEx(OnSubscribe&& function) { + static_assert( + folly::is_invocable< + OnSubscribe&&, + std::shared_ptr>, + std::shared_ptr>::value, + "OnSubscribe must have type " + "`void(std::shared_ptr>, std::shared_ptr)`"); + + return std::make_shared>>( + std::forward(function)); +} + +template +template +std::shared_ptr> Observable::defer( + ObservableFactory&& factory) { + return std::make_shared< + details::DeferObservable>>( + std::forward(factory)); +} + +template +template +std::shared_ptr> Observable::map(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +template +std::shared_ptr> Observable::filter(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +template +std::shared_ptr> Observable::reduce(Function&& function) { + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +template +std::shared_ptr> Observable::take(int64_t limit) { + return std::make_shared>(this->ref_from_this(this), limit); +} + +template +std::shared_ptr> Observable::skip(int64_t offset) { + return std::make_shared>(this->ref_from_this(this), offset); +} + +template +std::shared_ptr> Observable::ignoreElements() { + return std::make_shared>(this->ref_from_this(this)); +} + +template +std::shared_ptr> Observable::subscribeOn( + folly::Executor& executor) { + return std::make_shared>( + this->ref_from_this(this), executor); +} + +template +template +std::shared_ptr> Observable::doOnSubscribe( + Function&& function) { + return details::createDoOperator( + ref_from_this(this), + std::forward(function), + [](const T&) {}, + [](const auto&) {}, + [] {}, + [] {}); // onCancel +} + +template +std::shared_ptr> Observable::concatWith( + std::shared_ptr> next) { + return std::make_shared>( + this->ref_from_this(this), std::move(next)); +} + +template +template +std::shared_ptr> Observable::doOnNext(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(function), + [](const auto&) {}, + [] {}, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnError(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + std::forward(function), + [] {}, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnComplete( + Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [](const auto&) {}, + std::forward(function), + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnTerminate( + Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [](const T&) {}, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnEach(Function&& function) { + auto sharedFunction = std::make_shared>( + std::forward(function)); + return details::createDoOperator( + ref_from_this(this), + [] {}, + [sharedFunction](const T&) { (*sharedFunction)(); }, + [sharedFunction](const auto&) { (*sharedFunction)(); }, + [sharedFunction]() { (*sharedFunction)(); }, + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + [](const auto&) {}, + std::forward(onComplete), + [] {}); // onCancel +} + +template +template < + typename OnNextFunc, + typename OnCompleteFunc, + typename OnErrorFunc, + typename, + typename, + typename> +std::shared_ptr> Observable::doOn( + OnNextFunc&& onNext, + OnCompleteFunc&& onComplete, + OnErrorFunc&& onError) { + return details::createDoOperator( + ref_from_this(this), + [] {}, + std::forward(onNext), + std::forward(onError), + std::forward(onComplete), + [] {}); // onCancel +} + +template +template +std::shared_ptr> Observable::doOnCancel(Function&& function) { + return details::createDoOperator( + ref_from_this(this), + [] {}, // onSubscribe + [](const auto&) {}, // onNext + [](const auto&) {}, // onError + [] {}, // onComplete + std::forward(function)); // onCancel +} + +template +auto Observable::toFlowable(BackpressureStrategy strategy) { + switch (strategy) { + case BackpressureStrategy::DROP: + return toFlowable(IBackpressureStrategy::drop()); + case BackpressureStrategy::ERROR: + return toFlowable(IBackpressureStrategy::error()); + case BackpressureStrategy::BUFFER: + return toFlowable(IBackpressureStrategy::buffer()); + case BackpressureStrategy::LATEST: + return toFlowable(IBackpressureStrategy::latest()); + case BackpressureStrategy::MISSING: + return toFlowable(IBackpressureStrategy::missing()); + default: + CHECK(false); // unknown value for strategy + } +} + +template +auto Observable::toFlowable( + std::shared_ptr> strategy) { + return yarpl::flowable::internal::flowableFromSubscriber( + [thisObservable = this->ref_from_this(this), + strategy = std::move(strategy)]( + std::shared_ptr> subscriber) { + strategy->init(std::move(thisObservable), std::move(subscriber)); + }); +} + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/ObservableConcatOperators.h b/yarpl/observable/ObservableConcatOperators.h new file mode 100644 index 000000000..4a3879a4e --- /dev/null +++ b/yarpl/observable/ObservableConcatOperators.h @@ -0,0 +1,154 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/observable/ObservableOperator.h" + +namespace yarpl { +namespace observable { +namespace details { + +template +class ConcatWithOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + ConcatWithOperator( + std::shared_ptr> first, + std::shared_ptr> second) + : first_(std::move(first)), second_(std::move(second)) { + CHECK(first_); + CHECK(second_); + } + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(observer, first_, second_); + subscription->init(); + + return subscription; + } + + private: + class ForwardObserver; + + // Downstream will always point to this subscription + class ConcatWithSubscription + : public yarpl::observable::Subscription, + public std::enable_shared_from_this { + public: + ConcatWithSubscription( + std::shared_ptr> observer, + std::shared_ptr> first, + std::shared_ptr> second) + : downObserver_(std::move(observer)), + first_(std::move(first)), + second_(std::move(second)) {} + + void init() { + upObserver_ = std::make_shared(this->shared_from_this()); + downObserver_->onSubscribe(this->shared_from_this()); + if (upObserver_) { + first_->subscribe(upObserver_); + } + } + + void cancel() override { + if (auto observer = std::move(upObserver_)) { + observer->cancel(); + } + first_.reset(); + second_.reset(); + upObserver_.reset(); + downObserver_.reset(); + } + + void onNext(T value) { + downObserver_->onNext(std::move(value)); + } + + void onComplete() { + if (auto first = std::move(first_)) { + upObserver_ = + std::make_shared(this->shared_from_this()); + second_->subscribe(upObserver_); + second_.reset(); + } else { + downObserver_->onComplete(); + downObserver_.reset(); + } + } + + void onError(folly::exception_wrapper ew) { + downObserver_->onError(std::move(ew)); + first_.reset(); + second_.reset(); + upObserver_.reset(); + downObserver_.reset(); + } + + private: + std::shared_ptr> downObserver_; + std::shared_ptr> first_; + std::shared_ptr> second_; + std::shared_ptr upObserver_; + }; + + class ForwardObserver : public yarpl::observable::Observer, + public yarpl::observable::Subscription { + public: + ForwardObserver( + std::shared_ptr concatWithSubscription) + : concatWithSubscription_(std::move(concatWithSubscription)) {} + + void cancel() override { + if (auto subs = std::move(subscription_)) { + subs->cancel(); + } + } + + void onSubscribe(std::shared_ptr subscription) override { + // Don't forward the subscription to downstream observer + subscription_ = std::move(subscription); + } + + void onComplete() override { + concatWithSubscription_->onComplete(); + concatWithSubscription_.reset(); + } + + void onError(folly::exception_wrapper ew) override { + concatWithSubscription_->onError(std::move(ew)); + concatWithSubscription_.reset(); + } + + void onNext(T value) override { + concatWithSubscription_->onNext(std::move(value)); + } + + private: + std::shared_ptr concatWithSubscription_; + std::shared_ptr subscription_; + }; + + private: + const std::shared_ptr> first_; + const std::shared_ptr> second_; +}; + +} // namespace details +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/ObservableDoOperator.h b/yarpl/observable/ObservableDoOperator.h new file mode 100644 index 000000000..66f655eaf --- /dev/null +++ b/yarpl/observable/ObservableDoOperator.h @@ -0,0 +1,159 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/observable/ObservableOperator.h" + +namespace yarpl { +namespace observable { +namespace details { + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnCancelFunc> +class DoOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert( + std::is_same, OnSubscribeFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnNextFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnErrorFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCompleteFunc>::value, + "undecayed"); + static_assert( + std::is_same, OnCancelFunc>::value, + "undecayed"); + + public: + template < + typename FSubscribe, + typename FNext, + typename FError, + typename FComplete, + typename FCancel> + DoOperator( + std::shared_ptr> upstream, + FSubscribe&& onSubscribeFunc, + FNext&& onNextFunc, + FError&& onErrorFunc, + FComplete&& onCompleteFunc, + FCancel&& onCancelFunc) + : upstream_(std::move(upstream)), + onSubscribeFunc_(std::forward(onSubscribeFunc)), + onNextFunc_(std::forward(onNextFunc)), + onErrorFunc_(std::forward(onErrorFunc)), + onCompleteFunc_(std::forward(onCompleteFunc)), + onCancelFunc_(std::forward(onCancelFunc)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(observer)); + upstream_->subscribe( + // Note: implicit cast to a reference to a observer. + subscription); + return subscription; + } + + private: + class DoSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + DoSubscription( + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), observable_(std::move(observable)) {} + + void onSubscribe(std::shared_ptr + subscription) override { + observable_->onSubscribeFunc_(); + SuperSub::onSubscribe(std::move(subscription)); + } + + void onNext(U value) override { + const auto& valueRef = value; + observable_->onNextFunc_(valueRef); + SuperSub::observerOnNext(std::move(value)); + } + + void onError(folly::exception_wrapper ex) override { + const auto& exRef = ex; + observable_->onErrorFunc_(exRef); + SuperSub::onError(std::move(ex)); + } + + void onComplete() override { + observable_->onCompleteFunc_(); + SuperSub::onComplete(); + } + + void cancel() override { + observable_->onCancelFunc_(); + SuperSub::cancel(); + } + + private: + std::shared_ptr observable_; + }; + + std::shared_ptr> upstream_; + OnSubscribeFunc onSubscribeFunc_; + OnNextFunc onNextFunc_; + OnErrorFunc onErrorFunc_; + OnCompleteFunc onCompleteFunc_; + OnCancelFunc onCancelFunc_; +}; + +template < + typename U, + typename OnSubscribeFunc, + typename OnNextFunc, + typename OnErrorFunc, + typename OnCompleteFunc, + typename OnCancelFunc> +inline auto createDoOperator( + std::shared_ptr> upstream, + OnSubscribeFunc&& onSubscribeFunc, + OnNextFunc&& onNextFunc, + OnErrorFunc&& onErrorFunc, + OnCompleteFunc&& onCompleteFunc, + OnCancelFunc&& onCancelFunc) { + return std::make_shared, + std::decay_t, + std::decay_t, + std::decay_t, + std::decay_t>>( + std::move(upstream), + std::forward(onSubscribeFunc), + std::forward(onNextFunc), + std::forward(onErrorFunc), + std::forward(onCompleteFunc), + std::forward(onCancelFunc)); +} +} // namespace details +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/ObservableOperator.h b/yarpl/observable/ObservableOperator.h new file mode 100644 index 000000000..451c6bd13 --- /dev/null +++ b/yarpl/observable/ObservableOperator.h @@ -0,0 +1,560 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +#include "yarpl/Observable.h" +#include "yarpl/observable/Observer.h" +#include "yarpl/observable/Observable.h" + +namespace yarpl { +namespace observable { + +/** + * Base (helper) class for operators. Operators are templated on two types: + * D (downstream) and U (upstream). Operators are created by method calls on + * an upstream Observable, and are Observables themselves. Multi-stage + * pipelines + * can be built: a Observable heading a sequence of Operators. + */ +template +class ObservableOperator : public Observable { + protected: + /// An Operator's subscription. + /// + /// When a pipeline chain is active, each Observable has a corresponding + /// subscription. Except for the first one, the subscriptions are created + /// against Operators. Each operator subscription has two functions: as a + /// subscriber for the previous stage; as a subscription for the next one, + /// the user-supplied subscriber being the last of the pipeline stages. + class OperatorSubscription : public ::yarpl::observable::Subscription, + public Observer { + protected: + explicit OperatorSubscription(std::shared_ptr> observer) + : observer_(std::move(observer)) { + assert(observer_); + } + + void observerOnNext(D value) { + if (observer_) { + observer_->onNext(std::move(value)); + } + } + + /// Terminates both ends of an operator normally. + void terminate() { + terminateImpl(TerminateState::Both()); + } + + /// Terminates both ends of an operator with an error. + void terminateErr(folly::exception_wrapper ex) { + terminateImpl(TerminateState::Both(), std::move(ex)); + } + + // Subscription. + + void cancel() override { + Subscription::cancel(); + terminateImpl(TerminateState::Up()); + } + + // Observer. + + void onSubscribe(std::shared_ptr + subscription) override { + if (upstream_) { + DLOG(ERROR) << "attempt to subscribe twice"; + subscription->cancel(); + return; + } + upstream_ = std::move(subscription); + observer_->onSubscribe(this->ref_from_this(this)); + } + + void onComplete() override { + terminateImpl(TerminateState::Down()); + } + + void onError(folly::exception_wrapper ex) override { + terminateImpl(TerminateState::Down(), std::move(ex)); + } + + private: + struct TerminateState { + TerminateState(bool u, bool d) : up{u}, down{d} {} + + static TerminateState Down() { + return TerminateState{false, true}; + } + + static TerminateState Up() { + return TerminateState{true, false}; + } + + static TerminateState Both() { + return TerminateState{true, true}; + } + + const bool up{false}; + const bool down{false}; + }; + + bool isTerminated() const { + return !upstream_ && !observer_; + } + + /// Terminates an operator, sending cancel() and on{Complete,Error}() + /// signals as necessary. + void terminateImpl( + TerminateState state, + folly::exception_wrapper ex = folly::exception_wrapper{nullptr}) { + if (isTerminated()) { + return; + } + + if (auto upstream = std::move(upstream_)) { + if (state.up) { + upstream->cancel(); + } + } + + if (auto observer = std::move(observer_)) { + if (state.down) { + if (ex) { + observer->onError(std::move(ex)); + } else { + observer->onComplete(); + } + } + } + } + + /// This subscription controls the life-cycle of the observer. The + /// observer is retained as long as calls on it can be made. (Note: + /// the observer in turn maintains a reference on this subscription + /// object until cancellation and/or completion.) + std::shared_ptr> observer_; + + /// In an active pipeline, cancel and (possibly modified) request(n) + /// calls should be forwarded upstream. Note that `this` is also a + /// observer for the upstream stage: thus, there are cycles; all of + /// the objects drop their references at cancel/complete. + // TODO(lehecka): this is extra field... base class has this member so + // remove it + std::shared_ptr<::yarpl::observable::Subscription> upstream_; + }; +}; + +template +class MapOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); + + public: + template + MapOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(observer)); + upstream_->subscribe( + // Note: implicit cast to a reference to a observer. + subscription); + return subscription; + } + + private: + class MapSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + MapSubscription( + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), observable_(std::move(observable)) {} + + void onNext(U value) override { + try { + this->observerOnNext(observable_->function_(std::move(value))); + } catch (const std::exception& exn) { + folly::exception_wrapper ew{std::current_exception(), exn}; + this->terminateErr(std::move(ew)); + } + } + + private: + std::shared_ptr observable_; + }; + + std::shared_ptr> upstream_; + F function_; +}; + +template +class FilterOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); + + public: + template + FilterOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(observer)); + upstream_->subscribe( + // Note: implicit cast to a reference to a observer. + subscription); + return subscription; + } + + private: + class FilterSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + FilterSubscription( + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), observable_(std::move(observable)) {} + + void onNext(U value) override { + if (observable_->function_(value)) { + SuperSub::observerOnNext(std::move(value)); + } + } + + private: + std::shared_ptr observable_; + }; + + std::shared_ptr> upstream_; + F function_; +}; + +template +class ReduceOperator : public ObservableOperator { + using Super = ObservableOperator; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(std::is_assignable::value, "not assignable"); + static_assert(folly::is_invocable_r::value, "not invocable"); + + public: + template + ReduceOperator(std::shared_ptr> upstream, Func&& function) + : upstream_(std::move(upstream)), + function_(std::forward(function)) {} + + std::shared_ptr subscribe( + std::shared_ptr> subscriber) override { + auto subscription = std::make_shared( + this->ref_from_this(this), std::move(subscriber)); + upstream_->subscribe( + // Note: implicit cast to a reference to a subscriber. + subscription); + return subscription; + } + + private: + class ReduceSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + ReduceSubscription( + std::shared_ptr observable, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), + observable_(std::move(observable)), + accInitialized_(false) {} + + void onNext(U value) override { + if (accInitialized_) { + acc_ = observable_->function_(std::move(acc_), std::move(value)); + } else { + acc_ = std::move(value); + accInitialized_ = true; + } + } + + void onComplete() override { + if (accInitialized_) { + SuperSub::observerOnNext(std::move(acc_)); + } + SuperSub::onComplete(); + } + + private: + std::shared_ptr observable_; + bool accInitialized_; + D acc_; + }; + + std::shared_ptr> upstream_; + F function_; +}; + +template +class TakeOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + TakeOperator(std::shared_ptr> upstream, int64_t limit) + : upstream_(std::move(upstream)), limit_(limit) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(limit_, std::move(observer)); + upstream_->subscribe(subscription); + return subscription; + } + + private: + class TakeSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + TakeSubscription(int64_t limit, std::shared_ptr> observer) + : SuperSub(std::move(observer)), limit_(limit) {} + + void onSubscribe(std::shared_ptr + subscription) override { + SuperSub::onSubscribe(std::move(subscription)); + + if (limit_ <= 0) { + SuperSub::terminate(); + } + } + + void onNext(T value) override { + if (limit_-- > 0) { + SuperSub::observerOnNext(std::move(value)); + if (limit_ == 0) { + SuperSub::terminate(); + } + } + } + + private: + int64_t limit_; + }; + + std::shared_ptr> upstream_; + const int64_t limit_; +}; + +template +class SkipOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + SkipOperator(std::shared_ptr> upstream, int64_t offset) + : upstream_(std::move(upstream)), offset_(offset) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(offset_, std::move(observer)); + upstream_->subscribe(subscription); + return subscription; + } + + private: + class SkipSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + SkipSubscription(int64_t offset, std::shared_ptr> observer) + : SuperSub(std::move(observer)), offset_(offset) {} + + void onNext(T value) override { + if (offset_ <= 0) { + SuperSub::observerOnNext(std::move(value)); + } else { + --offset_; + } + } + + private: + int64_t offset_; + }; + + std::shared_ptr> upstream_; + const int64_t offset_; +}; + +template +class IgnoreElementsOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + explicit IgnoreElementsOperator(std::shared_ptr> upstream) + : upstream_(std::move(upstream)) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(std::move(observer)); + upstream_->subscribe(subscription); + return subscription; + } + + private: + class IgnoreElementsSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + IgnoreElementsSubscription(std::shared_ptr> observer) + : SuperSub(std::move(observer)) {} + + void onNext(T) override {} + }; + + std::shared_ptr> upstream_; +}; + +template +class SubscribeOnOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + SubscribeOnOperator( + std::shared_ptr> upstream, + folly::Executor& executor) + : upstream_(std::move(upstream)), executor_(executor) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = std::make_shared( + executor_, std::move(observer)); + executor_.add([subscription, upstream = upstream_]() mutable { + upstream->subscribe(std::move(subscription)); + }); + return subscription; + } + + private: + class SubscribeOnSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + SubscribeOnSubscription( + folly::Executor& executor, + std::shared_ptr> observer) + : SuperSub(std::move(observer)), executor_(executor) {} + + void cancel() override { + executor_.add([self = this->ref_from_this(this), this] { + this->callSuperCancel(); + }); + } + + void onNext(T value) override { + SuperSub::observerOnNext(std::move(value)); + } + + private: + // Trampoline to call superclass method; gcc bug 58972. + void callSuperCancel() { + SuperSub::cancel(); + } + + folly::Executor& executor_; + }; + + std::shared_ptr> upstream_; + folly::Executor& executor_; +}; + +template +class FromPublisherOperator : public Observable { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + + public: + template + explicit FromPublisherOperator(F&& function) + : function_(std::forward(function)) {} + + private: + class PublisherObserver : public Observer { + public: + PublisherObserver( + std::shared_ptr> inner, + std::shared_ptr subscription) + : inner_(std::move(inner)) { + Observer::onSubscribe(std::move(subscription)); + } + + void onSubscribe(std::shared_ptr) override { + DLOG(ERROR) << "not allowed to call"; + CHECK(false); + } + + void onComplete() override { + if (auto inner = atomic_exchange(&inner_, nullptr)) { + inner->onComplete(); + } + Observer::onComplete(); + } + + void onError(folly::exception_wrapper ex) override { + if (auto inner = atomic_exchange(&inner_, nullptr)) { + inner->onError(std::move(ex)); + } + Observer::onError(folly::exception_wrapper()); + } + + void onNext(T t) override { + atomic_load(&inner_)->onNext(std::move(t)); + } + + private: + AtomicReference> inner_; + }; + + public: + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = Subscription::create(); + observer->onSubscribe(subscription); + + if (!subscription->isCancelled()) { + function_(std::make_shared( + std::move(observer), subscription), subscription); + } + return subscription; + } + + private: + OnSubscribe function_; +}; +} // namespace observable +} // namespace yarpl + +#include "yarpl/observable/ObservableConcatOperators.h" +#include "yarpl/observable/ObservableDoOperator.h" diff --git a/yarpl/observable/Observables.cpp b/yarpl/observable/Observables.cpp new file mode 100644 index 000000000..6107938fe --- /dev/null +++ b/yarpl/observable/Observables.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/observable/Observables.h" + +namespace yarpl { +namespace observable { + +std::shared_ptr> Observable<>::range( + int64_t start, + int64_t count) { + auto lambda = [start, count](std::shared_ptr> observer) { + auto end = start + count; + for (int64_t i = start; i < end; ++i) { + observer->onNext(i); + } + observer->onComplete(); + }; + + return Observable::create(std::move(lambda)); +} +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Observables.h b/yarpl/observable/Observables.h new file mode 100644 index 000000000..7c30c4bec --- /dev/null +++ b/yarpl/observable/Observables.h @@ -0,0 +1,57 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/observable/Observable.h" +#include "yarpl/observable/Subscription.h" + +namespace yarpl { +namespace observable { + +template <> +class Observable { + public: + /** + * Emit the sequence of numbers [start, start + count). + */ + static std::shared_ptr> range( + int64_t start, + int64_t count); + + template + static std::shared_ptr> just(T&& value) { + return Observable>::just(std::forward(value)); + } + + template + static std::shared_ptr> justN(std::initializer_list list) { + return Observable>::justN(std::move(list)); + } + + // this will generate an observable which can be subscribed to only once + template + static std::shared_ptr> justOnce(T&& value) { + return Observable>::justOnce( + std::forward(value)); + } + + private: + Observable() = delete; +}; + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Observer.h b/yarpl/observable/Observer.h new file mode 100644 index 000000000..3e1e456b4 --- /dev/null +++ b/yarpl/observable/Observer.h @@ -0,0 +1,226 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "yarpl/Refcounted.h" +#include "yarpl/observable/Subscription.h" + +namespace yarpl { +namespace observable { + +template +class Observer : public yarpl::enable_get_ref { + public: + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + virtual void onSubscribe(std::shared_ptr subscription) { + DCHECK(subscription); + + if (subscription_) { + DLOG(ERROR) << "attempt to double subscribe"; + subscription->cancel(); + return; + } + + if (cancelled_) { + subscription->cancel(); + } + + subscription_ = std::move(subscription); + } + + // No further calls to the subscription after this method is invoked. + virtual void onComplete() { + DCHECK(subscription_) << "Calling onComplete() without a subscription"; + subscription_.reset(); + } + + // No further calls to the subscription after this method is invoked. + virtual void onError(folly::exception_wrapper) { + DCHECK(subscription_) << "Calling onError() without a subscription"; + subscription_.reset(); + } + + virtual void onNext(T) = 0; + + bool isUnsubscribed() const { + CHECK(subscription_); + return subscription_->isCancelled(); + } + + // Ability to add more subscription objects which will be notified when the + // subscription has been cancelled. + // Note that calling cancel on the tied subscription is not going to cancel + // this subscriber + void addSubscription(std::shared_ptr subscription) { + if (!subscription_) { + subscription->cancel(); + return; + } + subscription_->tieSubscription(std::move(subscription)); + } + + template + void addSubscription(OnCancel onCancel) { + addSubscription(Subscription::create(std::move(onCancel))); + } + + bool isUnsubscribedOrTerminated() const { + return !subscription_ || subscription_->isCancelled(); + } + + protected: + void unsubscribe() { + if (subscription_) { + subscription_->cancel(); + } else { + cancelled_ = true; + } + } + + public: + template < + typename Next, + typename = + typename std::enable_if::value>::type> + static std::shared_ptr> create(Next next); + + template < + typename Next, + typename Error, + typename = + typename std::enable_if::value>::type, + typename = typename std::enable_if< + folly::is_invocable::value>::type> + static std::shared_ptr> create(Next next, Error error); + + template < + typename Next, + typename Error, + typename Complete, + typename = + typename std::enable_if::value>::type, + typename = typename std::enable_if< + folly::is_invocable::value>::type, + typename = + typename std::enable_if::value>::type> + static std::shared_ptr> + create(Next next, Error error, Complete complete); + + static std::shared_ptr> create() { + class NullObserver : public Observer { + public: + void onNext(T) {} + }; + return std::make_shared(); + } + + private: + std::shared_ptr subscription_; + bool cancelled_{false}; +}; + +namespace details { + +template +class Base : public Observer { + static_assert(std::is_same, Next>::value, "undecayed"); + + public: + template + explicit Base(FNext&& next) : next_(std::forward(next)) {} + + void onNext(T value) override { + next_(std::move(value)); + } + + private: + Next next_; +}; + +template +class WithError : public Base { + static_assert(std::is_same, Error>::value, "undecayed"); + + public: + template + WithError(FNext&& next, FError&& error) + : Base(std::forward(next)), + error_(std::forward(error)) {} + + void onError(folly::exception_wrapper error) override { + error_(std::move(error)); + } + + private: + Error error_; +}; + +template +class WithErrorAndComplete : public WithError { + static_assert( + std::is_same, Complete>::value, + "undecayed"); + + public: + template + WithErrorAndComplete(FNext&& next, FError&& error, FComplete&& complete) + : WithError( + std::forward(next), + std::forward(error)), + complete_(std::move(complete)) {} + + void onComplete() override { + complete_(); + } + + private: + Complete complete_; +}; +} // namespace details + +template +template +std::shared_ptr> Observer::create(Next next) { + return std::make_shared>(std::move(next)); +} + +template +template +std::shared_ptr> Observer::create(Next next, Error error) { + return std::make_shared>( + std::move(next), std::move(error)); +} + +template +template < + typename Next, + typename Error, + typename Complete, + typename, + typename, + typename> +std::shared_ptr> +Observer::create(Next next, Error error, Complete complete) { + return std::make_shared< + details::WithErrorAndComplete>( + std::move(next), std::move(error), std::move(complete)); +} + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Subscription.cpp b/yarpl/observable/Subscription.cpp new file mode 100644 index 000000000..6a0abda1a --- /dev/null +++ b/yarpl/observable/Subscription.cpp @@ -0,0 +1,85 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/observable/Subscription.h" +#include +#include +#include + +namespace yarpl { +namespace observable { + +/** + * Implementation that allows checking if a Subscription is cancelled. + */ +void Subscription::cancel() { + cancelled_ = true; + // Lock must be obtained here and not in the range expression for it to + // apply to the loop body. + auto locked = tiedSubscriptions_.wlock(); + for (auto& subscription : *locked) { + subscription->cancel(); + } +} + +bool Subscription::isCancelled() const { + return cancelled_; +} + +void Subscription::tieSubscription(std::shared_ptr subscription) { + CHECK(subscription); + if (isCancelled()) { + subscription->cancel(); + } + tiedSubscriptions_.wlock()->push_back(std::move(subscription)); +} + +std::shared_ptr Subscription::create( + std::function onCancel) { + class CallbackSubscription : public Subscription { + public: + explicit CallbackSubscription(std::function onCancel) + : onCancel_(std::move(onCancel)) {} + + void cancel() override { + bool expected = false; + // mark cancelled 'true' and only if successful invoke 'onCancel()' + if (cancelled_.compare_exchange_strong(expected, true)) { + onCancel_(); + // Lock must be obtained here and not in the range expression for it to + // apply to the loop body. + auto locked = tiedSubscriptions_.wlock(); + for (auto& subscription : *locked) { + subscription->cancel(); + } + } + } + + private: + std::function onCancel_; + }; + return std::make_shared(std::move(onCancel)); +} + +std::shared_ptr Subscription::create( + std::atomic_bool& cancelled) { + return create([&cancelled]() { cancelled = true; }); +} + +std::shared_ptr Subscription::create() { + return std::make_shared(); +} + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/observable/Subscription.h b/yarpl/observable/Subscription.h new file mode 100644 index 000000000..38dc17792 --- /dev/null +++ b/yarpl/observable/Subscription.h @@ -0,0 +1,45 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace yarpl { +namespace observable { + +class Subscription { + public: + virtual ~Subscription() = default; + virtual void cancel(); + bool isCancelled() const; + + // Adds ability to tie another subscription to this instance. + // Whenever *this subscription is cancelled then all tied subscriptions get + // cancelled as well + void tieSubscription(std::shared_ptr subscription); + + static std::shared_ptr create(std::function onCancel); + static std::shared_ptr create(std::atomic_bool& cancelled); + static std::shared_ptr create(); + + protected: + std::atomic cancelled_{false}; + folly::Synchronized>> + tiedSubscriptions_; +}; + +} // namespace observable +} // namespace yarpl diff --git a/yarpl/include/yarpl/observable/TestObserver.h b/yarpl/observable/TestObserver.h similarity index 71% rename from yarpl/include/yarpl/observable/TestObserver.h rename to yarpl/observable/TestObserver.h index 3fafdf266..a4d290492 100644 --- a/yarpl/include/yarpl/observable/TestObserver.h +++ b/yarpl/observable/TestObserver.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -18,7 +30,7 @@ namespace observable { * Example usage: * * auto observable = ... - * auto ts = TestObserver::create(); + * auto ts = std::make_shared>(); * observable->subscribe(ts->unique_observer()); * ts->awaitTerminalEvent(); * ts->assert... @@ -29,7 +41,8 @@ namespace observable { * * For example: * - * auto ts = TestObserver::create(std::make_unique()); + * auto ts = + * std::make_shared>(std::make_unique()); * observable->subscribe(ts->unique_observer()); * * Now when 'observable' is subscribed to, the TestObserver behavior @@ -44,29 +57,13 @@ class TestObserver : public yarpl::observable::Observer, using Observer = yarpl::observable::Observer; public: - /** - * Create a TestObserver that will subscribe upwards - * with no flow control (max value) and store all values it receives. - * @return - */ - static std::shared_ptr> create(); - - /** - * Create a TestObserver that will delegate all on* method calls - * to the provided Observer. - * - * This will store all values it receives to allow assertions. - * @return - */ - static std::shared_ptr> create(std::unique_ptr); - TestObserver(); explicit TestObserver(std::unique_ptr delegate); - void onSubscribe(Subscription* s) override; - void onNext(const T& t) override; + void onSubscribe(std::shared_ptr s) override; + void onNext(T t) override; void onComplete() override; - void onError(std::exception_ptr ex) override; + void onError(folly::exception_wrapper ex) override; /** * Get a unique Observer that can be passed into the Observable.subscribe @@ -84,7 +81,8 @@ class TestObserver : public yarpl::observable::Observer, /** * Block the current thread until either onComplete or onError is called. */ - void awaitTerminalEvent(); + void awaitTerminalEvent( + std::chrono::milliseconds ms = std::chrono::seconds{1}); /** * If the onNext values received does not match the given count, @@ -106,7 +104,7 @@ class TestObserver : public yarpl::observable::Observer, T& getValueAt(size_t index); /** - * If the onError exception_ptr points to an error containing + * If the onError exception_wrapper points to an error containing * the given msg, complete successfully, otherwise throw a runtime_error */ void assertOnErrorMessage(std::string msg); @@ -116,14 +114,24 @@ class TestObserver : public yarpl::observable::Observer, */ void cancel(); + bool isComplete() const { + return complete_; + } + + bool isError() const { + return error_; + } + private: std::unique_ptr delegate_; std::vector values_; - std::exception_ptr e_; + folly::exception_wrapper e_; bool terminated_{false}; + bool complete_{false}; + bool error_{false}; std::mutex m_; std::condition_variable terminalEventCV_; - Subscription* subscription_; + std::shared_ptr subscription_; }; template @@ -134,18 +142,7 @@ TestObserver::TestObserver(std::unique_ptr delegate) : delegate_(std::move(delegate)){}; template -std::shared_ptr> TestObserver::create() { - return std::make_shared>(); -} - -template -std::shared_ptr> TestObserver::create( - std::unique_ptr s) { - return std::make_shared>(std::move(s)); -} - -template -void TestObserver::onSubscribe(Subscription* s) { +void TestObserver::onSubscribe(std::shared_ptr s) { subscription_ = s; if (delegate_) { delegate_->onSubscribe(s); @@ -153,7 +150,7 @@ void TestObserver::onSubscribe(Subscription* s) { } template -void TestObserver::onNext(const T& t) { +void TestObserver::onNext(T t) { if (delegate_) { // std::cout << "TestObserver onNext& => copy then delegate" << // std::endl; @@ -171,25 +168,29 @@ void TestObserver::onComplete() { delegate_->onComplete(); } terminated_ = true; + complete_ = true; terminalEventCV_.notify_all(); } template -void TestObserver::onError(std::exception_ptr ex) { +void TestObserver::onError(folly::exception_wrapper ex) { if (delegate_) { delegate_->onError(ex); } - e_ = ex; + e_ = std::move(ex); terminated_ = true; + error_ = true; terminalEventCV_.notify_all(); } template -void TestObserver::awaitTerminalEvent() { +void TestObserver::awaitTerminalEvent(std::chrono::milliseconds ms) { // now block this thread std::unique_lock lk(m_); // if shutdown gets implemented this would then be released by it - terminalEventCV_.wait(lk, [this] { return terminated_; }); + if (!terminalEventCV_.wait_for(lk, ms, [this] { return terminated_; })) { + throw std::runtime_error("timeout in awaitTerminalEvent"); + } } template @@ -212,8 +213,8 @@ TestObserver::unique_observer() { ts_->onNext(t); } - void onError(std::exception_ptr e) override { - ts_->onError(e); + void onError(folly::exception_wrapper e) override { + ts_->onError(std::move(e)); } void onComplete() override { @@ -247,22 +248,11 @@ T& TestObserver::getValueAt(size_t index) { template void TestObserver::assertOnErrorMessage(std::string msg) { - if (e_ == nullptr) { + if (!e_ || e_.get_exception()->what() != msg) { std::stringstream ss; - ss << "exception_ptr == nullptr, but expected " << msg; + ss << "Error is: " << e_ << " but expected: " << msg; throw std::runtime_error(ss.str()); } - try { - std::rethrow_exception(e_); - } catch (std::runtime_error& re) { - if (re.what() != msg) { - std::stringstream ss; - ss << "Error message is: " << re.what() << " but expected: " << msg; - throw std::runtime_error(ss.str()); - } - } catch (...) { - throw std::runtime_error("Expects an std::runtime_error"); - } -} -} } +} // namespace observable +} // namespace yarpl diff --git a/yarpl/perf/CMakeLists.txt b/yarpl/perf/CMakeLists.txt deleted file mode 100644 index 03ccdcffb..000000000 --- a/yarpl/perf/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ - -benchmark(function_perf Function_perf.cpp) -benchmark(observable_perf Observable_perf.cpp) diff --git a/yarpl/perf/Function_perf.cpp b/yarpl/perf/Function_perf.cpp index 4783d2a8a..4c868a44f 100644 --- a/yarpl/perf/Function_perf.cpp +++ b/yarpl/perf/Function_perf.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include diff --git a/yarpl/perf/Observable_perf.cpp b/yarpl/perf/Observable_perf.cpp index abbe81e90..ecfb6a92c 100644 --- a/yarpl/perf/Observable_perf.cpp +++ b/yarpl/perf/Observable_perf.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -9,7 +21,7 @@ using namespace yarpl::observable; static void Observable_OnNextOne_ConstructOnly(benchmark::State& state) { while (state.KeepRunning()) { - auto a = Observable::create([](yarpl::Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onSubscribe(Subscriptions::empty()); obs->onNext(1); obs->onComplete(); @@ -19,20 +31,20 @@ static void Observable_OnNextOne_ConstructOnly(benchmark::State& state) { BENCHMARK(Observable_OnNextOne_ConstructOnly); static void Observable_OnNextOne_SubscribeOnly(benchmark::State& state) { - auto a = Observable::create([](yarpl::Reference> obs) { + auto a = Observable::create([](std::shared_ptr> obs) { obs->onSubscribe(Subscriptions::empty()); obs->onNext(1); obs->onComplete(); }); while (state.KeepRunning()) { - a->subscribe(Observers::create([](int /* value */) {})); + a->subscribe(Observer::create([](int /* value */) {})); } } BENCHMARK(Observable_OnNextOne_SubscribeOnly); static void Observable_OnNextN(benchmark::State& state) { auto a = - Observable::create([&state](yarpl::Reference> obs) { + Observable::create([&state](std::shared_ptr> obs) { obs->onSubscribe(Subscriptions::empty()); for (int i = 0; i < state.range(0); i++) { obs->onNext(i); @@ -40,7 +52,7 @@ static void Observable_OnNextN(benchmark::State& state) { obs->onComplete(); }); while (state.KeepRunning()) { - a->subscribe(Observers::create([](int /* value */) {})); + a->subscribe(Observer::create([](int /* value */) {})); } } diff --git a/yarpl/single/Single.h b/yarpl/single/Single.h new file mode 100644 index 000000000..1355e30ba --- /dev/null +++ b/yarpl/single/Single.h @@ -0,0 +1,175 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "yarpl/Refcounted.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleObservers.h" +#include "yarpl/single/SingleSubscription.h" + +namespace yarpl { +namespace single { + +template +class Single : public yarpl::enable_get_ref { + public: + virtual ~Single() = default; + + virtual void subscribe(std::shared_ptr>) = 0; + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Success, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + void subscribe(Success&& next) { + subscribe(SingleObservers::create(std::forward(next))); + } + + /** + * Subscribe overload that accepts lambdas. + */ + template < + typename Success, + typename Error, + typename = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type> + void subscribe(Success next, Error error) { + subscribe(SingleObservers::create( + std::forward(next), std::forward(error))); + } + + /** + * Blocking subscribe that accepts lambdas. + * + * This blocks the current thread waiting on the response. + */ + template < + typename Success, + typename = typename std::enable_if< + folly::is_invocable&, T>::value>::type> + void subscribeBlocking(Success&& next) { + auto waiting_ = std::make_shared>(); + subscribe( + SingleObservers::create([next = std::forward(next), waiting_](T t) { + next(std::move(t)); + waiting_->post(); + })); + // TODO get errors and throw if one is received + waiting_->wait(); + } + + template < + typename OnSubscribe, + typename = typename std::enable_if&, + std::shared_ptr>>::value>::type> + static std::shared_ptr> create(OnSubscribe&&); + + template + auto map(Function&& function); +}; + +template <> +class Single { + public: + virtual ~Single() = default; + + virtual void subscribe(std::shared_ptr>) = 0; + + /** + * Subscribe overload taking lambda for onSuccess that is called upon writing + * to the network. + */ + template < + typename Success, + typename = typename std::enable_if< + folly::is_invocable&>::value>::type> + void subscribe(Success&& s) { + class SuccessSingleObserver : public SingleObserverBase { + public: + explicit SuccessSingleObserver(Success&& success) + : success_{std::forward(success)} {} + + void onSubscribe( + std::shared_ptr subscription) override { + SingleObserverBase::onSubscribe(std::move(subscription)); + } + + void onSuccess() override { + success_(); + SingleObserverBase::onSuccess(); + } + + // No further calls to the subscription after this method is invoked. + void onError(folly::exception_wrapper ex) override { + SingleObserverBase::onError(std::move(ex)); + } + + private: + std::decay_t success_; + }; + + subscribe( + std::make_shared(std::forward(s))); + } + + template < + typename OnSubscribe, + typename = typename std::enable_if&, + std::shared_ptr>>::value>::type> + static auto create(OnSubscribe&&); +}; + +} // namespace single +} // namespace yarpl + +#include "yarpl/single/SingleOperator.h" + +namespace yarpl { +namespace single { + +template +template +std::shared_ptr> Single::create(OnSubscribe&& function) { + return std::make_shared>>( + std::forward(function)); +} + +template +auto Single::create(OnSubscribe&& function) { + return std::make_shared< + SingleVoidFromPublisherOperator>>( + std::forward(function)); +} + +template +template +auto Single::map(Function&& function) { + using D = typename folly::invoke_result_t; + return std::make_shared>>( + this->ref_from_this(this), std::forward(function)); +} + +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleObserver.h b/yarpl/single/SingleObserver.h new file mode 100644 index 000000000..8c74337c5 --- /dev/null +++ b/yarpl/single/SingleObserver.h @@ -0,0 +1,171 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "yarpl/single/SingleSubscription.h" + +namespace yarpl { +namespace single { + +template +class SingleObserver { + public: + virtual ~SingleObserver() = default; + virtual void onSubscribe(std::shared_ptr) = 0; + virtual void onSuccess(T) = 0; + virtual void onError(folly::exception_wrapper) = 0; + + template + static std::shared_ptr> create(Success&& success); + + template + static std::shared_ptr> create( + Success&& success, + Error&& error); +}; + +template +class SingleObserverBase : public SingleObserver { + public: + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + void onSubscribe(std::shared_ptr subscription) override { + DCHECK(subscription); + + if (subscription_) { + subscription->cancel(); + return; + } + + subscription_ = std::move(subscription); + } + + void onSuccess(T) override { + DCHECK(subscription_) << "Calling onSuccess() without a subscription"; + subscription_.reset(); + } + + // No further calls to the subscription after this method is invoked. + void onError(folly::exception_wrapper) override { + DCHECK(subscription_) << "Calling onError() without a subscription"; + subscription_.reset(); + } + + protected: + SingleSubscription* subscription() { + return subscription_.operator->(); + } + + private: + std::shared_ptr subscription_; +}; + +/// Specialization of SingleObserverBase. +template <> +class SingleObserverBase { + public: + virtual ~SingleObserverBase() = default; + + // Note: If any of the following methods is overridden in a subclass, the new + // methods SHOULD ensure that these are invoked as well. + virtual void onSubscribe(std::shared_ptr subscription) { + DCHECK(subscription); + + if (subscription_) { + subscription->cancel(); + return; + } + + subscription_ = std::move(subscription); + } + + virtual void onSuccess() { + DCHECK(subscription_) << "Calling onSuccess() without a subscription"; + subscription_.reset(); + } + + // No further calls to the subscription after this method is invoked. + virtual void onError(folly::exception_wrapper) { + DCHECK(subscription_) << "Calling onError() without a subscription"; + subscription_.reset(); + } + + protected: + SingleSubscription* subscription() { + return subscription_.operator->(); + } + + private: + std::shared_ptr subscription_; +}; + +template +class SimpleSingleObserver : public SingleObserver { + public: + SimpleSingleObserver(Success success, Error error) + : success_(std::move(success)), error_(std::move(error)) {} + + void onSubscribe(std::shared_ptr) { + // throw away the subscription + } + + void onSuccess(T value) override { + success_(std::move(value)); + } + + void onError(folly::exception_wrapper ew) { + error_(std::move(ew)); + } + + Success success_; + Error error_; +}; + +template +template +std::shared_ptr> SingleObserver::create( + Success&& success) { + static_assert( + folly::is_invocable::value, + "Input `success` should be invocable with a parameter of `T`."); + return std::make_shared, + folly::Function>>( + std::forward(success), [](folly::exception_wrapper) {}); +} + +template +template +std::shared_ptr> SingleObserver::create( + Success&& success, + Error&& error) { + static_assert( + folly::is_invocable::value, + "Input `success` should be invocable with a parameter of `T`."); + static_assert( + folly::is_invocable::value, + "Input `error` should be invocable with a parameter of " + "`folly::exception_wrapper`."); + + return std::make_shared< + SimpleSingleObserver, std::decay_t>>( + std::forward(success), std::forward(error)); +} + +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleObservers.h b/yarpl/single/SingleObservers.h new file mode 100644 index 000000000..118b25fa9 --- /dev/null +++ b/yarpl/single/SingleObservers.h @@ -0,0 +1,107 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/single/SingleObserver.h" + +#include + +namespace yarpl { +namespace single { + +/// Helper methods for constructing subscriber instances from functions: +/// one or two functions (callables; can be lamda, for instance) +/// may be specified, corresponding to onNext, onError and onComplete +/// method bodies in the subscriber. +class SingleObservers { + private: + /// Defined if Success and Error are signature-compatible with + /// onSuccess and onError subscriber methods respectively. + template < + typename T, + typename Success, + typename Error = void (*)(folly::exception_wrapper)> + using EnableIfCompatible = typename std::enable_if< + folly::is_invocable&, T>::value && + folly::is_invocable&, folly::exception_wrapper>:: + value>::type; + + public: + template > + static auto create(Next&& next) { + return std::make_shared>>( + std::forward(next)); + } + + template < + typename T, + typename Success, + typename Error, + typename = EnableIfCompatible> + static auto create(Success&& next, Error&& error) { + return std::make_shared< + WithError, std::decay_t>>( + std::forward(next), std::forward(error)); + } + + template + static auto create() { + return std::make_shared>(); + } + + private: + template + class Base : public SingleObserverBase { + static_assert(std::is_same, Next>::value, "undecayed"); + + public: + template + explicit Base(FNext&& next) : next_(std::forward(next)) {} + + void onSuccess(T value) override { + next_(std::move(value)); + // TODO how do we call the super to trigger release? + // SingleObserver::onSuccess(value); + } + + private: + Next next_; + }; + + template + class WithError : public Base { + static_assert(std::is_same, Error>::value, "undecayed"); + + public: + template + WithError(FSuccess&& success, FError&& error) + : Base(std::forward(success)), + error_(std::forward(error)) {} + + void onError(folly::exception_wrapper error) override { + error_(error); + // TODO do we call the super here to trigger release? + Base::onError(std::move(error)); + } + + private: + Error error_; + }; + + SingleObservers() = delete; +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleOperator.h b/yarpl/single/SingleOperator.h new file mode 100644 index 000000000..0b3e7392e --- /dev/null +++ b/yarpl/single/SingleOperator.h @@ -0,0 +1,248 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "yarpl/single/Single.h" +#include "yarpl/single/SingleObserver.h" +#include "yarpl/single/SingleSubscriptions.h" + +namespace yarpl { +namespace single { +/** + * Base (helper) class for operators. Operators are templated on two types: + * D (downstream) and U (upstream). Operators are created by method calls on + * an upstream Single, and are Observables themselves. Multi-stage + * pipelines + * can be built: a Single heading a sequence of Operators. + */ +template +class SingleOperator : public Single { + public: + explicit SingleOperator(std::shared_ptr> upstream) + : upstream_(std::move(upstream)) {} + + protected: + /// + /// \brief An Operator's subscription. + /// + /// When a pipeline chain is active, each Single has a corresponding + /// subscription. Except for the first one, the subscriptions are created + /// against Operators. Each operator subscription has two functions: as a + /// observer for the previous stage; as a subscription for the next one, + /// the user-supplied observer being the last of the pipeline stages. + template + class Subscription : public ::yarpl::single::SingleSubscription, + public SingleObserver, + public yarpl::enable_get_ref { + protected: + Subscription( + std::shared_ptr single, + std::shared_ptr> observer) + : single_(std::move(single)), observer_(std::move(observer)) {} + + ~Subscription() { + observer_.reset(); + } + + void observerOnSuccess(D value) { + terminateImpl(TerminateState::Down(), folly::Try{std::move(value)}); + } + + void observerOnError(folly::exception_wrapper ew) { + terminateImpl(TerminateState::Down(), folly::Try{std::move(ew)}); + } + + std::shared_ptr getOperator() { + return single_; + } + + void terminateErr(folly::exception_wrapper ew) { + terminateImpl(TerminateState::Both(), std::move(ew)); + } + + // SingleSubscription. + + void cancel() override { + terminateImpl(TerminateState::Up(), folly::Try{}); + } + + // Subscriber. + + void onSubscribe(std::shared_ptr + subscription) override { + upstream_ = std::move(subscription); + observer_->onSubscribe(this->ref_from_this(this)); + } + + void onError(folly::exception_wrapper ew) override { + terminateImpl(TerminateState::Down(), folly::Try{std::move(ew)}); + } + + private: + struct TerminateState { + TerminateState(bool u, bool d) : up{u}, down{d} {} + + static TerminateState Down() { + return TerminateState{false, true}; + } + + static TerminateState Up() { + return TerminateState{true, false}; + } + + static TerminateState Both() { + return TerminateState{true, true}; + } + + const bool up{false}; + const bool down{false}; + }; + + bool isTerminated() const { + return !upstream_ && !observer_; + } + + void terminateImpl(TerminateState state, folly::Try maybe) { + if (isTerminated()) { + return; + } + + if (auto upstream = std::move(upstream_)) { + if (state.up) { + upstream->cancel(); + } + } + + if (auto observer = std::move(observer_)) { + if (state.down) { + if (maybe.hasValue()) { + observer->onSuccess(std::move(maybe).value()); + } else { + observer->onError(std::move(maybe).exception()); + } + } + } + } + + /// The Single has the lambda, and other creation parameters. + std::shared_ptr single_; + + /// This subscription controls the life-cycle of the observer. The + /// observer is retained as long as calls on it can be made. (Note: + /// the observer in turn maintains a reference on this subscription + /// object until cancellation and/or completion.) + std::shared_ptr> observer_; + + /// In an active pipeline, cancel and (possibly modified) request(n) + /// calls should be forwarded upstream. Note that `this` is also a + /// observer for the upstream stage: thus, there are cycles; all of + /// the objects drop their references at cancel/complete. + std::shared_ptr upstream_; + }; + + std::shared_ptr> upstream_; +}; + +template < + typename U, + typename D, + typename F> +class MapOperator : public SingleOperator { + using ThisOperatorT = MapOperator; + using Super = SingleOperator; + using OperatorSubscription = + typename Super::template Subscription; + static_assert(std::is_same, F>::value, "undecayed"); + static_assert(folly::is_invocable_r::value, "not invocable"); + + public: + template + MapOperator(std::shared_ptr> upstream, Func&& function) + : Super(std::move(upstream)), function_(std::forward(function)) {} + + void subscribe(std::shared_ptr> observer) override { + Super::upstream_->subscribe( + // Note: implicit cast to a reference to a observer. + std::make_shared( + this->ref_from_this(this), std::move(observer))); + } + + private: + class MapSubscription : public OperatorSubscription { + public: + MapSubscription( + std::shared_ptr single, + std::shared_ptr> observer) + : OperatorSubscription(std::move(single), std::move(observer)) {} + + void onSuccess(U value) override { + try { + auto map_operator = this->getOperator(); + this->observerOnSuccess(map_operator->function_(std::move(value))); + } catch (const std::exception& exn) { + folly::exception_wrapper ew{std::current_exception(), exn}; + this->observerOnError(std::move(ew)); + } + } + }; + + F function_; +}; + +template +class FromPublisherOperator : public Single { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + + public: + template + explicit FromPublisherOperator(F&& function) + : function_(std::forward(function)) {} + + void subscribe(std::shared_ptr> observer) override { + function_(std::move(observer)); + } + + private: + OnSubscribe function_; +}; + +template +class SingleVoidFromPublisherOperator : public Single { + static_assert( + std::is_same, OnSubscribe>::value, + "undecayed"); + + public: + template + explicit SingleVoidFromPublisherOperator(F&& function) + : function_(std::forward(function)) {} + + void subscribe(std::shared_ptr> observer) override { + function_(std::move(observer)); + } + + private: + OnSubscribe function_; +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/SingleSubscription.h b/yarpl/single/SingleSubscription.h new file mode 100644 index 000000000..ef898c1ba --- /dev/null +++ b/yarpl/single/SingleSubscription.h @@ -0,0 +1,32 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/Refcounted.h" + +namespace yarpl { +namespace single { + +class SingleSubscription { + public: + virtual ~SingleSubscription() = default; + virtual void cancel() = 0; + + protected: + SingleSubscription() {} +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/include/yarpl/single/SingleSubscriptions.h b/yarpl/single/SingleSubscriptions.h similarity index 60% rename from yarpl/include/yarpl/single/SingleSubscriptions.h rename to yarpl/single/SingleSubscriptions.h index a40c30e26..9ebfe4498 100644 --- a/yarpl/include/yarpl/single/SingleSubscriptions.h +++ b/yarpl/single/SingleSubscriptions.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -13,8 +25,8 @@ namespace yarpl { namespace single { /** -* Implementation that allows checking if a Subscription is cancelled. -*/ + * Implementation that allows checking if a Subscription is cancelled. + */ class AtomicBoolSingleSubscription : public SingleSubscription { public: void cancel() override { @@ -29,11 +41,11 @@ class AtomicBoolSingleSubscription : public SingleSubscription { }; /** -* Implementation that gets a callback when cancellation occurs. -*/ + * Implementation that gets a callback when cancellation occurs. + */ class CallbackSingleSubscription : public SingleSubscription { public: - explicit CallbackSingleSubscription(std::function&& onCancel) + explicit CallbackSingleSubscription(std::function onCancel) : onCancel_(std::move(onCancel)) {} void cancel() override { bool expected = false; @@ -52,10 +64,10 @@ class CallbackSingleSubscription : public SingleSubscription { }; /** -* Implementation that can be cancelled with or without + * Implementation that can be cancelled with or without * a delegate, and when the delegate exists (before or after cancel) * it will be cancelled in a thread-safe manner. -*/ + */ class DelegateSingleSubscription : public SingleSubscription { public: explicit DelegateSingleSubscription() {} @@ -80,7 +92,7 @@ class DelegateSingleSubscription : public SingleSubscription { /** * This can be called once. */ - void setDelegate(Reference d) { + void setDelegate(std::shared_ptr d) { bool shouldCancelDelegate = false; { std::lock_guard g(m_); @@ -102,26 +114,27 @@ class DelegateSingleSubscription : public SingleSubscription { // all must be protected by a mutex mutable std::mutex m_; bool cancelled_{false}; - Reference delegate_; + std::shared_ptr delegate_; }; class SingleSubscriptions { public: - static Reference create( + static std::shared_ptr create( std::function onCancel) { - return make_ref(std::move(onCancel)); + return std::make_shared(std::move(onCancel)); } - static Reference create( + static std::shared_ptr create( std::atomic_bool& cancelled) { return create([&cancelled]() { cancelled = true; }); } - static Reference empty() { - return Reference(new AtomicBoolSingleSubscription()); + static std::shared_ptr empty() { + return std::make_shared(); } - static Reference atomicBoolSubscription() { - return make_ref(); + static std::shared_ptr + atomicBoolSubscription() { + return std::make_shared(); } }; -} // single namespace -} // yarpl namespace +} // namespace single +} // namespace yarpl diff --git a/yarpl/include/yarpl/single/SingleTestObserver.h b/yarpl/single/SingleTestObserver.h similarity index 68% rename from yarpl/include/yarpl/single/SingleTestObserver.h rename to yarpl/single/SingleTestObserver.h index 2bdf6696b..2557f5d10 100644 --- a/yarpl/include/yarpl/single/SingleTestObserver.h +++ b/yarpl/single/SingleTestObserver.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -32,7 +44,7 @@ namespace single { * * For example: * - * auto to = SingleTestObserver::create(make_ref()); + * auto to = SingleTestObserver::create(std::make_shared()); * single->subscribe(to); * * Now when 'single' is subscribed to, the SingleTestObserver behavior @@ -49,8 +61,8 @@ class SingleTestObserver : public yarpl::single::SingleObserver { * * @return */ - static Reference> create() { - return make_ref>(); + static std::shared_ptr> create() { + return std::make_shared>(); } /** @@ -60,9 +72,9 @@ class SingleTestObserver : public yarpl::single::SingleObserver { * This will store the value it receives to allow assertions. * @return */ - static Reference> create( - Reference> delegate) { - return make_ref>(std::move(delegate)); + static std::shared_ptr> create( + std::shared_ptr> delegate) { + return std::make_shared>(std::move(delegate)); } SingleTestObserver() : delegate_(nullptr) {} @@ -74,10 +86,10 @@ class SingleTestObserver : public yarpl::single::SingleObserver { // and then access them for verification/assertion // on the unit test main thread. - explicit SingleTestObserver(Reference> delegate) + explicit SingleTestObserver(std::shared_ptr> delegate) : delegate_(std::move(delegate)) {} - void onSubscribe(Reference subscription) override { + void onSubscribe(std::shared_ptr subscription) override { if (delegate_) { delegateSubscription_->setDelegate(subscription); // copy delegate_->onSubscribe(std::move(subscription)); @@ -108,14 +120,14 @@ class SingleTestObserver : public yarpl::single::SingleObserver { terminalEventCV_.notify_all(); } - void onError(std::exception_ptr ex) override { + void onError(folly::exception_wrapper ex) override { if (delegate_) { // Do NOT hold the mutex while emitting delegate_->onError(ex); } { std::lock_guard g(m_); - e_ = ex; + e_ = std::move(ex); terminated_ = true; } terminalEventCV_.notify_all(); @@ -149,7 +161,10 @@ class SingleTestObserver : public yarpl::single::SingleObserver { throw std::runtime_error("Did not receive terminal event."); } if (e_) { - throw std::runtime_error("Received onError instead of onSuccess"); + std::stringstream ss; + ss << "Received onError instead of onSuccess"; + ss << " (error was " << e_ << ")"; + throw std::runtime_error(ss.str()); } } @@ -165,8 +180,6 @@ class SingleTestObserver : public yarpl::single::SingleObserver { /** * Get a reference to the received value if onSuccess was called. - * - * @return */ T& getOnSuccessValue() { std::lock_guard g(m_); @@ -174,27 +187,31 @@ class SingleTestObserver : public yarpl::single::SingleObserver { } /** - * If the onError exception_ptr points to an error containing + * Get the error received from onError if it was called. + */ + folly::exception_wrapper getError() { + std::lock_guard g(m_); + if (!terminated_) { + throw std::logic_error{"Must call getError() on a terminated observer"}; + } + return e_; + } + + /** + * If the onError exception_wrapper points to an error containing * the given msg, complete successfully, otherwise throw a runtime_error */ void assertOnErrorMessage(std::string msg) { std::lock_guard g(m_); - if (e_ == nullptr) { + if (!e_ || e_.get_exception()->what() != msg) { std::stringstream ss; - ss << "exception_ptr == nullptr, but expected " << msg; + ss << "Error is: " << e_ << " but expected: " << msg; throw std::runtime_error(ss.str()); } - try { - std::rethrow_exception(e_); - } catch (std::runtime_error& re) { - if (re.what() != msg) { - std::stringstream ss; - ss << "Error message is: " << re.what() << " but expected: " << msg; - throw std::runtime_error(ss.str()); - } - } catch (...) { - throw std::runtime_error("Expects an std::runtime_error"); - } + } + + folly::exception_wrapper getException() const { + return e_; } /** @@ -208,15 +225,15 @@ class SingleTestObserver : public yarpl::single::SingleObserver { private: std::mutex m_; std::condition_variable terminalEventCV_; - Reference> delegate_; + std::shared_ptr> delegate_; // The following variables must be protected by mutex m_ T value_; - std::exception_ptr e_; + folly::exception_wrapper e_; bool terminated_{false}; // allows thread-safe cancellation against a delegate // regardless of when it is received - Reference delegateSubscription_{ - make_ref()}; + std::shared_ptr delegateSubscription_{ + std::make_shared()}; }; -} -} +} // namespace single +} // namespace yarpl diff --git a/yarpl/single/Singles.h b/yarpl/single/Singles.h new file mode 100644 index 000000000..b6fe896cb --- /dev/null +++ b/yarpl/single/Singles.h @@ -0,0 +1,83 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "yarpl/single/Single.h" +#include "yarpl/single/SingleSubscriptions.h" + +#include + +namespace yarpl { +namespace single { + +class Singles { + public: + template + static std::shared_ptr> just(const T& value) { + auto lambda = [value](std::shared_ptr> observer) { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onSuccess(value); + }; + + return Single::create(std::move(lambda)); + } + + template < + typename T, + typename OnSubscribe, + typename = typename std::enable_if>>::value>::type> + static std::shared_ptr> create(OnSubscribe&& function) { + return std::make_shared< + FromPublisherOperator>>( + std::forward(function)); + } + + template + static std::shared_ptr> error(folly::exception_wrapper ex) { + auto lambda = + [e = std::move(ex)](std::shared_ptr> observer) { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onError(e); + }; + return Single::create(std::move(lambda)); + } + + template + static std::shared_ptr> error(const ExceptionType& ex) { + auto lambda = [ex](std::shared_ptr> observer) { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onError(ex); + }; + return Single::create(std::move(lambda)); + } + + template + static std::shared_ptr> fromGenerator(TGenerator&& generator) { + auto lambda = [generator = std::forward(generator)]( + std::shared_ptr> observer) mutable { + observer->onSubscribe(SingleSubscriptions::empty()); + observer->onSuccess(generator()); + }; + return Single::create(std::move(lambda)); + } + + private: + Singles() = delete; +}; + +} // namespace single +} // namespace yarpl diff --git a/yarpl/src/yarpl/flowable/sources/Subscription.cpp b/yarpl/src/yarpl/flowable/sources/Subscription.cpp deleted file mode 100644 index 905837fd2..000000000 --- a/yarpl/src/yarpl/flowable/sources/Subscription.cpp +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "yarpl/flowable/Subscription.h" - -namespace yarpl { -namespace flowable { - -yarpl::Reference Subscription::empty() { - class NullSubscription : public Subscription { - void request(int64_t) override {} - void cancel() override {} - }; - return make_ref(); -} - -} // flowable -} // yarpl diff --git a/yarpl/src/yarpl/observable/Subscriptions.cpp b/yarpl/src/yarpl/observable/Subscriptions.cpp deleted file mode 100644 index df3efd221..000000000 --- a/yarpl/src/yarpl/observable/Subscriptions.cpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "yarpl/observable/Subscriptions.h" -#include -#include - -namespace yarpl { -namespace observable { - -/** - * Implementation that allows checking if a Subscription is cancelled. - */ -void AtomicBoolSubscription::cancel() { - cancelled_ = true; -} - -bool AtomicBoolSubscription::isCancelled() const { - return cancelled_; -} - -/** - * Implementation that gets a callback when cancellation occurs. - */ -CallbackSubscription::CallbackSubscription(std::function&& onCancel) - : onCancel_(std::move(onCancel)) {} - -void CallbackSubscription::cancel() { - bool expected = false; - // mark cancelled 'true' and only if successful invoke 'onCancel()' - if (cancelled_.compare_exchange_strong(expected, true)) { - onCancel_(); - } -} - -bool CallbackSubscription::isCancelled() const { - return cancelled_; -} - -Reference Subscriptions::create(std::function onCancel) { - return Reference(new CallbackSubscription(std::move(onCancel))); -} - -Reference Subscriptions::create(std::atomic_bool& cancelled) { - return create([&cancelled]() { cancelled = true; }); -} - -Reference Subscriptions::empty() { - return Reference(new AtomicBoolSubscription()); -} - -Reference Subscriptions::atomicBoolSubscription() { - return Reference(new AtomicBoolSubscription()); -} -} -} diff --git a/yarpl/src/yarpl/schedulers/ThreadScheduler.cpp b/yarpl/src/yarpl/schedulers/ThreadScheduler.cpp deleted file mode 100644 index 1a70b765f..000000000 --- a/yarpl/src/yarpl/schedulers/ThreadScheduler.cpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "yarpl/schedulers/ThreadScheduler.h" - -#include -#include -#include -#include - -#include "yarpl/Disposable.h" - -/** - * A VERY BAD implementation of Scheduler. - * This spawns a thread for *every* schedule event. - * This also means it breaks the contract of ensuring sequential - * execution on a single Worker. - * - * And it does nothing with disposal. - */ -// TODO fix this mess by finishing a proper implementation -namespace yarpl { - -class ADisposable : public yarpl::Disposable { - void dispose() override {} - - bool isDisposed() override { - return false; - } -}; - -class ThreadWorker : public Worker { - public: - std::unique_ptr schedule( - std::function&& task) override { - std::thread([task = std::move(task)]() { task(); }).detach(); - return std::make_unique(); - } - - void dispose() override { - isDisposed_.store(true); - } - - bool isDisposed() override { - return isDisposed_; - } - - private: - std::atomic_bool isDisposed_{false}; - // std::thread loop_; -}; - -std::unique_ptr ThreadScheduler::createWorker() { - return std::make_unique(); -} -} diff --git a/yarpl/test/FlowableFlatMapTest.cpp b/yarpl/test/FlowableFlatMapTest.cpp new file mode 100644 index 000000000..5306d9971 --- /dev/null +++ b/yarpl/test/FlowableFlatMapTest.cpp @@ -0,0 +1,297 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" +#include "yarpl/test_utils/Mocks.h" + +namespace yarpl { +namespace flowable { + +namespace { + +/// Construct a pipeline with a test subscriber against the supplied +/// flowable. Return the items that were sent to the subscriber. If some +/// exception was sent, the exception is thrown. +template +std::vector run( + std::shared_ptr> flowable, + int64_t requestCount = 100) { + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + return std::move(subscriber->values()); +} + +} // namespace + +template +std::vector filter_run(std::vector in, Pred pred) { + std::vector ret; + std::copy_if(in.begin(), in.end(), std::back_inserter(ret), pred); + return ret; +} + +std::vector +filter_range(std::vector in, int64_t startat, int64_t endat) { + CHECK_LE(startat, endat); + return filter_run( + in, [=](int64_t i) { return (i >= startat) && (i < endat); }); +} + +auto make_flowable_mapper_func() { + return folly::Function>(int)>([](int n) { + switch (n) { + case 10: + return Flowable<>::range(n, 2); + case 20: + return Flowable<>::range(n, 3); + case 30: + return Flowable<>::range(n, 4); + } + return Flowable<>::range(n, 3); + }); +} + +// assumes that separate streams of values in separate_streams are entirely +// disjoint +template +bool validate_flatmapped_values( + std::vector flatmapped, + std::vector> separate_streams) { + for (auto elem : flatmapped) { + bool found_match = false; + for (auto& stream : separate_streams) { + if (stream.size() > 0) { + if (elem == stream[0]) { + stream.pop_front(); + found_match = true; + break; + } + } + } + + EXPECT_TRUE(found_match) + << "Did not find elem '" << elem << "' in any input streams"; + if (!found_match) { + return false; + } + } + + return true; +} + +TEST(FlowableFlatMapTest, AllRequestedTest) { + auto f = Flowable<>::justN({10, 20, 30}) + ->flatMap(make_flowable_mapper_func()); + + std::vector res = run(f); + EXPECT_EQ(9UL, res.size()); + EXPECT_EQ(filter_range(res, 10, 20), std::vector({10, 11})); + EXPECT_EQ(filter_range(res, 20, 30), std::vector({20, 21, 22})); + EXPECT_EQ(filter_range(res, 30, 40), std::vector({30, 31, 32, 33})); +} + +TEST(FlowableFlatMapTest, FiniteRequested) { + auto f = Flowable<>::justN({10, 20, 30}) + ->flatMap(make_flowable_mapper_func()); + + auto subscriber = std::make_shared>(1); + f->subscribe(subscriber); + + EXPECT_EQ(1UL, subscriber->values().size()); + EXPECT_TRUE( + validate_flatmapped_values(subscriber->values(), {{10}, {20}, {30}})); + + subscriber->request(3); + EXPECT_TRUE(validate_flatmapped_values( + subscriber->values(), {{10, 11}, {20, 21, 22}, {30, 31, 32, 33}})); + EXPECT_EQ(subscriber->getValueCount(), 4); + subscriber->cancel(); + EXPECT_EQ(subscriber->getValueCount(), 4); +} + +TEST(FlowableFlatMapTest, MappingLambdaThrowsErrorOnFirstCall) { + folly::Function>(int)> func = [](int n) { + CHECK_EQ(1, n); + throw std::runtime_error{"throwing in mapper!"}; + return Flowable::empty(); + }; + + auto f = Flowable<>::just(1)->flatMap(std::move(func)); + + auto subscriber = std::make_shared>(1); + f->subscribe(subscriber); + + EXPECT_EQ(subscriber->getValueCount(), 0); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "throwing in mapper!"); +} + +TEST(FlowableFlatMapTest, MappedStreamThrows) { + folly::Function>(int)> func = [](int n) { + CHECK_EQ(1, n); + + // flowable which emits an onNext, then the next iteration, emits an error + int64_t i = 1; + return Flowable::create( + [i](auto& subscriber, int64_t req) mutable { + CHECK_EQ(1, req); + if (i > 0) { + subscriber.onNext(i); + i--; + } else { + subscriber.onError(std::runtime_error{"throwing in stream!"}); + } + }); + }; + + auto f = Flowable<>::just(1)->flatMap(std::move(func)); + + auto subscriber = std::make_shared>(2); + f->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "throwing in stream!"); +} + +struct CBSubscription : yarpl::flowable::Subscription { + template + CBSubscription(OnReq&& r, OnCancel&& c) + : onRequest(std::move(r)), onCancel(std::move(c)) {} + + void request(int64_t n) override { + onRequest(n); + }; + void cancel() override { + onCancel(); + } + + folly::Function onRequest; + folly::Function onCancel; +}; + +struct FlowableEvbPair { + FlowableEvbPair() = default; + std::shared_ptr> flowable{nullptr}; + folly::EventBaseThread evb{}; +}; + +std::shared_ptr make_range_flowable(int start, int end) { + auto ret = std::make_shared(); + ret->evb.start("MRF_Worker"); + ret->flowable = Flowable<>::range(start, end - start) + ->map([](int64_t val) { return (int)val; }) + ->subscribeOn(*ret->evb.getEventBase()); + return ret; +} + +TEST(FlowableFlatMapTest, Multithreaded) { + auto p1 = make_range_flowable(10, 12); + auto p2 = make_range_flowable(20, 25); + + auto f = Flowable<>::range(0, 2)->flatMap([&](auto i) { + if (i == 0) { + return p1->flowable; + } else { + return p2->flowable; + } + }); + + auto sub = std::make_shared>(0); + f->subscribe(sub); + + sub->request(2); + sub->awaitValueCount(2); + EXPECT_TRUE(validate_flatmapped_values(sub->values(), {{10, 11}, {20, 21}})); + + sub->cancel(); + p1->evb.stop(); + p2->evb.stop(); +} + +TEST(FlowableFlatMapTest, MultithreadedLargeAmount) { + auto p1 = make_range_flowable(10000, 40000); + auto p2 = make_range_flowable(50000, 80000); + + auto f = Flowable<>::range(0, 2)->flatMap([&](auto i) { + if (i == 0) { + return p1->flowable; + } else { + return p2->flowable; + } + }); + + auto sub = std::make_shared>(); + sub->dropValues(true); + + f->subscribe(sub); + + sub->awaitTerminalEvent(std::chrono::seconds{5}); + EXPECT_EQ(60000, sub->getValueCount()); + EXPECT_TRUE(sub->isComplete()); + + p1->evb.stop(); + p2->evb.stop(); +} + +TEST(FlowableFlatMapTest, MergeOperator) { + auto sub = std::make_shared>(0); + + auto p1 = Flowable<>::justN({"foo", "bar"}); + auto p2 = Flowable<>::justN({"baz", "quxx"}); + std::shared_ptr>>> p3 = + Flowable<>::justN>>({p1, p2}); + + std::shared_ptr> p4 = p3->merge(); + p4->subscribe(sub); + + EXPECT_EQ(0, sub->getValueCount()); + sub->request(1); + EXPECT_EQ(1, sub->getValueCount()); + EXPECT_EQ(false, sub->isComplete()); + EXPECT_TRUE(validate_flatmapped_values(sub->values(), {{"foo"}, {"baz"}})); + + sub->request(1); + EXPECT_EQ(2, sub->getValueCount()); + EXPECT_EQ(false, sub->isComplete()); + EXPECT_EQ(false, sub->isError()); + EXPECT_TRUE(validate_flatmapped_values( + sub->values(), {{"foo", "bar"}, {"baz", "quxx"}})); + + sub->request(1); + EXPECT_EQ(3, sub->getValueCount()); + EXPECT_EQ(false, sub->isComplete()); + EXPECT_EQ(false, sub->isError()); + EXPECT_TRUE(validate_flatmapped_values( + sub->values(), {{"foo", "bar"}, {"baz", "quxx"}})); + + sub->request(1); + EXPECT_EQ(4, sub->getValueCount()); + EXPECT_EQ(true, sub->isComplete()); + EXPECT_EQ(false, sub->isError()); + EXPECT_TRUE(validate_flatmapped_values( + sub->values(), {{"foo", "bar"}, {"baz", "quxx"}})); +} + +} // namespace flowable +} // namespace yarpl diff --git a/yarpl/test/FlowableSubscriberTest.cpp b/yarpl/test/FlowableSubscriberTest.cpp new file mode 100644 index 000000000..683f57f4b --- /dev/null +++ b/yarpl/test/FlowableSubscriberTest.cpp @@ -0,0 +1,146 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/Subscriber.h" +#include "yarpl/test_utils/Mocks.h" + +using namespace yarpl; +using namespace yarpl::flowable; +using namespace yarpl::mocks; +using namespace testing; + +namespace { + +TEST(FlowableSubscriberTest, CreateSubscriber) { + int calls{0}; + struct Functor { + explicit Functor(int& calls) : calls_(calls) {} + // If we update the template definition of the Subscriber, + // then we should comment out this method and observe the compiler output + // with and without the change. + void operator()(int) & { + ++calls_; + } + void operator()(int) && { + FAIL() << "onNext lambda should be stored as l-value"; + } + void operator()(std::string) const& { + ++calls_; + } + void operator()(std::string) const&& { + FAIL() << "onNext lambda should be stored as l-value"; + } + int& calls_; + }; + auto s1 = Subscriber::create(Functor(calls)); + s1->onSubscribe(yarpl::flowable::Subscription::create()); + s1->onNext(1); + EXPECT_EQ(1, calls); + + auto s2 = Subscriber::create(Functor(calls)); + s2->onSubscribe(yarpl::flowable::Subscription::create()); + s2->onNext((long)1); + EXPECT_EQ(2, calls); + + auto s3 = Subscriber::create(Functor(calls)); + s3->onSubscribe(yarpl::flowable::Subscription::create()); + s3->onNext("test"); + EXPECT_EQ(3, calls); + + // by reference + auto f = Functor(calls); + auto s4 = Subscriber::create(f); + s4->onSubscribe(yarpl::flowable::Subscription::create()); + s4->onNext(1); + EXPECT_EQ(4, calls); +} + +TEST(FlowableSubscriberTest, TestBasicFunctionality) { + Sequence subscriber_seq; + auto subscriber = std::make_shared>>(); + + EXPECT_CALL(*subscriber, onSubscribeImpl()) + .Times(1) + .InSequence(subscriber_seq) + .WillOnce(Invoke([&] { subscriber->request(3); })); + EXPECT_CALL(*subscriber, onNextImpl(5)).Times(1).InSequence(subscriber_seq); + EXPECT_CALL(*subscriber, onCompleteImpl()) + .Times(1) + .InSequence(subscriber_seq); + + auto subscription = std::make_shared>(); + EXPECT_CALL(*subscription, request_(3)) + .Times(1) + .WillOnce(InvokeWithoutArgs([&] { + subscriber->onNext(5); + subscriber->onComplete(); + })); + + subscriber->onSubscribe(subscription); +} + +TEST(FlowableSubscriberTest, TestKeepRefToThisIsDisabled) { + auto subscriber = + std::make_shared>>(); + auto subscription = std::make_shared>(); + + // tests that only a single reference exists to the Subscriber; clearing + // reference in `auto subscriber` would cause it to deallocate + { + InSequence s; + EXPECT_CALL(*subscriber, onSubscribeImpl()).Times(1).WillOnce(Invoke([&] { + EXPECT_EQ(1UL, subscriber.use_count()); + })); + } + + subscriber->onSubscribe(subscription); +} +TEST(FlowableSubscriberTest, TestKeepRefToThisIsEnabled) { + auto subscriber = std::make_shared>>(); + auto subscription = std::make_shared>(); + + // tests that only a reference is held somewhere on the stack, so clearing + // references to `BaseSubscriber` while in a signaling method won't + // deallocate it (until it's safe to do so) + { + InSequence s; + EXPECT_CALL(*subscriber, onSubscribeImpl()).Times(1).WillOnce(Invoke([&] { + EXPECT_EQ(2UL, subscriber.use_count()); + })); + } + + subscriber->onSubscribe(subscription); +} + +TEST(FlowableSubscriberTest, AutoFlowControl) { + size_t count = 0; + auto subscriber = Subscriber::create( + [&](int value) { + ++count; + EXPECT_EQ(value, count); + }, + 1); + auto subscription = std::make_shared>(); + + EXPECT_CALL(*subscription, request_(1)) + .Times(3) + .WillOnce(InvokeWithoutArgs([&] { subscriber->onNext(1); })) + .WillOnce(InvokeWithoutArgs([&] { + subscriber->onNext(2); + subscriber->onComplete(); + })); + + subscriber->onSubscribe(subscription); +} +} // namespace diff --git a/yarpl/test/FlowableTest.cpp b/yarpl/test/FlowableTest.cpp index 8cb108cae..7c8a77353 100644 --- a/yarpl/test/FlowableTest.cpp +++ b/yarpl/test/FlowableTest.cpp @@ -1,44 +1,65 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include #include #include -#include #include "yarpl/Flowable.h" +#include "yarpl/flowable/Subscriber.h" #include "yarpl/flowable/TestSubscriber.h" -#include "yarpl/utils/ExceptionString.h" +#include "yarpl/test_utils/Mocks.h" + +#if FOLLY_HAS_COROUTINES +#include +#include "yarpl/flowable/AsyncGeneratorShim.h" +#endif + +using namespace yarpl::flowable; +using namespace testing; -namespace yarpl { -namespace flowable { namespace { /* * Used in place of TestSubscriber where we have move-only types. */ template -class CollectingSubscriber : public Subscriber { +class CollectingSubscriber : public BaseSubscriber { public: explicit CollectingSubscriber(int64_t requestCount = 100) : requestCount_(requestCount) {} - void onSubscribe(Reference subscription) override { - Subscriber::onSubscribe(subscription); - subscription->request(requestCount_); + void onSubscribeImpl() override { + this->request(requestCount_); } - void onNext(T next) override { + void onNextImpl(T next) override { values_.push_back(std::move(next)); } - void onComplete() override { - Subscriber::onComplete(); + void onCompleteImpl() override { complete_ = true; } - void onError(std::exception_ptr ex) override { - Subscriber::onError(ex); + void onErrorImpl(folly::exception_wrapper ex) override { error_ = true; - errorMsg_ = yarpl::exceptionStr(ex); + errorMsg_ = ex.get_exception()->what(); } std::vector& values() { @@ -58,7 +79,7 @@ class CollectingSubscriber : public Subscriber { } void cancelSubscription() { - Subscriber::subscription()->cancel(); + this->cancel(); } private: @@ -74,29 +95,29 @@ class CollectingSubscriber : public Subscriber { /// exception was sent, the exception is thrown. template std::vector run( - Reference> flowable, + std::shared_ptr> flowable, int64_t requestCount = 100) { - auto subscriber = make_ref>(requestCount); + auto subscriber = std::make_shared>(requestCount); flowable->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); return std::move(subscriber->values()); } - } // namespace TEST(FlowableTest, SingleFlowable) { - auto flowable = Flowables::just(10); + auto flowable = Flowable<>::just(10); flowable.reset(); } TEST(FlowableTest, SingleMovableFlowable) { auto value = std::make_unique(123456); - auto flowable = Flowables::justOnce(std::move(value)); - EXPECT_EQ(std::size_t{1}, flowable->count()); + auto flowable = Flowable<>::justOnce(std::move(value)); + EXPECT_EQ(1, flowable.use_count()); size_t received = 0; auto subscriber = - Subscribers::create>([&](std::unique_ptr p) { + Subscriber>::create([&](std::unique_ptr p) { EXPECT_EQ(*p, 123456); received++; }); @@ -106,24 +127,24 @@ TEST(FlowableTest, SingleMovableFlowable) { } TEST(FlowableTest, JustFlowable) { - EXPECT_EQ(run(Flowables::just(22)), std::vector{22}); + EXPECT_EQ(run(Flowable<>::just(22)), std::vector{22}); EXPECT_EQ( - run(Flowables::justN({12, 34, 56, 98})), + run(Flowable<>::justN({12, 34, 56, 98})), std::vector({12, 34, 56, 98})); EXPECT_EQ( - run(Flowables::justN({"ab", "pq", "yz"})), + run(Flowable<>::justN({"ab", "pq", "yz"})), std::vector({"ab", "pq", "yz"})); } TEST(FlowableTest, JustIncomplete) { - auto flowable = Flowables::justN({"a", "b", "c"})->take(2); + auto flowable = Flowable<>::justN({"a", "b", "c"})->take(2); EXPECT_EQ(run(std::move(flowable)), std::vector({"a", "b"})); - flowable = Flowables::justN({"a", "b", "c"})->take(2)->take(1); + flowable = Flowable<>::justN({"a", "b", "c"})->take(2)->take(1); EXPECT_EQ(run(std::move(flowable)), std::vector({"a"})); flowable.reset(); - flowable = Flowables::justN( + flowable = Flowable<>::justN( {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) ->map([](std::string s) { s[0] = ::toupper(s[0]); @@ -137,14 +158,30 @@ TEST(FlowableTest, JustIncomplete) { flowable.reset(); } +TEST(FlowableTest, MapWithException) { + auto flowable = Flowable<>::justN({1, 2, 3, 4})->map([](int n) { + if (n > 2) { + throw std::runtime_error{"Too big!"}; + } + return n; + }); + + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2})); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "Too big!"); +} + TEST(FlowableTest, Range) { EXPECT_EQ( - run(Flowables::range(10, 5)), + run(Flowable<>::range(10, 5)), std::vector({10, 11, 12, 13, 14})); } TEST(FlowableTest, RangeWithMap) { - auto flowable = Flowables::range(1, 3) + auto flowable = Flowable<>::range(1, 3) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return std::to_string(v); }); @@ -153,43 +190,40 @@ TEST(FlowableTest, RangeWithMap) { } TEST(FlowableTest, RangeWithReduceMoreItems) { - auto flowable = Flowables::range(0, 10) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); - EXPECT_EQ( - run(std::move(flowable)), std::vector({45})); + auto flowable = Flowable<>::range(0, 10)->reduce( + [](int64_t acc, int64_t v) { return acc + v; }); + EXPECT_EQ(run(std::move(flowable)), std::vector({45})); } TEST(FlowableTest, RangeWithReduceByMultiplication) { - auto flowable = Flowables::range(0, 10) - ->reduce([](int64_t acc, int64_t v) { return acc * v; }); - EXPECT_EQ( - run(std::move(flowable)), std::vector({0})); + auto flowable = Flowable<>::range(0, 10)->reduce( + [](int64_t acc, int64_t v) { return acc * v; }); + EXPECT_EQ(run(std::move(flowable)), std::vector({0})); - flowable = Flowables::range(1, 10) - ->reduce([](int64_t acc, int64_t v) { return acc * v; }); + flowable = Flowable<>::range(1, 10)->reduce( + [](int64_t acc, int64_t v) { return acc * v; }); EXPECT_EQ( - run(std::move(flowable)), std::vector({2*3*4*5*6*7*8*9*10})); + run(std::move(flowable)), + std::vector({2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10})); } TEST(FlowableTest, RangeWithReduceLessItems) { - auto flowable = Flowables::range(0, 10) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); + auto flowable = Flowable<>::range(0, 10)->reduce( + [](int64_t acc, int64_t v) { return acc + v; }); // Even if we ask for 1 item only, it will reduce all the items - EXPECT_EQ( - run(std::move(flowable), 5), std::vector({45})); + EXPECT_EQ(run(std::move(flowable), 5), std::vector({45})); } TEST(FlowableTest, RangeWithReduceOneItem) { - auto flowable = Flowables::range(5, 1) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); - EXPECT_EQ( - run(std::move(flowable)), std::vector({5})); + auto flowable = Flowable<>::range(5, 1)->reduce( + [](int64_t acc, int64_t v) { return acc + v; }); + EXPECT_EQ(run(std::move(flowable)), std::vector({5})); } TEST(FlowableTest, RangeWithReduceNoItem) { - auto flowable = Flowables::range(0, 0) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); - auto subscriber = make_ref>(100); + auto flowable = Flowable<>::range(0, 0)->reduce( + [](int64_t acc, int64_t v) { return acc + v; }); + auto subscriber = std::make_shared>(100); flowable->subscribe(subscriber); EXPECT_TRUE(subscriber->isComplete()); @@ -197,76 +231,100 @@ TEST(FlowableTest, RangeWithReduceNoItem) { } TEST(FlowableTest, RangeWithFilterAndReduce) { - auto flowable = Flowables::range(0, 10) - ->filter([](int64_t v) { return v % 2 != 0; }) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); + auto flowable = Flowable<>::range(0, 10) + ->filter([](int64_t v) { return v % 2 != 0; }) + ->reduce([](int64_t acc, int64_t v) { return acc + v; }); EXPECT_EQ( - run(std::move(flowable)), std::vector({1+3+5+7+9})); + run(std::move(flowable)), std::vector({1 + 3 + 5 + 7 + 9})); } TEST(FlowableTest, RangeWithReduceToBiggerType) { - auto flowable = Flowables::range(5, 1) - ->map([](int64_t v){ return (char)(v + 10); }) - ->reduce([](int64_t acc, char v) { return acc + v; }); - EXPECT_EQ( - run(std::move(flowable)), std::vector({15})); + auto flowable = Flowable<>::range(5, 1) + ->map([](int64_t v) { return (char)(v + 10); }) + ->reduce([](int64_t acc, char v) { return acc + v; }); + EXPECT_EQ(run(std::move(flowable)), std::vector({15})); } TEST(FlowableTest, StringReduce) { - auto flowable = Flowables::justN( - {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) - ->reduce([](std::string acc, std::string v) { - return acc + v; - }); - EXPECT_EQ( - run(std::move(flowable)), std::vector({"abcdefghi"})); + auto flowable = + Flowable<>::justN( + {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) + ->reduce([](std::string acc, std::string v) { return acc + v; }); + EXPECT_EQ(run(std::move(flowable)), std::vector({"abcdefghi"})); } TEST(FlowableTest, RangeWithFilterRequestMoreItems) { auto flowable = - Flowables::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); + Flowable<>::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(flowable)), std::vector({1, 3, 5, 7, 9})); } TEST(FlowableTest, RangeWithFilterRequestLessItems) { auto flowable = - Flowables::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); + Flowable<>::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(flowable), 5), std::vector({1, 3, 5, 7, 9})); } TEST(FlowableTest, RangeWithFilterAndMap) { - auto flowable = Flowables::range(0, 10) - ->filter([](int64_t v) { return v % 2 != 0; }) - ->map([](int64_t v){ return v + 10; }); - EXPECT_EQ(run(std::move(flowable)), std::vector({11, 13, 15, 17, 19})); + auto flowable = Flowable<>::range(0, 10) + ->filter([](int64_t v) { return v % 2 != 0; }) + ->map([](int64_t v) { return v + 10; }); + EXPECT_EQ( + run(std::move(flowable)), std::vector({11, 13, 15, 17, 19})); } TEST(FlowableTest, RangeWithMapAndFilter) { - auto flowable = Flowables::range(0, 10) - ->map([](int64_t v){ return (char)(v + 10); }) - ->filter([](char v) { return v % 2 != 0; }); + auto flowable = Flowable<>::range(0, 10) + ->map([](int64_t v) { return (char)(v + 10); }) + ->filter([](char v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(flowable)), std::vector({11, 13, 15, 17, 19})); } TEST(FlowableTest, SimpleTake) { EXPECT_EQ( - run(Flowables::range(0, 100)->take(3)), std::vector({0, 1, 2})); + run(Flowable<>::range(0, 100)->take(3)), std::vector({0, 1, 2})); EXPECT_EQ( - run(Flowables::range(10, 5)), + run(Flowable<>::range(10, 5)), std::vector({10, 11, 12, 13, 14})); + + EXPECT_EQ(run(Flowable<>::range(0, 100)->take(0)), std::vector({})); +} + +TEST(FlowableTest, TakeError) { + auto take0 = + Flowable::error(std::runtime_error("something broke!"))->take(0); + + auto subscriber = std::make_shared>(); + take0->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); +} + +TEST(FlowableTes, NeverTake) { + auto take0 = Flowable::never()->take(0); + + auto subscriber = std::make_shared>(); + take0->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); } TEST(FlowableTest, SimpleSkip) { - EXPECT_EQ(run(Flowables::range(0, 10)->skip(8)), std::vector({8, 9})); + EXPECT_EQ( + run(Flowable<>::range(0, 10)->skip(8)), std::vector({8, 9})); } TEST(FlowableTest, OverflowSkip) { - EXPECT_EQ(run(Flowables::range(0, 10)->skip(12)), std::vector({})); + EXPECT_EQ(run(Flowable<>::range(0, 10)->skip(12)), std::vector({})); } TEST(FlowableTest, SkipPartial) { - auto subscriber = make_ref>(2); - auto flowable = Flowables::range(0, 10)->skip(5); + auto subscriber = std::make_shared>(2); + auto flowable = Flowable<>::range(0, 10)->skip(5); flowable->subscribe(subscriber); EXPECT_EQ(subscriber->values(), std::vector({5, 6})); @@ -274,15 +332,14 @@ TEST(FlowableTest, SkipPartial) { } TEST(FlowableTest, IgnoreElements) { - auto flowable = Flowables::range(0, 100) - ->ignoreElements() - ->map([](int64_t v) { return v * v; }); + auto flowable = Flowable<>::range(0, 100)->ignoreElements()->map( + [](int64_t v) { return v * v; }); EXPECT_EQ(run(flowable), std::vector({})); } TEST(FlowableTest, IgnoreElementsPartial) { - auto subscriber = make_ref>(5); - auto flowable = Flowables::range(0, 10)->ignoreElements(); + auto subscriber = std::make_shared>(5); + auto flowable = Flowable<>::range(0, 10)->ignoreElements(); flowable->subscribe(subscriber); EXPECT_EQ(subscriber->values(), std::vector({})); @@ -292,11 +349,11 @@ TEST(FlowableTest, IgnoreElementsPartial) { subscriber->cancel(); } -TEST(FlowableTest, IgnoreElementsError) { +TEST(FlowableTest, FlowableErrorNoRequestN) { constexpr auto kMsg = "Failure"; - auto subscriber = make_ref>(); - auto flowable = Flowables::error(std::runtime_error(kMsg)); + auto subscriber = std::make_shared>(0); + auto flowable = Flowable::error(std::runtime_error(kMsg)); flowable->subscribe(subscriber); EXPECT_TRUE(subscriber->isError()); @@ -306,8 +363,8 @@ TEST(FlowableTest, IgnoreElementsError) { TEST(FlowableTest, FlowableError) { constexpr auto kMsg = "something broke!"; - auto flowable = Flowables::error(std::runtime_error(kMsg)); - auto subscriber = make_ref>(); + auto flowable = Flowable::error(std::runtime_error(kMsg)); + auto subscriber = std::make_shared>(); flowable->subscribe(subscriber); EXPECT_FALSE(subscriber->isComplete()); @@ -315,33 +372,58 @@ TEST(FlowableTest, FlowableError) { EXPECT_EQ(subscriber->getErrorMsg(), kMsg); } -TEST(FlowableTest, FlowableErrorPtr) { - constexpr auto kMsg = "something broke!"; +TEST(FlowableTest, FlowableEmpty) { + auto flowable = Flowable::empty(); + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); - auto flowable = Flowables::error( - std::make_exception_ptr(std::runtime_error(kMsg))); - auto subscriber = make_ref>(); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); +} + +TEST(FlowableTest, FlowableEmptyNoRequestN) { + auto flowable = Flowable::empty(); + auto subscriber = std::make_shared>(0); flowable->subscribe(subscriber); + EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isError()); +} + +TEST(FlowableTest, FlowableNever) { + auto flowable = Flowable::never(); + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + EXPECT_THROW( + subscriber->awaitTerminalEvent(std::chrono::milliseconds(100)), + std::runtime_error); + EXPECT_FALSE(subscriber->isComplete()); - EXPECT_TRUE(subscriber->isError()); - EXPECT_EQ(subscriber->getErrorMsg(), kMsg); + EXPECT_FALSE(subscriber->isError()); + + subscriber->cancel(); } -TEST(FlowableTest, FlowableEmpty) { - auto flowable = Flowables::empty(); - auto subscriber = make_ref>(); +TEST(FlowableTest, FlowableNeverNoRequestN) { + auto flowable = Flowable::never(); + auto subscriber = std::make_shared>(0); flowable->subscribe(subscriber); + EXPECT_THROW( + subscriber->awaitTerminalEvent(std::chrono::milliseconds(100)), + std::runtime_error); - EXPECT_TRUE(subscriber->isComplete()); + EXPECT_FALSE(subscriber->isComplete()); EXPECT_FALSE(subscriber->isError()); + + subscriber->cancel(); } TEST(FlowableTest, FlowableFromGenerator) { - auto flowable = Flowables::fromGenerator>( + auto flowable = Flowable>::fromGenerator( [] { return std::unique_ptr(); }); - auto subscriber = make_ref>>(10); + auto subscriber = + std::make_shared>>(10); flowable->subscribe(subscriber); EXPECT_FALSE(subscriber->isComplete()); @@ -354,13 +436,15 @@ TEST(FlowableTest, FlowableFromGenerator) { TEST(FlowableTest, FlowableFromGeneratorException) { constexpr auto errorMsg = "error from generator"; int count = 5; - auto flowable = Flowables::fromGenerator>( - [&] { - while (count--) { return std::unique_ptr(); } + auto flowable = Flowable>::fromGenerator([&] { + while (count--) { + return std::unique_ptr(); + } throw std::runtime_error(errorMsg); }); - auto subscriber = make_ref>>(10); + auto subscriber = + std::make_shared>>(10); flowable->subscribe(subscriber); EXPECT_FALSE(subscriber->isComplete()); @@ -370,29 +454,29 @@ TEST(FlowableTest, FlowableFromGeneratorException) { } TEST(FlowableTest, SubscribersComplete) { - auto flowable = Flowables::empty(); - auto subscriber = Subscribers::create( - [](int) { FAIL(); }, [](std::exception_ptr) { FAIL(); }, [&] {}); + auto flowable = Flowable::empty(); + auto subscriber = Subscriber::create( + [](int) { FAIL(); }, [](folly::exception_wrapper) { FAIL(); }, [&] {}); flowable->subscribe(std::move(subscriber)); } TEST(FlowableTest, SubscribersError) { - auto flowable = Flowables::error(std::runtime_error("Whoops")); - auto subscriber = Subscribers::create( - [](int) { FAIL(); }, [&](std::exception_ptr) {}, [] { FAIL(); }); + auto flowable = Flowable::error(std::runtime_error("Whoops")); + auto subscriber = Subscriber::create( + [](int) { FAIL(); }, [&](folly::exception_wrapper) {}, [] { FAIL(); }); flowable->subscribe(std::move(subscriber)); } TEST(FlowableTest, FlowableCompleteInTheMiddle) { - auto flowable = Flowable::create( - [](Subscriber & subscriber, int64_t requested) { + auto flowable = + Flowable::create([](auto& subscriber, int64_t requested) { EXPECT_GT(requested, 1); subscriber.onNext(123); subscriber.onComplete(); - return std::make_tuple(int64_t(1), true); - })->map([](int v) { return std::to_string(v); }); + }) + ->map([](int v) { return std::to_string(v); }); - auto subscriber = make_ref>(10); + auto subscriber = std::make_shared>(10); flowable->subscribe(subscriber); EXPECT_TRUE(subscriber->isComplete()); @@ -400,5 +484,1036 @@ TEST(FlowableTest, FlowableCompleteInTheMiddle) { EXPECT_EQ(std::size_t{1}, subscriber->values().size()); } -} // flowable -} // yarpl +class RangeCheckingSubscriber : public BaseSubscriber { + public: + explicit RangeCheckingSubscriber(int32_t total, folly::Baton<>& b) + : total_(total), onComplete_(b) {} + + void onSubscribeImpl() override { + this->request(total_); + } + + void onNextImpl(int32_t val) override { + EXPECT_EQ(val, current_); + current_++; + } + + void onErrorImpl(folly::exception_wrapper) override { + FAIL() << "shouldn't call onError"; + } + + void onCompleteImpl() override { + EXPECT_EQ(total_, current_); + onComplete_.post(); + } + + private: + int32_t current_{0}; + int32_t total_; + folly::Baton<>& onComplete_; +}; + +namespace { +// workaround for gcc-4.9 +auto const expect_count = 10000; +TEST(FlowableTest, FlowableFromDifferentThreads) { + auto flowable = Flowable::create([&](auto& subscriber, int64_t req) { + EXPECT_EQ(req, expect_count); + auto t1 = std::thread([&] { + for (int32_t i = 0; i < req; i++) { + subscriber.onNext(i); + } + subscriber.onComplete(); + }); + t1.join(); + }); + + auto t2 = std::thread([&] { + folly::Baton<> on_flowable_complete; + flowable->subscribe(std::make_shared( + expect_count, on_flowable_complete)); + on_flowable_complete.timed_wait(std::chrono::milliseconds(100)); + }); + + t2.join(); +} +} // namespace + +class ErrorRangeCheckingSubscriber : public BaseSubscriber { + public: + explicit ErrorRangeCheckingSubscriber( + int32_t expect, + int32_t request, + folly::Baton<>& b, + folly::exception_wrapper expected_err) + : expect_(expect), + request_(request), + onError_(b), + expectedErr_(expected_err) {} + + void onSubscribeImpl() override { + this->request(request_); + } + + void onNextImpl(int32_t val) override { + EXPECT_EQ(val, current_); + current_++; + } + + void onErrorImpl(folly::exception_wrapper err) override { + EXPECT_EQ(expect_, current_); + EXPECT_TRUE(err); + EXPECT_EQ( + err.get_exception()->what(), expectedErr_.get_exception()->what()); + onError_.post(); + } + + void onCompleteImpl() override { + FAIL() << "shouldn't ever onComplete"; + } + + private: + int32_t expect_; + int32_t request_; + folly::Baton<>& onError_; + folly::exception_wrapper expectedErr_; + int32_t current_{0}; +}; + +namespace { +// workaround for gcc-4.9 +auto const request = 10000; +auto const expect = 5000; +auto const the_ex = folly::make_exception_wrapper("wat"); + +TEST(FlowableTest, FlowableFromDifferentThreadsWithError) { + auto flowable = Flowable::create([=](auto& subscriber, int64_t req) { + EXPECT_EQ(req, request); + EXPECT_LT(expect, request); + + auto t1 = std::thread([&] { + for (int32_t i = 0; i < expect; i++) { + subscriber.onNext(i); + } + subscriber.onError(the_ex); + }); + t1.join(); + }); + + auto t2 = std::thread([&] { + folly::Baton<> on_flowable_error; + flowable->subscribe(std::make_shared( + expect, request, on_flowable_error, the_ex)); + on_flowable_error.timed_wait(std::chrono::milliseconds(100)); + }); + + t2.join(); +} +} // namespace + +TEST(FlowableTest, SubscribeMultipleTimes) { + using namespace ::testing; + using StrictMockSubscriber = + testing::StrictMock>; + auto f = Flowable::create([](auto& subscriber, int64_t req) { + for (int64_t i = 0; i < req; i++) { + subscriber.onNext(i); + } + + subscriber.onComplete(); + }); + + auto setup_mock = [](auto request_num, auto& resps) { + auto mock = std::make_shared(request_num); + + Sequence seq; + EXPECT_CALL(*mock, onSubscribe_(_)).InSequence(seq); + EXPECT_CALL(*mock, onNext_(_)) + .InSequence(seq) + .WillRepeatedly( + Invoke([&resps](int64_t value) { resps.push_back(value); })); + EXPECT_CALL(*mock, onComplete_()).InSequence(seq); + return mock; + }; + + std::vector> results{5}; + auto mock1 = setup_mock(5, results[0]); + auto mock2 = setup_mock(5, results[1]); + auto mock3 = setup_mock(5, results[2]); + auto mock4 = setup_mock(5, results[3]); + auto mock5 = setup_mock(5, results[4]); + + // map on the same flowable twice + auto stream1 = f->map([](auto i) { return i + 1; }); + auto stream2 = f->map([](auto i) { return i * 2; }); + auto stream3 = stream2->skip(2); // skip operator chained after a map operator + auto stream4 = stream1->take(3); // take operator chained after a map operator + auto stream5 = stream1; // test subscribing to exact same flowable twice + + stream1->subscribe(mock1); + stream2->subscribe(mock2); + stream3->subscribe(mock3); + stream4->subscribe(mock4); + stream5->subscribe(mock5); + + EXPECT_EQ(results[0], std::vector({1, 2, 3, 4, 5})); + EXPECT_EQ(results[1], std::vector({0, 2, 4, 6, 8})); + EXPECT_EQ(results[2], std::vector({4, 6, 8, 10, 12})); + EXPECT_EQ(results[3], std::vector({1, 2, 3})); + EXPECT_EQ(results[4], std::vector({1, 2, 3, 4, 5})); +} + +/* following test should probably behave like: + * +TEST(FlowableTest, ConsumerThrows_OnNext) { + auto range = Flowable<>::range(1, 10); + + EXPECT_THROWS({ + range->subscribe( + // onNext + [](auto) { throw std::runtime_error("throw at consumption"); }, + // onError + [](auto) { FAIL(); }, + // onComplete + []() { FAIL(); }); + }); +} +*/ +TEST(FlowableTest, ConsumerThrows_OnNext) { + bool onErrorIsCalled{false}; + + Flowable<>::range(1, 10)->subscribe( + [](auto) { throw std::runtime_error("throw at consumption"); }, + [&onErrorIsCalled](auto ex) { onErrorIsCalled = true; }, + []() { FAIL() << "onError should have been called"; }); + + EXPECT_TRUE(onErrorIsCalled); +} + +TEST(FlowableTest, ConsumerThrows_OnNext_Cancel) { + class TestOperator : public FlowableOperator { + public: + void subscribe(std::shared_ptr> subscriber) override { + auto subscription = + std::make_shared>(); + EXPECT_CALL(*subscription, request_(_)); + EXPECT_CALL(*subscription, cancel_()); + subscriber->onSubscribe(subscription); + + try { + subscriber->onNext(1); + } catch (const std::exception&) { + FAIL() + << "onNext should not throw but subscription should get canceled."; + } + } + }; + + auto testOperator = std::make_shared(); + auto mapped = testOperator->map([](uint32_t i) { + throw std::runtime_error("test"); + return i; + }); + auto mockSubscriber = + std::make_shared>>(); + EXPECT_CALL(*mockSubscriber, onSubscribe_(_)); + EXPECT_CALL(*mockSubscriber, onError_(_)); + mapped->subscribe(mockSubscriber); +} + +TEST(FlowableTest, DeferTest) { + int switchValue = 0; + auto flowable = Flowable::defer([&]() { + if (switchValue == 0) { + return Flowable<>::range(1, 1); + } else { + return Flowable<>::range(3, 1); + } + }); + + EXPECT_EQ(run(flowable), std::vector({1})); + switchValue = 1; + EXPECT_EQ(run(flowable), std::vector({3})); +} + +TEST(FlowableTest, DeferExceptionTest) { + auto flowable = Flowable::defer([&]() -> std::shared_ptr> { + throw std::runtime_error{"Too big!"}; + }); + + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "Too big!"); +} + +TEST(FlowableTest, DoOnSubscribeTest) { + auto a = Flowable::empty(); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnSubscribe([&] { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnNextTest) { + std::vector values; + auto a = Flowable<>::range(10, 14)->doOnNext( + [&](int64_t v) { values.push_back(v); }); + auto values2 = run(std::move(a)); + EXPECT_EQ(values, values2); +} + +TEST(FlowableTest, DoOnErrorTest) { + auto a = Flowable::error(std::runtime_error("something broke!")); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnError([&](const auto&) { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnTerminateTest) { + auto a = Flowable::empty(); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnTerminate([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnTerminate2Test) { + auto a = Flowable::error(std::runtime_error("something broke!")); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnTerminate([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnEachTest) { + // TODO(lehecka): rewrite with concatWith + auto a = Flowable::create([](Subscriber& s, int64_t) { + s.onNext(5); + s.onError(std::runtime_error("something broke!")); + }); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()).Times(2); + a->doOnEach([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(FlowableTest, DoOnTest) { + // TODO(lehecka): rewrite with concatWith + auto a = Flowable::create([](Subscriber& s, int64_t) { + s.onNext(5); + s.onError(std::runtime_error("something broke!")); + }); + + MockFunction checkpoint1; + EXPECT_CALL(checkpoint1, Call()); + MockFunction checkpoint2; + EXPECT_CALL(checkpoint2, Call()); + + a->doOn( + [&](int value) { + checkpoint1.Call(); + EXPECT_EQ(value, 5); + }, + [] { FAIL(); }, + [&](const auto&) { checkpoint2.Call(); }) + ->subscribe(); +} + +TEST(FlowableTest, DoOnCancelTest) { + auto a = Flowable<>::range(1, 10); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnCancel([&]() { checkpoint.Call(); })->take(1)->subscribe(); +} + +template +void cancelDuringOnNext(Op&& op, F&& f) { + folly::Baton<> next, cancelled; + + folly::ScopedEventBaseThread thread; + + auto d = op(Flowable<>::justN({1, 2}), + [&, marker = std::make_shared(1), f](auto&&... args) { + auto weak = std::weak_ptr(marker); + // This simulates subscription cancellation during onNext + next.post(); + cancelled.wait(); + // Lambda with all captures should still exist, while it's + // handling onNext call. If it doesn't exist, the following + // lock will fail. + EXPECT_TRUE(weak.lock()); + return f(args...); + }) + ->observeOn(thread.getEventBase()) + ->subscribe([](int) {}); + + // Wait till onNext is called, and cancel subscription while onNext is still + // in progress + ASSERT_TRUE(next.try_wait_for(std::chrono::seconds(1))); + d->dispose(); + + // Let onNext finish + cancelled.post(); +} + +TEST(FlowableTest, CancelDuringMapOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->map(f); }, + [](int value) { return value; }); +} + +TEST(FlowableTest, CancelDuringFilterOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->filter(f); }, + [](int value) { return value > 0; }); +} + +TEST(FlowableTest, CancelDuringReduceOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->reduce(f); }, + [](int acc, int value) { return acc + value; }); +} + +TEST(FlowableTest, CancelDuringDoOnNext) { + cancelDuringOnNext( + [](auto&& flowable, auto&& f) { return flowable->doOnNext(f); }, + [](int) {}); +} + +TEST(FlowableTest, DoOnRequestTest) { + auto a = Flowable<>::range(1, 10); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call(2)); + + a->doOnRequest([&](int64_t n) { checkpoint.Call(n); })->take(2)->subscribe(); +} + +TEST(FlowableTest, ConcatWithTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto combined = first->concatWith(second); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); +} + +TEST(FlowableTest, ConcatWithMultipleTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable<>::range(10, 2); + auto fourth = Flowable<>::range(15, 2); + auto firstSecond = first->concatWith(second); + auto thirdFourth = third->concatWith(fourth); + auto combined = firstSecond->concatWith(thirdFourth); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(FlowableTest, ConcatWithExceptionTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable::error(std::runtime_error("error")); + + auto combined = first->concatWith(second)->concatWith(third); + + auto subscriber = std::make_shared>(); + combined->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 5, 6})); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "error"); +} + +TEST(FlowableTest, ConcatWithFlowControlTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable<>::range(10, 2); + auto fourth = Flowable<>::range(15, 2); + auto firstSecond = first->concatWith(second); + auto thirdFourth = third->concatWith(fourth); + auto combined = firstSecond->concatWith(thirdFourth); + + auto subscriber = std::make_shared>(0); + combined->subscribe(subscriber); + EXPECT_EQ(subscriber->values(), std::vector{}); + + const std::vector allResults{1, 2, 5, 6, 10, 11, 15, 16}; + for (int i = 1; i <= 8; ++i) { + subscriber->request(1); + subscriber->awaitValueCount(1, std::chrono::seconds(1)); + EXPECT_EQ( + subscriber->values(), + std::vector(allResults.begin(), allResults.begin() + i)); + } +} + +TEST(FlowableTest, ConcatWithCancel) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + + auto combined = first->concatWith(second); + auto subscriber = std::make_shared>(0); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + combined->doOnCancel([&]() { checkpoint.Call(); })->subscribe(subscriber); + + subscriber->request(3); + subscriber->awaitValueCount(3, std::chrono::seconds(1)); + + subscriber->cancel(); + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 5})); +} + +TEST(FlowableTest, ConcatWithCompleteAtSubscription) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + + auto combined = first->concatWith(second)->take(0); + EXPECT_EQ(run(combined), std::vector({})); +} + +TEST(FlowableTest, ConcatWithVarArgsTest) { + auto first = Flowable<>::range(1, 2); + auto second = Flowable<>::range(5, 2); + auto third = Flowable<>::range(10, 2); + auto fourth = Flowable<>::range(15, 2); + + auto combined = first->concatWith(second, third, fourth); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(FlowableTest, ConcatTest) { + auto combined = Flowable::concat( + Flowable<>::range(1, 2), Flowable<>::range(5, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + + // Flowable::concat shoud not accept one parameter! + // Next line should cause compiler failure: OK! + // combined = Flowable::concat(Flowable<>::range(1, 2)); + + combined = Flowable::concat( + Flowable<>::range(1, 2), + Flowable<>::range(5, 2), + Flowable<>::range(10, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11})); + + combined = Flowable::concat( + Flowable<>::range(1, 2), + Flowable<>::range(5, 2), + Flowable<>::range(10, 2), + Flowable<>::range(15, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(FlowableTest, ConcatWith_DelaySubscribe) { + // If there is no request for the second flowable, don't subscribe to it + bool subscribed = false; + auto a = Flowable<>::range(1, 1); + auto b = Flowable<>::range(2, 1)->doOnSubscribe( + [&subscribed]() { subscribed = true; }); + auto combined = a->concatWith(b); + + uint32_t request = 0; + auto subscriber = std::make_shared>(request); + combined->subscribe(subscriber); + subscriber->request(1); + + ASSERT_EQ(subscriber->values(), std::vector({1})); + ASSERT_FALSE(subscribed); + + // termination signal! + subscriber->cancel(); // otherwise we leak the active subscription +} + +TEST(FlowableTest, ConcatWith_EagerCancel) { + // If there is no request for the second flowable, don't subscribe to it + bool subscribed = false; + + // Control the execution of SubscribeOn operator + folly::EventBase evb; + + auto a = Flowable<>::range(1, 1); + auto b = Flowable<>::range(2, 1)->subscribeOn(evb)->doOnSubscribe( + [&subscribed]() { subscribed = true; }); + auto combined = a->concatWith(b); + + uint32_t request = 2; + std::vector values; + auto subscriber = yarpl::flowable::Subscriber::create( + [&values](int64_t value) { values.push_back(value); }, request); + + combined->subscribe(subscriber); + + // Even though we requested 2 items, we received 1 item + ASSERT_EQ(values, std::vector({1})); + ASSERT_FALSE(subscribed); // not yet, callback did not arrive yet! + + // We have requested 2 items, but did not consume the second item yet + // and we send a cancel before looping the eventBase + auto baseSubscriber = static_cast*>(subscriber.get()); + baseSubscriber->cancel(); + + // If the evb is never looped, it will cause memory leak + evb.loop(); + ASSERT_EQ(values, std::vector({1})); // no change! + ASSERT_TRUE(subscribed); // subscribe() already issued before the cancel +} + +class TestTimeout : public folly::AsyncTimeout { + public: + explicit TestTimeout(folly::EventBase* eventBase, folly::Function fn) + : AsyncTimeout(eventBase), fn_(std::move(fn)) {} + + void timeoutExpired() noexcept override { + fn_(); + } + + folly::Function fn_; +}; + +TEST(FlowableTest, Timeout_SpecialException) { + class RestrictedType { + public: + RestrictedType() = default; + RestrictedType(RestrictedType&&) noexcept = default; + RestrictedType& operator=(RestrictedType&&) noexcept = default; + auto operator()() { + return std::logic_error("RestrictedType"); + } + }; + + folly::EventBase timerEvb; + auto flowable = Flowable::never()->timeout( + timerEvb, + std::chrono::milliseconds(0), + std::chrono::milliseconds(1), + RestrictedType{}); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + + timerEvb.loop(); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->exceptionWrapper().with_exception( + [](const std::logic_error& ex) { + EXPECT_STREQ("RestrictedType", ex.what()); + })); +} + +TEST(FlowableTest, Timeout_NoTimeout) { + folly::EventBase timerEvb; + auto flowable = Flowable<>::range(1, 1)->observeOn(timerEvb)->timeout( + timerEvb, std::chrono::milliseconds(0), std::chrono::milliseconds(0)); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + EXPECT_EQ(subscriber->values(), std::vector({1})); + + flowable = + Flowable::create([=](auto& subscriber, int64_t) { + subscriber.onNext(2); + subscriber.onComplete(); + }) + ->observeOn(timerEvb) + ->timeout(timerEvb, std::chrono::seconds(0), std::chrono::seconds(0)); + + subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + EXPECT_EQ(subscriber->values(), std::vector({2})); +} + +TEST(FlowableTest, Timeout_OnNextTimeout) { + folly::EventBase timerEvb; + + auto flowable = Flowable<>::range(1, 2)->observeOn(timerEvb)->timeout( + timerEvb, + std::chrono::milliseconds(50), + std::chrono::milliseconds(0)); // no init_timeout + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // request next in 100 msec, timeout! + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + // first one is consumed + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_InitTimeout) { + folly::EventBase timerEvb; + auto flowable = Flowable::create([=](auto& subscriber, int64_t req) { + if (req > 0) { + subscriber.onNext(2); + subscriber.onComplete(); + } + }) + ->observeOn(timerEvb) + ->timeout( + timerEvb, + std::chrono::milliseconds(0), + std::chrono::milliseconds(10)); + + int requestCount = 0; + auto subscriber = std::make_shared>(requestCount); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // timeout the init + + flowable->subscribe(subscriber); + flowable.reset(); + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_StopUsageOfTimer) { + // When the consumption completes, it should stop using the timer + auto flowable = Flowable<>::range(1, 1); + { + // EventBase will be deleted before the flowable + folly::EventBase timerEvb; + auto flowableIn = flowable->timeout( + timerEvb, std::chrono::milliseconds(1), std::chrono::milliseconds(0)); + EXPECT_EQ(run(flowableIn), std::vector({1})); + } +} + +TEST(FlowableTest, Timeout_NeverOperator_Timesout) { + folly::EventBase timerEvb; + auto flowable = Flowable::never()->observeOn(timerEvb)->timeout( + timerEvb, std::chrono::milliseconds(10), std::chrono::milliseconds(10)); + + int requestCount = 10; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + flowable.reset(); + + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_BecauseOfNoRequest) { + folly::ScopedEventBaseThread timerThread; + auto flowable = Flowable<>::range(1, 2) + ->observeOn(*timerThread.getEventBase()) + ->timeout( + *timerThread.getEventBase(), + std::chrono::seconds(1), + std::chrono::milliseconds(10)); + + int requestCount = 0; + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_WithObserveOnSubscribeOn) { + folly::ScopedEventBaseThread subscribeOnThread; + folly::EventBase timerEvb; + auto flowable = Flowable<>::range(1, 2) + ->subscribeOn(*subscribeOnThread.getEventBase()) + ->observeOn(timerEvb) + ->timeout( + timerEvb, + std::chrono::milliseconds(10), + std::chrono::milliseconds(100)); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // timeout onNext + + flowable->subscribe(subscriber); + flowable.reset(); + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + // first one is consumed + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, Timeout_SameThread) { + folly::EventBase timerEvb; + auto flowable = Flowable<>::range(1, 2) + ->subscribeOn(timerEvb) + ->observeOn(timerEvb) + ->timeout( + timerEvb, + std::chrono::milliseconds(10), + std::chrono::milliseconds(100)); + + int requestCount = 1; + auto subscriber = std::make_shared>(requestCount); + + TestTimeout timeout(&timerEvb, [subscriber]() { subscriber->request(1); }); + timeout.scheduleTimeout(100); // timeout onNext + + flowable->subscribe(subscriber); + flowable.reset(); + timerEvb.loop(); + + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + + // first one is consumed + EXPECT_EQ(subscriber->values(), std::vector({1})); + EXPECT_TRUE(subscriber->isError()); +} + +TEST(FlowableTest, SwapException) { + auto flowable = Flowable::error(std::runtime_error("private")); + flowable = flowable->map( + [](auto&& a) { return a; }, + [](auto) { return std::runtime_error("public"); }); + + auto subscriber = std::make_shared>(); + flowable->subscribe(subscriber); + + EXPECT_EQ(subscriber->values(), std::vector({})); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "public"); +} + +#if FOLLY_HAS_COROUTINES +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorIntType) { + folly::ScopedEventBaseThread th; + const int length = 5; + folly::Baton<> baton; + auto stream = + folly::coro::co_invoke([]() -> folly::coro::AsyncGenerator { + for (int i = 0; i < length; i++) { + co_yield std::move(i); + } + }); + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](int i) { EXPECT_EQ(expected_i++, i); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 2); + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorStringType) { + folly::ScopedEventBaseThread th; + const int length = 5; + folly::Baton<> baton; + auto stream = folly::coro::co_invoke( + []() -> folly::coro::AsyncGenerator { + for (int i = 0; i < length; i++) { + co_yield folly::to(i); + } + }); + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](std::string i) { EXPECT_EQ(expected_i++, folly::to(i)); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 2); + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorReverseFulfill) { + folly::ScopedEventBaseThread th; + folly::ScopedEventBaseThread pth; + const int length = 5; + std::vector> vp(length); + + int i = 0; + folly::Baton<> baton; + auto stream = + folly::coro::co_invoke([&]() -> folly::coro::AsyncGenerator { + while (i < length) { + co_yield co_await vp[i++].getSemiFuture(); + } + }); + + // intentionally let promised fulfilled in reverse order, but the result + // should come back to stream in order + for (int i = length - 1; i >= 0; i--) { + pth.add([&vp, i]() { vp[i].setValue(i); }); + } + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](int i) { EXPECT_EQ(expected_i++, i); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 2); + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorLambdaGaptureVariable) { + folly::ScopedEventBaseThread th; + std::string t = "test"; + folly::Baton<> baton; + auto stream = folly::coro::co_invoke( + [&, t = std::move(t) ]() mutable + -> folly::coro::AsyncGenerator { + co_yield std::move(t); + co_return; + }); + + std::string result; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](std::string t) { result = t; }, + [&](folly::exception_wrapper ex) { + ADD_FAILURE() << "on Error " << ex.what(); + baton.post(); + }, + [&] { baton.post(); }, + 2); + baton.wait(); + + EXPECT_EQ("test", result); +} + +TEST(AsyncGeneratorShimTest, ShouldNotHaveCoAwaitMoreThanOnce) { + folly::ScopedEventBaseThread th; + folly::ScopedEventBaseThread pth; + const int length = 5; + std::vector> vp(length); + + int i = 0; + folly::Baton<> baton; + auto stream = + folly::coro::co_invoke([&]() -> folly::coro::AsyncGenerator { + while (i < length) { + co_yield co_await vp[i++].getSemiFuture(); + } + }); + + int expected_i = 0; + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe( + [&](int i) { EXPECT_EQ(expected_i++, i); }, + [&](folly::exception_wrapper) { + ADD_FAILURE() << "on Error"; + baton.post(); + }, + [&] { + EXPECT_EQ(expected_i, length); + baton.post(); + }, + 5); + // subscribe before fulfill future, expecting co_await on the future will + // happen before setValue() + for (int i = 0; i < length; i++) { + pth.add([&vp, i]() { vp[i].setValue(i); }); + } + baton.wait(); +} + +TEST(AsyncGeneratorShimTest, CoroAsyncGeneratorPreemptiveCancel) { + folly::ScopedEventBaseThread th; + folly::coro::Baton b; + bool canceled = false; + auto stream = folly::coro:: + co_invoke([&]() -> folly::coro::AsyncGenerator { + // cancelCallback will be execute in the same event loop + // as async generator + folly::CancellationCallback cancelCallback( + co_await folly::coro::co_current_cancellation_token, [&]() { + canceled = true; + b.post(); + }); + co_yield "first"; + co_await b; + if (!canceled) { + co_yield "never_reach"; + } + }); + + struct TestSubscriber : public Subscriber { + void onSubscribe(std::shared_ptr s) override final { + s->request(2); + s_ = std::move(s); + } + + void onNext(std::string s) override { + EXPECT_EQ("first", s); + b1_.post(); + } + void onComplete() override { + b2_.post(); + } + void onError(folly::exception_wrapper) override {} + std::shared_ptr s_; + folly::Baton<> b1_, b2_; + }; + auto subscriber = std::make_shared(); + yarpl::toFlowable(std::move(stream), th.getEventBase()) + ->subscribe(subscriber); + subscriber->b1_.wait(); + subscriber->s_->cancel(); + subscriber->b2_.wait(); + EXPECT_TRUE(canceled); +} +#endif diff --git a/test/MocksTest.cpp b/yarpl/test/MocksTest.cpp similarity index 56% rename from test/MocksTest.cpp rename to yarpl/test/MocksTest.cpp index 3c8fdcb28..db97df1ff 100644 --- a/test/MocksTest.cpp +++ b/yarpl/test/MocksTest.cpp @@ -1,13 +1,25 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include "test/test_utils/Mocks.h" +#include "yarpl/test_utils/Mocks.h" #include #include using namespace ::testing; -using namespace rsocket; using namespace yarpl::flowable; +using namespace yarpl::mocks; TEST(MocksTest, SelfManagedMocks) { // Best run with ASAN, to detect potential leaks, use-after-free or @@ -15,12 +27,12 @@ TEST(MocksTest, SelfManagedMocks) { int value = 42; MockFlowable flowable; - auto subscription = yarpl::make_ref(); - auto subscriber = yarpl::make_ref>(0); + auto subscription = std::make_shared(); + auto subscriber = std::make_shared>(0); { InSequence dummy; EXPECT_CALL(flowable, subscribe_(_)) - .WillOnce(Invoke([&](yarpl::Reference> consumer) { + .WillOnce(Invoke([&](std::shared_ptr> consumer) { consumer->onSubscribe(subscription); })); EXPECT_CALL(*subscriber, onSubscribe_(_)); diff --git a/yarpl/test/Observable_test.cpp b/yarpl/test/Observable_test.cpp index a024e67a4..f2c31a1ba 100644 --- a/yarpl/test/Observable_test.cpp +++ b/yarpl/test/Observable_test.cpp @@ -1,24 +1,39 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include #include #include +#include #include "yarpl/Observable.h" +#include "yarpl/flowable/Flowable.h" #include "yarpl/flowable/Subscriber.h" -#include "yarpl/flowable/Subscribers.h" -#include "yarpl/schedulers/ThreadScheduler.h" - -#include "Tuple.h" +#include "yarpl/test_utils/Mocks.h" +#include "yarpl/test_utils/Tuple.h" // TODO can we eliminate need to import both of these? using namespace yarpl; +using namespace yarpl::mocks; using namespace yarpl::observable; +using namespace testing; namespace { void unreachable() { - EXPECT_TRUE(false); + EXPECT_TRUE(false) << "unreachable code"; } template @@ -31,17 +46,14 @@ class CollectingObserver : public Observer { void onComplete() override { Observer::onComplete(); complete_ = true; + terminated_ = true; } - void onError(std::exception_ptr ex) override { + void onError(folly::exception_wrapper ex) override { Observer::onError(ex); error_ = true; - - try { - std::rethrow_exception(ex); - } catch (const std::exception& e) { - errorMsg_ = e.what(); - } + errorMsg_ = ex.get_exception()->what(); + terminated_ = true; } std::vector& values() { @@ -60,42 +72,57 @@ class CollectingObserver : public Observer { return errorMsg_; } + /** + * Block the current thread until either onSuccess or onError is called. + */ + void awaitTerminalEvent( + std::chrono::milliseconds ms = std::chrono::seconds{1}) { + // now block this thread + std::unique_lock lk(m_); + // if shutdown gets implemented this would then be released by it + if (!terminalEventCV_.wait_for(lk, ms, [this] { return terminated_; })) { + throw std::runtime_error("timeout in awaitTerminalEvent"); + } + } + private: std::vector values_; std::string errorMsg_; bool complete_{false}; bool error_{false}; + + bool terminated_{false}; + std::mutex m_; + std::condition_variable terminalEventCV_; }; /// Construct a pipeline with a collecting observer against the supplied /// observable. Return the items that were sent to the observer. If some /// exception was sent, the exception is thrown. template -std::vector run(Reference> observable) { - auto collector = make_ref>(); +std::vector run(std::shared_ptr> observable) { + auto collector = std::make_shared>(); observable->subscribe(collector); + collector->awaitTerminalEvent(std::chrono::seconds(1)); return std::move(collector->values()); } } // namespace TEST(Observable, SingleOnNext) { - auto a = Observable::create([](Reference> obs) { - auto s = Subscriptions::empty(); - obs->onSubscribe(s); + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(1); obs->onComplete(); }); std::vector v; a->subscribe( - Observers::create([&v](const int& value) { v.push_back(value); })); + Observer::create([&v](const int& value) { v.push_back(value); })); EXPECT_EQ(v.at(0), 1); } TEST(Observable, MultiOnNext) { - auto a = Observable::create([](Reference> obs) { - obs->onSubscribe(Subscriptions::empty()); + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(1); obs->onNext(2); obs->onNext(3); @@ -104,7 +131,7 @@ TEST(Observable, MultiOnNext) { std::vector v; a->subscribe( - Observers::create([&v](const int& value) { v.push_back(value); })); + Observer::create([&v](const int& value) { v.push_back(value); })); EXPECT_EQ(v.at(0), 1); EXPECT_EQ(v.at(1), 2); @@ -113,49 +140,33 @@ TEST(Observable, MultiOnNext) { TEST(Observable, OnError) { std::string errorMessage("DEFAULT->No Error Message"); - auto a = Observable::create([](Reference> obs) { - try { - throw std::runtime_error("something broke!"); - } catch (const std::exception&) { - obs->onError(std::current_exception()); - } + auto a = Observable::create([](std::shared_ptr> obs) { + obs->onError(std::runtime_error("something broke!")); }); - a->subscribe(Observers::create( - [](int value) { /* do nothing */ }, - [&errorMessage](std::exception_ptr e) { - try { - std::rethrow_exception(e); - } catch (const std::runtime_error& ex) { - errorMessage = std::string(ex.what()); - } + a->subscribe(Observer::create( + [](int) { /* do nothing */ }, + [&errorMessage](folly::exception_wrapper ex) { + errorMessage = ex.get_exception()->what(); })); EXPECT_EQ("something broke!", errorMessage); } -static std::atomic instanceCount; - /** * Assert that all items passed through the Observable get destroyed */ TEST(Observable, ItemsCollectedSynchronously) { - auto a = Observable::create([](Reference> obs) { - obs->onSubscribe(Subscriptions::empty()); + auto a = Observable::create([](std::shared_ptr> obs) { obs->onNext(Tuple{1, 2}); obs->onNext(Tuple{2, 3}); obs->onNext(Tuple{3, 4}); obs->onComplete(); }); - a->subscribe(Observers::create([](const Tuple& value) { + a->subscribe(Observer::create([](const Tuple& value) { std::cout << "received value " << value.a << std::endl; })); - - std::cout << "Finished ... remaining instances == " << instanceCount - << std::endl; - - EXPECT_EQ(0, Tuple::instanceCount); } /* @@ -166,46 +177,40 @@ TEST(Observable, ItemsCollectedSynchronously) { * in a Vector which could then be consumed on another thread. */ TEST(DISABLED_Observable, ItemsCollectedAsynchronously) { - // scope this so we can check destruction of Vector after this block - { - auto a = Observable::create([](Reference> obs) { - obs->onSubscribe(Subscriptions::empty()); - std::cout << "-----------------------------" << std::endl; - obs->onNext(Tuple{1, 2}); - std::cout << "-----------------------------" << std::endl; - obs->onNext(Tuple{2, 3}); - std::cout << "-----------------------------" << std::endl; - obs->onNext(Tuple{3, 4}); - std::cout << "-----------------------------" << std::endl; - obs->onComplete(); - }); + auto a = Observable::create([](std::shared_ptr> obs) { + std::cout << "-----------------------------" << std::endl; + obs->onNext(Tuple{1, 2}); + std::cout << "-----------------------------" << std::endl; + obs->onNext(Tuple{2, 3}); + std::cout << "-----------------------------" << std::endl; + obs->onNext(Tuple{3, 4}); + std::cout << "-----------------------------" << std::endl; + obs->onComplete(); + }); - std::vector v; - v.reserve(10); // otherwise it resizes and copies on each push_back - a->subscribe(Observers::create([&v](const Tuple& value) { - std::cout << "received value " << value.a << std::endl; - // copy into vector - v.push_back(value); - std::cout << "done pushing into vector" << std::endl; - })); - - // expect that 3 instances were originally created, then 3 more when copying - EXPECT_EQ(6, Tuple::createdCount); - // expect that 3 instances still exist in the vector, so only 3 destroyed so - // far - EXPECT_EQ(3, Tuple::destroyedCount); - - std::cout << "Leaving block now so Vector should release Tuples..." - << std::endl; - } - EXPECT_EQ(0, Tuple::instanceCount); + std::vector v; + v.reserve(10); // otherwise it resizes and copies on each push_back + a->subscribe(Observer::create([&v](const Tuple& value) { + std::cout << "received value " << value.a << std::endl; + // copy into vector + v.push_back(value); + std::cout << "done pushing into vector" << std::endl; + })); + + // 3 copy & 3 move and 3 more copy constructed + EXPECT_EQ(9, Tuple::createdCount); + // 3 still exists in the vector, 6 destroyed + EXPECT_EQ(6, Tuple::destroyedCount); + + std::cout << "Leaving block now so Vector should release Tuples..." + << std::endl; } class TakeObserver : public Observer { private: const int limit; int count = 0; - Reference subscription_; + std::shared_ptr subscription_; std::vector& v; public: @@ -213,7 +218,8 @@ class TakeObserver : public Observer { v.reserve(5); } - void onSubscribe(Reference s) override { + void onSubscribe( + std::shared_ptr s) override { subscription_ = std::move(s); } @@ -226,106 +232,322 @@ class TakeObserver : public Observer { } } - void onError(std::exception_ptr) override {} + void onError(folly::exception_wrapper) override {} void onComplete() override {} }; // assert behavior of onComplete after subscription.cancel TEST(Observable, SubscriptionCancellation) { - static std::atomic_int emitted{0}; - auto a = Observable::create([](Reference> obs) { - std::atomic_bool isUnsubscribed{false}; - auto s = - Subscriptions::create([&isUnsubscribed] { isUnsubscribed = true; }); - obs->onSubscribe(std::move(s)); + std::atomic_int emitted{0}; + auto a = Observable::create([&](std::shared_ptr> obs) { int i = 0; - while (!isUnsubscribed && i <= 10) { + while (!obs->isUnsubscribed() && i <= 10) { emitted++; obs->onNext(i++); } - if (!isUnsubscribed) { + if (!obs->isUnsubscribed()) { + // should be ignored obs->onComplete(); } }); std::vector v; - a->subscribe(Reference>(new TakeObserver(2, v))); + a->subscribe(std::make_shared(2, v)); EXPECT_EQ((unsigned long)2, v.size()); EXPECT_EQ(2, emitted); } -TEST(Observable, toFlowable) { - auto a = Observable::create([](Reference> obs) { - auto s = Subscriptions::empty(); - obs->onSubscribe(s); - obs->onNext(1); - obs->onComplete(); +TEST(Observable, CancelFromDifferentThread) { + std::atomic_int emitted{0}; + std::mutex m; + std::condition_variable cv; + + std::atomic cancelled1{false}; + std::atomic cancelled2{false}; + + std::thread t; + auto a = Observable::create([&](std::shared_ptr> obs) { + t = std::thread([obs, &emitted, &cancelled1]() { + obs->addSubscription([&]() { cancelled1 = true; }); + while (!obs->isUnsubscribed()) { + ++emitted; + obs->onNext(0); + } + }); + obs->addSubscription([&]() { cancelled2 = true; }); }); + auto subscription = a->subscribe([](int) {}); + + std::unique_lock lk(m); + CHECK(cv.wait_for( + lk, std::chrono::seconds(1), [&] { return emitted >= 1000; })); + + subscription->cancel(); + t.join(); + CHECK(cancelled1); + CHECK(cancelled2); + LOG(INFO) << "cancelled after " << emitted << " items"; +} + +TEST(Observable, toFlowableDrop) { + auto a = Observable<>::range(1, 10); auto f = a->toFlowable(BackpressureStrategy::DROP); - std::vector v; - f->subscribe(yarpl::flowable::Subscribers::create( - [&v](const int& value) { v.push_back(value); })); + std::vector v; - EXPECT_EQ(v.at(0), 1); + auto subscriber = + std::make_shared>>(5); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onComplete_()); + + f->subscribe(subscriber); + + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); } -TEST(Observable, toFlowableWithCancel) { - auto a = Observable::create([](Reference> obs) { - auto s = Subscriptions::atomicBoolSubscription(); - obs->onSubscribe(s); +TEST(Observable, toFlowableDropWithCancel) { + auto a = Observable::create([](std::shared_ptr> obs) { int i = 0; - while (!s->isCancelled()) { + while (!obs->isUnsubscribed()) { obs->onNext(++i); } - if (!s->isCancelled()) { - obs->onComplete(); - } }); auto f = a->toFlowable(BackpressureStrategy::DROP); std::vector v; - f->take(5)->subscribe(yarpl::flowable::Subscribers::create( + f->take(5)->subscribe(yarpl::flowable::Subscriber::create( [&v](const int& value) { v.push_back(value); })); EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); } +TEST(Observable, toFlowableErrorStrategy) { + auto a = Observable::createEx([](auto observer, auto subscription) { + int64_t i = 1; + for (; !subscription->isCancelled() && i <= 10; ++i) { + observer->onNext(i); + } + EXPECT_EQ(7, i); + }); + auto f = a->toFlowable(BackpressureStrategy::ERROR); + + std::vector v; + + auto subscriber = + std::make_shared>>(5); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onError_(_)) + .WillOnce(Invoke([&](folly::exception_wrapper ex) { + EXPECT_TRUE(ex.is_compatible_with< + yarpl::flowable::MissingBackpressureException>()); + })); + + f->subscribe(subscriber); + + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); +} + +TEST(Observable, toFlowableBufferStrategy) { + auto a = Observable<>::range(1, 10); + auto f = a->toFlowable(BackpressureStrategy::BUFFER); + + std::vector v; + + auto subscriber = + std::make_shared>>(5); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onComplete_()); + + f->subscribe(subscriber); + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); + + subscriber->subscription()->request(5); + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); +} + +TEST(Observable, toFlowableBufferStrategyLimit) { + std::shared_ptr> observer; + std::shared_ptr subscription; + + auto a = Observable::createEx([&](auto o, auto s) { + observer = std::move(o); + subscription = std::move(s); + }); + auto f = + a->toFlowable(std::make_shared>(3)); + + std::vector v; + + auto subscriber = + std::make_shared>>(5); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + + EXPECT_FALSE(observer); + EXPECT_FALSE(subscription); + + f->subscribe(subscriber); + + EXPECT_TRUE(observer); + EXPECT_TRUE(subscription); + + for (size_t i = 1; i <= 5; ++i) { + observer->onNext(i); + } + + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); + + observer->onNext(6); + observer->onNext(7); + observer->onNext(8); + + EXPECT_FALSE(observer->isUnsubscribedOrTerminated()); + EXPECT_FALSE(subscription->isCancelled()); + + EXPECT_CALL(*subscriber, onError_(_)) + .WillOnce(Invoke([&](folly::exception_wrapper ex) { + EXPECT_TRUE(ex.is_compatible_with< + yarpl::flowable::MissingBackpressureException>()); + })); + + observer->onNext(9); + + EXPECT_TRUE(observer->isUnsubscribedOrTerminated()); + EXPECT_TRUE(subscription->isCancelled()); +} + +TEST(Observable, toFlowableBufferStrategyStress) { + std::shared_ptr> observer; + auto a = Observable::createEx( + [&](auto o, auto) { observer = std::move(o); }); + auto f = a->toFlowable(BackpressureStrategy::BUFFER); + + std::vector v; + std::atomic tokens{0}; + + auto subscriber = + std::make_shared>>(0); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onComplete_()); + + f->subscribe(subscriber); + EXPECT_TRUE(observer); + + constexpr size_t kNumElements = 100000; + + std::thread nextThread([&] { + for (size_t i = 0; i < kNumElements; ++i) { + while (tokens.load() < -5) { + std::this_thread::yield(); + } + + observer->onNext(i); + --tokens; + } + observer->onComplete(); + }); + + std::thread requestThread([&] { + for (size_t i = 0; i < kNumElements; ++i) { + while (tokens.load() > 5) { + std::this_thread::yield(); + } + + subscriber->subscription()->request(1); + ++tokens; + } + }); + + nextThread.join(); + requestThread.join(); + + for (size_t i = 0; i < kNumElements; ++i) { + CHECK_EQ(i, v[i]); + } +} + +TEST(Observable, toFlowableLatestStrategy) { + auto a = Observable<>::range(1, 10); + auto f = a->toFlowable(BackpressureStrategy::LATEST); + + std::vector v; + + auto subscriber = + std::make_shared>>(5); + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onComplete_()); + + f->subscribe(subscriber); + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5})); + + subscriber->subscription()->request(5); + EXPECT_EQ(v, std::vector({1, 2, 3, 4, 5, 10})); +} + TEST(Observable, Just) { - EXPECT_EQ(run(Observables::just(22)), std::vector{22}); + EXPECT_EQ(run(Observable<>::just(22)), std::vector{22}); EXPECT_EQ( - run(Observables::justN({12, 34, 56, 98})), + run(Observable<>::justN({12, 34, 56, 98})), std::vector({12, 34, 56, 98})); EXPECT_EQ( - run(Observables::justN({"ab", "pq", "yz"})), + run(Observable<>::justN({"ab", "pq", "yz"})), std::vector({"ab", "pq", "yz"})); } TEST(Observable, SingleMovable) { auto value = std::make_unique(123456); - auto observable = Observables::justOnce(std::move(value)); - EXPECT_EQ(std::size_t{1}, observable->count()); + auto observable = Observable<>::justOnce(std::move(value)); + EXPECT_EQ(std::size_t{1}, observable.use_count()); auto values = run(std::move(observable)); - EXPECT_EQ( - values.size(), - size_t(1)); + EXPECT_EQ(values.size(), size_t(1)); - EXPECT_EQ( - *values[0], - 123456); + EXPECT_EQ(*values[0], 123456); +} + +TEST(Observable, MapWithException) { + auto observable = Observable<>::justN({1, 2, 3, 4})->map([](int n) { + if (n > 2) { + throw std::runtime_error{"Too big!"}; + } + return n; + }); + + auto observer = std::make_shared>(); + observable->subscribe(observer); + + EXPECT_EQ(observer->values(), std::vector({1, 2})); + EXPECT_TRUE(observer->error()); + EXPECT_EQ(observer->errorMsg(), "Too big!"); } TEST(Observable, Range) { - auto observable = Observables::range(10, 14); + auto observable = Observable<>::range(10, 4); EXPECT_EQ(run(std::move(observable)), std::vector({10, 11, 12, 13})); } TEST(Observable, RangeWithMap) { - auto observable = Observables::range(1, 4) + auto observable = Observable<>::range(1, 3) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return v * v; }) ->map([](int64_t v) { return std::to_string(v); }); @@ -334,51 +556,49 @@ TEST(Observable, RangeWithMap) { } TEST(Observable, RangeWithReduce) { - auto observable = Observables::range(0, 10) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); - EXPECT_EQ( - run(std::move(observable)), std::vector({45})); + auto observable = Observable<>::range(0, 10)->reduce( + [](int64_t acc, int64_t v) { return acc + v; }); + EXPECT_EQ(run(std::move(observable)), std::vector({45})); } TEST(Observable, RangeWithReduceByMultiplication) { - auto observable = Observables::range(0, 10) - ->reduce([](int64_t acc, int64_t v) { return acc * v; }); - EXPECT_EQ( - run(std::move(observable)), std::vector({0})); + auto observable = Observable<>::range(0, 10)->reduce( + [](int64_t acc, int64_t v) { return acc * v; }); + EXPECT_EQ(run(std::move(observable)), std::vector({0})); - observable = Observables::range(1, 10) - ->reduce([](int64_t acc, int64_t v) { return acc * v; }); + observable = Observable<>::range(1, 10)->reduce( + [](int64_t acc, int64_t v) { return acc * v; }); EXPECT_EQ( - run(std::move(observable)), std::vector({2*3*4*5*6*7*8*9})); + run(std::move(observable)), + std::vector({1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * 10})); } TEST(Observable, RangeWithReduceOneItem) { - auto observable = Observables::range(5, 6) - ->reduce([](int64_t acc, int64_t v) { return acc + v; }); - EXPECT_EQ( - run(std::move(observable)), std::vector({5})); + auto observable = Observable<>::range(5, 1)->reduce( + [](int64_t acc, int64_t v) { return acc + v; }); + EXPECT_EQ(run(std::move(observable)), std::vector({5})); } TEST(Observable, RangeWithReduceNoItem) { - auto observable = Observables::range(0, 0)->reduce( + auto observable = Observable<>::range(0, 0)->reduce( [](int64_t acc, int64_t v) { return acc + v; }); - auto collector = make_ref>(); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->error(), false); EXPECT_EQ(collector->values(), std::vector({})); } TEST(Observable, RangeWithReduceToBiggerType) { - auto observable = Observables::range(5, 6) - ->map([](int64_t v){ return (int32_t)v; }) - ->reduce([](int64_t acc, int32_t v) { return acc + v; }); - EXPECT_EQ( - run(std::move(observable)), std::vector({5})); + auto observable = + Observable<>::range(5, 1) + ->map([](int64_t v) { return (int32_t)v; }) + ->reduce([](int64_t acc, int32_t v) { return acc + v; }); + EXPECT_EQ(run(std::move(observable)), std::vector({5})); } TEST(Observable, StringReduce) { auto observable = - Observables::justN( + Observable<>::justN( {"a", "b", "c", "d", "e", "f", "g", "h", "i"}) ->reduce([](std::string acc, std::string v) { return acc + v; }); EXPECT_EQ( @@ -387,28 +607,45 @@ TEST(Observable, StringReduce) { TEST(Observable, RangeWithFilter) { auto observable = - Observables::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); + Observable<>::range(0, 10)->filter([](int64_t v) { return v % 2 != 0; }); EXPECT_EQ(run(std::move(observable)), std::vector({1, 3, 5, 7, 9})); } TEST(Observable, SimpleTake) { EXPECT_EQ( - run(Observables::range(0, 100)->take(3)), + run(Observable<>::range(0, 100)->take(3)), std::vector({0, 1, 2})); + + EXPECT_EQ( + run(Observable<>::range(0, 100)->take(0)), std::vector({})); +} + +TEST(Observable, TakeError) { + auto take0 = + Observable::error(std::runtime_error("something broke!")) + ->take(0); + + auto collector = std::make_shared>(); + take0->subscribe(collector); + + EXPECT_EQ(collector->values(), std::vector({})); + EXPECT_TRUE(collector->complete()); + EXPECT_FALSE(collector->error()); } TEST(Observable, SimpleSkip) { EXPECT_EQ( - run(Observables::range(0, 10)->skip(8)), std::vector({8, 9})); + run(Observable<>::range(0, 10)->skip(8)), std::vector({8, 9})); } TEST(Observable, OverflowSkip) { - EXPECT_EQ(run(Observables::range(0, 10)->skip(12)), std::vector({})); + EXPECT_EQ( + run(Observable<>::range(0, 10)->skip(12)), std::vector({})); } TEST(Observable, IgnoreElements) { - auto collector = make_ref>(); - auto observable = Observables::range(0, 105)->ignoreElements()->map( + auto collector = std::make_shared>(); + auto observable = Observable<>::range(0, 105)->ignoreElements()->map( [](int64_t v) { return v + 1; }); observable->subscribe(collector); @@ -419,8 +656,8 @@ TEST(Observable, IgnoreElements) { TEST(Observable, Error) { auto observable = - Observables::error(std::runtime_error("something broke!")); - auto collector = make_ref>(); + Observable::error(std::runtime_error("something broke!")); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->complete(), false); @@ -429,9 +666,9 @@ TEST(Observable, Error) { } TEST(Observable, ErrorPtr) { - auto observable = Observables::error( - std::make_exception_ptr(std::runtime_error("something broke!"))); - auto collector = make_ref>(); + auto observable = + Observable::error(std::runtime_error("something broke!")); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->complete(), false); @@ -440,8 +677,8 @@ TEST(Observable, ErrorPtr) { } TEST(Observable, Empty) { - auto observable = Observables::empty(); - auto collector = make_ref>(); + auto observable = Observable::empty(); + auto collector = std::make_shared>(); observable->subscribe(collector); EXPECT_EQ(collector->complete(), true); @@ -449,12 +686,12 @@ TEST(Observable, Empty) { } TEST(Observable, ObserversComplete) { - auto observable = Observables::empty(); + auto observable = Observable::empty(); bool completed = false; - auto observer = Observers::create( + auto observer = Observer::create( [](int) { unreachable(); }, - [](std::exception_ptr) { unreachable(); }, + [](folly::exception_wrapper) { unreachable(); }, [&] { completed = true; }); observable->subscribe(std::move(observer)); @@ -462,12 +699,12 @@ TEST(Observable, ObserversComplete) { } TEST(Observable, ObserversError) { - auto observable = Observables::error(std::runtime_error("Whoops")); + auto observable = Observable::error(std::runtime_error("Whoops")); bool errored = false; - auto observer = Observers::create( + auto observer = Observer::create( [](int) { unreachable(); }, - [&](std::exception_ptr) { errored = true; }, + [&](folly::exception_wrapper) { errored = true; }, [] { unreachable(); }); observable->subscribe(std::move(observer)); @@ -475,11 +712,381 @@ TEST(Observable, ObserversError) { } TEST(Observable, CancelReleasesObjects) { - auto lambda = [](Reference> observer) { + auto lambda = [](std::shared_ptr> observer) { // we will send nothing }; auto observable = Observable::create(std::move(lambda)); - auto collector = make_ref>(); + auto collector = std::make_shared>(); observable->subscribe(collector); } + +TEST(Observable, CompleteReleasesObjects) { + auto shared = std::make_shared>>(); + { + auto observable = Observable::create( + [shared](std::shared_ptr> observer) { + *shared = observer; + // onComplete releases the DoOnComplete operator + // so the lambda params will be freed + observer->onComplete(); + }) + ->doOnComplete([shared] {}); + observable->subscribe(); + } + EXPECT_EQ(1, shared->use_count()); +} + +TEST(Observable, ErrorReleasesObjects) { + auto shared = std::make_shared>>(); + { + auto observable = Observable::create( + [shared](std::shared_ptr> observer) { + *shared = observer; + // onError releases the DoOnComplete operator + // so the lambda params will be freed + observer->onError(std::runtime_error("error")); + }) + ->doOnComplete([shared] { /*never executed*/ }); + observable->subscribe(); + } + EXPECT_EQ(1, shared->use_count()); +} + +class InfiniteAsyncTestOperator : public ObservableOperator { + using Super = ObservableOperator; + + public: + InfiniteAsyncTestOperator( + std::shared_ptr> upstream, + MockFunction& checkpoint) + : upstream_(std::move(upstream)), checkpoint_(checkpoint) {} + + std::shared_ptr subscribe( + std::shared_ptr> observer) override { + auto subscription = + std::make_shared(std::move(observer), checkpoint_); + upstream_->subscribe( + // Note: implicit cast to a reference to a observer. + subscription); + return subscription; + } + + private: + class TestSubscription : public Super::OperatorSubscription { + using SuperSub = typename Super::OperatorSubscription; + + public: + ~TestSubscription() override { + t_.join(); + } + + void sendSuperNext() { + // workaround for gcc bug 58972. + SuperSub::observerOnNext(1); + } + + TestSubscription( + std::shared_ptr> observer, + MockFunction& checkpoint) + : SuperSub(std::move(observer)), checkpoint_(checkpoint) {} + + void onSubscribe(std::shared_ptr subscription) override { + SuperSub::onSubscribe(std::move(subscription)); + t_ = std::thread([this]() { + while (!isCancelled()) { + sendSuperNext(); + } + checkpoint_.Call(); + }); + } + void onNext(int /*value*/) override {} + + std::thread t_; + MockFunction& checkpoint_; + }; + + std::shared_ptr> upstream_; + MockFunction& checkpoint_; +}; + +// FIXME: This hits an ASAN heap-use-after-free. Disabling for now, but we need +// to get back to this and fix it. +TEST(Observable, DISABLED_CancelSubscriptionChain) { + std::atomic_int emitted{0}; + std::mutex m; + std::condition_variable cv; + + MockFunction checkpoint; + MockFunction checkpoint2; + MockFunction checkpoint3; + std::thread t; + auto infinite1 = + Observable::create([&](std::shared_ptr> obs) { + EXPECT_CALL(checkpoint, Call()).Times(1); + EXPECT_CALL(checkpoint2, Call()).Times(1); + EXPECT_CALL(checkpoint3, Call()).Times(1); + t = std::thread([obs, &emitted, &checkpoint]() { + while (!obs->isUnsubscribed()) { + ++emitted; + obs->onNext(0); + } + checkpoint.Call(); + }); + }); + auto infinite2 = infinite1->skip(1)->skip(1); + auto test1 = + std::make_shared(infinite2, checkpoint2); + auto test2 = + std::make_shared(test1->skip(1), checkpoint3); + auto skip = test2->skip(8); + + auto subscription = skip->subscribe([](int) {}); + + std::unique_lock lk(m); + CHECK(cv.wait_for( + lk, std::chrono::seconds(1), [&] { return emitted >= 1000; })); + + subscription->cancel(); + t.join(); + + LOG(INFO) << "cancelled after " << emitted << " items"; +} + +TEST(Observable, DoOnSubscribeTest) { + auto a = Observable::empty(); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnSubscribe([&] { checkpoint.Call(); })->subscribe(); +} + +TEST(Observable, DoOnNextTest) { + std::vector values; + auto observable = Observable<>::range(10, 14)->doOnNext( + [&](int64_t v) { values.push_back(v); }); + auto values2 = run(std::move(observable)); + EXPECT_EQ(values, values2); +} + +TEST(Observable, DoOnErrorTest) { + auto a = Observable::error(std::runtime_error("something broke!")); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnError([&](const auto&) { checkpoint.Call(); })->subscribe(); +} + +TEST(Observable, DoOnTerminateTest) { + auto a = Observable::empty(); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnTerminate([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(Observable, DoOnTerminate2Test) { + auto a = Observable::error(std::runtime_error("something broke!")); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnTerminate([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(Observable, DoOnEachTest) { + // TODO(lehecka): rewrite with concatWith + auto a = Observable::create([](std::shared_ptr> obs) { + obs->onNext(5); + obs->onError(std::runtime_error("something broke!")); + }); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()).Times(2); + a->doOnEach([&]() { checkpoint.Call(); })->subscribe(); +} + +TEST(Observable, DoOnTest) { + // TODO(lehecka): rewrite with concatWith + auto a = Observable::create([](std::shared_ptr> obs) { + obs->onNext(5); + obs->onError(std::runtime_error("something broke!")); + }); + + MockFunction checkpoint1; + EXPECT_CALL(checkpoint1, Call()); + MockFunction checkpoint2; + EXPECT_CALL(checkpoint2, Call()); + + a->doOn( + [&](int value) { + checkpoint1.Call(); + EXPECT_EQ(value, 5); + }, + [] { FAIL(); }, + [&](const auto&) { checkpoint2.Call(); }) + ->subscribe(); +} + +TEST(Observable, DoOnCancelTest) { + auto a = Observable<>::range(1, 10); + + MockFunction checkpoint; + EXPECT_CALL(checkpoint, Call()); + + a->doOnCancel([&]() { checkpoint.Call(); })->take(1)->subscribe(); +} + +TEST(Observable, DeferTest) { + int switchValue = 0; + auto observable = Observable::defer([&]() { + if (switchValue == 0) { + return Observable<>::range(1, 1); + } else { + return Observable<>::range(3, 1); + } + }); + + EXPECT_EQ(run(observable), std::vector({1})); + switchValue = 1; + EXPECT_EQ(run(observable), std::vector({3})); +} + +TEST(Observable, DeferExceptionTest) { + auto observable = + Observable::defer([&]() -> std::shared_ptr> { + throw std::runtime_error{"Too big!"}; + }); + + auto observer = std::make_shared>(); + observable->subscribe(observer); + + EXPECT_TRUE(observer->error()); + EXPECT_EQ(observer->errorMsg(), "Too big!"); +} + +TEST(Observable, ConcatWithTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto combined = first->concatWith(second); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + // Subscribe again + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); +} + +TEST(Observable, ConcatWithMultipleTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto third = Observable<>::range(10, 2); + auto fourth = Observable<>::range(15, 2); + auto firstSecond = first->concatWith(second); + auto thirdFourth = third->concatWith(fourth); + auto combined = firstSecond->concatWith(thirdFourth); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(Observable, ConcatWithExceptionTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto third = Observable::error(std::runtime_error("error")); + + auto combined = first->concatWith(second)->concatWith(third); + + auto observer = std::make_shared>(); + combined->subscribe(observer); + + EXPECT_EQ(observer->values(), std::vector({1, 2, 5, 6})); + EXPECT_TRUE(observer->error()); + EXPECT_EQ(observer->errorMsg(), "error"); +} + +TEST(Observable, ConcatWithCancelTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto combined = first->concatWith(second); + auto take0 = combined->take(0); + + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + EXPECT_EQ(run(take0), std::vector({})); +} + +TEST(Observable, ConcatWithCompleteAtSubscription) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + + auto combined = first->concatWith(second)->take(0); + EXPECT_EQ(run(combined), std::vector({})); +} + +TEST(Observable, ConcatWithVarArgsTest) { + auto first = Observable<>::range(1, 2); + auto second = Observable<>::range(5, 2); + auto third = Observable<>::range(10, 2); + auto fourth = Observable<>::range(15, 2); + + auto combined = first->concatWith(second, third, fourth); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(Observable, ConcatTest) { + auto combined = Observable::concat( + Observable<>::range(1, 2), Observable<>::range(5, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6})); + + // Observable::concat shoud not accept one parameter! + // Next line should cause compiler failure: OK! + // combined = Observable::concat(Observable<>::range(1, 2)); + + combined = Observable::concat( + Observable<>::range(1, 2), + Observable<>::range(5, 2), + Observable<>::range(10, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11})); + + combined = Observable::concat( + Observable<>::range(1, 2), + Observable<>::range(5, 2), + Observable<>::range(10, 2), + Observable<>::range(15, 2)); + EXPECT_EQ(run(combined), std::vector({1, 2, 5, 6, 10, 11, 15, 16})); +} + +TEST(Observable, ToFlowableConcat) { + // Concat a flowable with an observable. + // Convert the observable to flowable before concat. + // Use ERROR as backpressure strategy. + + // Test: Request only as much as the initial flowable provides + // - Check that the observable is not subscribed to so it doesn't flood + + auto a = yarpl::flowable::Flowable<>::range(1, 1); + auto b = Observable<>::range(2, 9)->toFlowable(BackpressureStrategy::ERROR); + + auto c = a->concatWith(b); + + uint32_t request = 1; + auto subscriber = + std::make_shared>>(request); + + std::vector v; + + EXPECT_CALL(*subscriber, onSubscribe_(_)); + EXPECT_CALL(*subscriber, onNext_(_)) + .WillRepeatedly(Invoke([&](int64_t value) { v.push_back(value); })); + EXPECT_CALL(*subscriber, onError_(_)).Times(0); + + c->subscribe(subscriber); + + // As only 1 item is requested, the second flowable will not be subscribed. So + // the observer will not flood the stream and cause ERROR. + EXPECT_EQ(v, std::vector({1})); + + // Now flood the stream + EXPECT_CALL(*subscriber, onError_(_)); + subscriber->subscription()->request(1); +} diff --git a/yarpl/test/PublishProcessorTest.cpp b/yarpl/test/PublishProcessorTest.cpp new file mode 100644 index 000000000..802f41e24 --- /dev/null +++ b/yarpl/test/PublishProcessorTest.cpp @@ -0,0 +1,219 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "yarpl/flowable/PublishProcessor.h" +#include +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" + +using namespace yarpl; +using namespace yarpl::flowable; + +TEST(PublishProcessorTest, OnNextTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 3})); + + // cancel the subscription as its a cyclic reference + subscriber->cancel(); +} + +TEST(PublishProcessorTest, OnCompleteTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + pp->onComplete(); + + EXPECT_EQ( + subscriber->values(), + std::vector({ + 1, + 2, + })); + EXPECT_TRUE(subscriber->isComplete()); + + auto subscriber2 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + EXPECT_EQ(subscriber2->values(), std::vector()); + EXPECT_TRUE(subscriber2->isComplete()); +} + +TEST(PublishProcessorTest, OnErrorTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + pp->onError(std::runtime_error("error!")); + + EXPECT_EQ( + subscriber->values(), + std::vector({ + 1, + 2, + })); + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ(subscriber->getErrorMsg(), "error!"); + + auto subscriber2 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + EXPECT_EQ(subscriber2->values(), std::vector()); + EXPECT_TRUE(subscriber2->isError()); +} + +TEST(PublishProcessorTest, OnNextMultipleSubscribersTest) { + auto pp = PublishProcessor::create(); + + auto subscriber1 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber1); + auto subscriber2 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + + EXPECT_EQ(subscriber1->values(), std::vector({1, 2, 3})); + EXPECT_EQ(subscriber2->values(), std::vector({1, 2, 3})); + + subscriber1->cancel(); + subscriber2->cancel(); +} + +TEST(PublishProcessorTest, OnNextSlowSubscriberTest) { + auto pp = PublishProcessor::create(); + + auto subscriber1 = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber1); + auto subscriber2 = std::make_shared>(1); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber2); + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + + EXPECT_EQ(subscriber1->values(), std::vector({1, 2, 3})); + subscriber1->cancel(); + + EXPECT_EQ(subscriber2->values(), std::vector({1})); + EXPECT_TRUE(subscriber2->isError()); + EXPECT_EQ( + subscriber2->exceptionWrapper().type(), + typeid(MissingBackpressureException)); +} + +TEST(PublishProcessorTest, CancelTest) { + auto pp = PublishProcessor::create(); + + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + pp->onNext(1); + pp->onNext(2); + + subscriber->cancel(); + + pp->onNext(3); + pp->onNext(4); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2})); + + subscriber->onComplete(); // to break any reference cycles +} + +TEST(PublishProcessorTest, OnMultipleSubscribersMultithreadedWithErrorTest) { + auto pp = PublishProcessor::create(); + + std::vector threads; + std::atomic threadsDone{0}; + + for (int i = 0; i < 100; i++) { + threads.push_back(std::thread([&] { + for (int j = 0; j < 100; j++) { + auto subscriber = std::make_shared>(1); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + subscriber->awaitTerminalEvent(std::chrono::milliseconds(500)); + + EXPECT_EQ(subscriber->values().size(), 1ULL); + + EXPECT_TRUE(subscriber->isError()); + EXPECT_EQ( + subscriber->exceptionWrapper().type(), + typeid(MissingBackpressureException)); + } + ++threadsDone; + })); + } + + int k = 0; + while (threadsDone < threads.size()) { + pp->onNext(k++); + } + + for (auto& thread : threads) { + thread.join(); + } +} + +TEST(PublishProcessorTest, OnMultipleSubscribersMultithreadedTest) { + auto pp = PublishProcessor::create(); + + std::vector threads; + std::atomic subscribersReady{0}; + std::atomic threadsDone{0}; + + for (int i = 0; i < 100; i++) { + threads.push_back(std::thread([&] { + auto subscriber = std::make_shared>(); + pp->toFlowable(BackpressureStrategy::ERROR)->subscribe(subscriber); + + ++subscribersReady; + subscriber->awaitTerminalEvent(std::chrono::milliseconds(50)); + + EXPECT_EQ(subscriber->values(), std::vector({1, 2, 3, 4, 5})); + EXPECT_FALSE(subscriber->isError()); + EXPECT_TRUE(subscriber->isComplete()); + + ++threadsDone; + })); + } + + while (subscribersReady < threads.size()) + ; + + pp->onNext(1); + pp->onNext(2); + pp->onNext(3); + pp->onNext(4); + pp->onNext(5); + pp->onComplete(); + + for (auto& thread : threads) { + thread.join(); + } +} diff --git a/yarpl/test/RefcountedTest.cpp b/yarpl/test/RefcountedTest.cpp deleted file mode 100644 index 49f18a39f..000000000 --- a/yarpl/test/RefcountedTest.cpp +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include - -#include "yarpl/Refcounted.h" - -namespace yarpl { - -TEST(RefcountedTest, ObjectCountsAreMaintained) { - std::vector> v; - for (std::size_t i = 0; i < 16; ++i) { - v.push_back(std::make_unique()); - EXPECT_EQ(0U, v[i]->count()); // no references. - } - - v.resize(11); -} - -TEST(RefcountedTest, ReferenceCountingWorks) { - auto first = Reference(new Refcounted); - EXPECT_EQ(1U, first->count()); - - auto second = first; - - EXPECT_EQ(second.get(), first.get()); - EXPECT_EQ(2U, first->count()); - - auto third = std::move(second); - EXPECT_EQ(nullptr, second.get()); - EXPECT_EQ(third.get(), first.get()); - EXPECT_EQ(2U, first->count()); - - // second was already moved from, above. - second.reset(); - EXPECT_EQ(nullptr, second.get()); - EXPECT_EQ(2U, first->count()); - - auto fourth = third; - EXPECT_EQ(3U, first->count()); - - fourth.reset(); - EXPECT_EQ(nullptr, fourth.get()); - EXPECT_EQ(2U, first->count()); -} -} // yarpl diff --git a/yarpl/test/ReferenceTest.cpp b/yarpl/test/ReferenceTest.cpp deleted file mode 100644 index 3bc462a33..000000000 --- a/yarpl/test/ReferenceTest.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include - -#include - -#include "yarpl/Flowable.h" -#include "yarpl/Refcounted.h" - -using yarpl::Refcounted; -using yarpl::Reference; -using yarpl::flowable::Subscriber; - -namespace { - -template -class MySubscriber : public Subscriber { - void onNext(T) override {} -}; -} - -TEST(ReferenceTest, Upcast) { - Reference> derived(new MySubscriber()); - Reference> base1(derived); - - Reference> base2; - base2 = derived; - - Reference> derivedCopy1(derived); - Reference> derivedCopy2(derived); - - Reference> base3(std::move(derivedCopy1)); - - Reference> base4; - base4 = std::move(derivedCopy2); -} - -TEST(RefcountedTest, CopyAssign) { - using Sub = MySubscriber; - Reference a(new Sub()); - Reference b(a); - EXPECT_EQ(2u, a->count()); - Sub* ptr = nullptr; - Reference c(ptr = new Sub()); - b = c; - EXPECT_EQ(1u, a->count()); - EXPECT_EQ(ptr, b.get()); -} - -TEST(RefcountedTest, MoveAssign) { - using Sub = MySubscriber; - Reference a(new Sub()); - Reference b(a); - EXPECT_EQ(2u, a->count()); - Sub* ptr = nullptr; - b = Reference(ptr = new Sub()); - EXPECT_EQ(1u, a->count()); - EXPECT_EQ(ptr, b.get()); -} - -TEST(RefcountedTest, CopyAssignTemplate) { - using Sub = MySubscriber; - Reference a(new Sub()); - Reference b(a); - EXPECT_EQ(2u, a->count()); - using Sub2 = MySubscriber; - Sub2* ptr = nullptr; - Reference c(ptr = new Sub2()); - b = c; - EXPECT_EQ(1u, a->count()); - EXPECT_EQ(ptr, b.get()); -} - -TEST(RefcountedTest, MoveAssignTemplate) { - using Sub = MySubscriber; - Reference a(new Sub()); - Reference b(a); - EXPECT_EQ(2u, a->count()); - using Sub2 = MySubscriber; - Sub2* ptr = nullptr; - b = Reference(ptr = new Sub2()); - EXPECT_EQ(1u, a->count()); - EXPECT_EQ(ptr, b.get()); -} diff --git a/yarpl/test/Scheduler_test.cpp b/yarpl/test/Scheduler_test.cpp deleted file mode 100644 index c9f072bd1..000000000 --- a/yarpl/test/Scheduler_test.cpp +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include -#include -#include -#include -#include "yarpl/schedulers/ThreadScheduler.h" - -using namespace yarpl; - -TEST(Scheduler, ThreadScheduler_Task) { - ThreadScheduler scheduler; - auto worker = scheduler.createWorker(); - worker->schedule([]() { - std::cout << "doing work on thread id: " << std::this_thread::get_id() - << std::endl; - }); - worker->dispose(); - - // TODO add condition variable into task above instead of sleep - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - // TODO add validation of above, right now just testing it doesn't blow up -} diff --git a/yarpl/test/Single_test.cpp b/yarpl/test/Single_test.cpp index 1dd5135b1..48d3c666e 100644 --- a/yarpl/test/Single_test.cpp +++ b/yarpl/test/Single_test.cpp @@ -1,20 +1,30 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. -#include +#include +#include #include #include #include "yarpl/Single.h" #include "yarpl/single/SingleTestObserver.h" +#include "yarpl/test_utils/Tuple.h" -#include "Tuple.h" - -// TODO can we eliminate need to import both of these? -using namespace yarpl; using namespace yarpl::single; TEST(Single, SingleOnNext) { - auto a = Single::create([](Reference> obs) { + auto a = Single::create([](std::shared_ptr> obs) { obs->onSubscribe(SingleSubscriptions::empty()); obs->onSuccess(1); }); @@ -27,12 +37,9 @@ TEST(Single, SingleOnNext) { TEST(Single, OnError) { std::string errorMessage("DEFAULT->No Error Message"); - auto a = Single::create([](Reference> obs) { - try { - throw std::runtime_error("something broke!"); - } catch (const std::exception&) { - obs->onError(std::current_exception()); - } + auto a = Single::create([](std::shared_ptr> obs) { + obs->onError( + folly::exception_wrapper(std::runtime_error("something broke!"))); }); auto to = SingleTestObserver::create(); @@ -61,13 +68,27 @@ TEST(Single, Error) { } TEST(Single, SingleMap) { - auto a = Single::create([](Reference> obs) { + auto a = Single::create([](std::shared_ptr> obs) { obs->onSubscribe(SingleSubscriptions::empty()); obs->onSuccess(1); }); auto to = SingleTestObserver::create(); - a->map([](int v) { return "hello"; })->subscribe(to); + a->map([](int) { return "hello"; })->subscribe(to); to->awaitTerminalEvent(); to->assertOnSuccessValue("hello"); } + +TEST(Single, MapWithException) { + auto single = Singles::just(3)->map([](int n) { + if (n > 2) { + throw std::runtime_error{"Too big!"}; + } + return n; + }); + + auto observer = std::make_shared>(); + single->subscribe(observer); + + observer->assertOnErrorMessage("Too big!"); +} diff --git a/yarpl/test/SubscribeObserveOnTests.cpp b/yarpl/test/SubscribeObserveOnTests.cpp new file mode 100644 index 000000000..dc95e28d9 --- /dev/null +++ b/yarpl/test/SubscribeObserveOnTests.cpp @@ -0,0 +1,219 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include "yarpl/Flowable.h" +#include "yarpl/Observable.h" +#include "yarpl/flowable/TestSubscriber.h" +#include "yarpl/observable/TestObserver.h" + +using namespace yarpl::flowable; +using namespace yarpl::observable; + +constexpr std::chrono::milliseconds timeout{100}; + +TEST(FlowableTests, SubscribeOnWorksAsExpected) { + folly::ScopedEventBaseThread worker; + + auto f = Flowable::create([&](auto& subscriber, auto req) { + EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); + EXPECT_EQ(1, req); + subscriber.onNext("foo"); + subscriber.onComplete(); + }); + + auto subscriber = std::make_shared>(1); + f->subscribeOn(*worker.getEventBase())->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::milliseconds(100)); + EXPECT_EQ(1, subscriber->getValueCount()); + EXPECT_TRUE(subscriber->isComplete()); +} + +TEST(ObservableTests, SubscribeOnWorksAsExpected) { + folly::ScopedEventBaseThread worker; + + auto f = Observable::create([&](auto observer) { + EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); + observer->onNext("foo"); + observer->onComplete(); + }); + + auto observer = std::make_shared>(); + f->subscribeOn(*worker.getEventBase())->subscribe(observer); + observer->awaitTerminalEvent(std::chrono::milliseconds(100)); + EXPECT_EQ(1, observer->getValueCount()); + EXPECT_TRUE(observer->isComplete()); +} + +TEST(FlowableTests, ObserveOnWorksAsExpectedSuccess) { + folly::ScopedEventBaseThread worker; + folly::Baton<> subscriber_complete; + + auto f = Flowable::create([&](auto& subscriber, auto req) { + EXPECT_EQ(1, req); + subscriber.onNext("foo"); + subscriber.onComplete(); + }); + + bool calledOnNext{false}; + + f->observeOn(*worker.getEventBase()) + ->subscribe( + // onNext + [&](std::string s) { + EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); + EXPECT_EQ(s, "foo"); + calledOnNext = true; + }, + + // onError + [&](folly::exception_wrapper) { FAIL(); }, + + // onComplete + [&] { + EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); + EXPECT_TRUE(calledOnNext); + subscriber_complete.post(); + }, + + 1 /* initial request(n) */ + ); + + subscriber_complete.timed_wait(timeout); +} + +TEST(FlowableTests, ObserveOnWorksAsExpectedError) { + folly::ScopedEventBaseThread worker; + folly::Baton<> subscriber_complete; + + auto f = Flowable::create([&](auto& subscriber, auto req) { + EXPECT_EQ(1, req); + subscriber.onError(std::runtime_error("oops!")); + }); + + f->observeOn(*worker.getEventBase()) + ->subscribe( + // onNext + [&](std::string s) { FAIL(); }, + + // onError + [&](folly::exception_wrapper) { + EXPECT_TRUE(worker.getEventBase()->isInEventBaseThread()); + subscriber_complete.post(); + }, + + // onComplete + [&] { FAIL(); }, + + 1 /* initial request(n) */ + ); + + subscriber_complete.timed_wait(timeout); +} + +TEST(FlowableTests, BothObserveAndSubscribeOn) { + folly::ScopedEventBaseThread subscriber_eb; + folly::ScopedEventBaseThread producer_eb; + folly::Baton<> subscriber_complete; + + auto f = Flowable::create([&](auto& subscriber, auto req) { + EXPECT_EQ(1, req); + EXPECT_TRUE(producer_eb.getEventBase()->isInEventBaseThread()); + subscriber.onNext("foo"); + subscriber.onComplete(); + }) + ->subscribeOn(*producer_eb.getEventBase()) + ->observeOn(*subscriber_eb.getEventBase()); + + bool calledOnNext{false}; + + f->subscribe( + // onNext + [&](std::string s) { + EXPECT_TRUE(subscriber_eb.getEventBase()->isInEventBaseThread()); + EXPECT_EQ(s, "foo"); + calledOnNext = true; + }, + + // onError + [&](folly::exception_wrapper) { FAIL(); }, + + // onComplete + [&] { + EXPECT_TRUE(subscriber_eb.getEventBase()->isInEventBaseThread()); + EXPECT_TRUE(calledOnNext); + subscriber_complete.post(); + }, + + 1 /* initial request(n) */ + ); + + subscriber_complete.timed_wait(timeout); +} + +namespace { +class EarlyCancelSubscriber : public yarpl::flowable::BaseSubscriber { + public: + EarlyCancelSubscriber( + folly::EventBase& on_base, + folly::Baton<>& subscriber_complete) + : on_base_(on_base), subscriber_complete_(subscriber_complete) {} + + void onSubscribeImpl() override { + this->request(5); + } + + void onNextImpl(int64_t n) override { + if (did_cancel_) { + FAIL(); + } + + EXPECT_TRUE(on_base_.isInEventBaseThread()); + EXPECT_EQ(n, 1); + this->cancel(); + did_cancel_ = true; + subscriber_complete_.post(); + } + + void onErrorImpl(folly::exception_wrapper /*e*/) override { + FAIL(); + } + + void onCompleteImpl() override { + FAIL(); + } + + bool did_cancel_{false}; + folly::EventBase& on_base_; + folly::Baton<>& subscriber_complete_; +}; +} // namespace + +TEST(FlowableTests, EarlyCancelObserveOn) { + folly::ScopedEventBaseThread worker; + + folly::Baton<> subscriber_complete; + + Flowable<>::range(1, 100) + ->observeOn(*worker.getEventBase()) + ->subscribe(std::make_shared( + *worker.getEventBase(), subscriber_complete)); + + subscriber_complete.timed_wait(timeout); +} diff --git a/yarpl/test/ThriftStreamShimTest.cpp b/yarpl/test/ThriftStreamShimTest.cpp new file mode 100644 index 000000000..78d178bbb --- /dev/null +++ b/yarpl/test/ThriftStreamShimTest.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "yarpl/Flowable.h" +#include "yarpl/flowable/TestSubscriber.h" +#include "yarpl/flowable/ThriftStreamShim.h" + +using namespace yarpl::flowable; + +template +std::vector run( + std::shared_ptr> flowable, + int64_t requestCount = 100) { + auto subscriber = std::make_shared>(requestCount); + flowable->subscribe(subscriber); + subscriber->awaitTerminalEvent(std::chrono::seconds(1)); + return std::move(subscriber->values()); +} +template +std::vector run(apache::thrift::ServerStream&& stream) { + std::vector values; + std::move(stream).toClientStreamUnsafeDoNotUse().subscribeInline([&](auto&& val) { + if (val.hasValue()) { + values.push_back(std::move(*val)); + } + }); + return values; +} + +apache::thrift::ClientBufferedStream makeRange(int start, int count) { + auto streamAndPublisher = + apache::thrift::ServerStream::createPublisher(); + for (int i = 0; i < count; ++i) { + streamAndPublisher.second.next(i + start); + } + std::move(streamAndPublisher.second).complete(); + return std::move(streamAndPublisher.first).toClientStreamUnsafeDoNotUse(); +} + +TEST(ThriftStreamShimTest, ClientStream) { + auto flowable = ThriftStreamShim::fromClientStream( + makeRange(1, 5), folly::getEventBase()); + EXPECT_EQ(run(flowable), std::vector({1, 2, 3, 4, 5})); +} + +TEST(ThriftStreamShimTest, ServerStream) { + auto stream = ThriftStreamShim::toServerStream(Flowable<>::range(1, 5)); + EXPECT_EQ(run(std::move(stream)), std::vector({1, 2, 3, 4, 5})); + + stream = ThriftStreamShim::toServerStream(Flowable::never()); + auto sub = std::move(stream).toClientStreamUnsafeDoNotUse().subscribeExTry( + folly::getEventBase(), [](auto) {}); + sub.cancel(); + std::move(sub).join(); + + ThriftStreamShim::toServerStream(Flowable<>::just(std::make_unique(42))); +} diff --git a/yarpl/test/Tuple.cpp b/yarpl/test/Tuple.cpp deleted file mode 100644 index 9d23a2cc9..000000000 --- a/yarpl/test/Tuple.cpp +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright 2004-present Facebook. All Rights Reserved. - -#include "Tuple.h" - -namespace yarpl { - -std::atomic Tuple::createdCount; -std::atomic Tuple::destroyedCount; -std::atomic Tuple::instanceCount; -} diff --git a/yarpl/test/credits-test.cpp b/yarpl/test/credits-test.cpp index 26e2206ca..41d23880d 100644 --- a/yarpl/test/credits-test.cpp +++ b/yarpl/test/credits-test.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include @@ -76,6 +88,14 @@ TEST(Credits, cancel3) { ASSERT_TRUE(isCancelled(&rn)); } +TEST(Credits, cancel4) { + std::atomic rn{9999}; + cancel(&rn); + // it should stay cancelled once cancelled + consume(&rn, 1); + ASSERT_TRUE(isCancelled(&rn)); +} + TEST(Credits, isInfinite) { std::atomic rn{0}; add(&rn, INT64_MAX); diff --git a/yarpl/test/test_has_shared_ptr_support.cpp b/yarpl/test/test_has_shared_ptr_support.cpp new file mode 100644 index 000000000..61bcbe004 --- /dev/null +++ b/yarpl/test/test_has_shared_ptr_support.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main() { + std::shared_ptr i; + auto il = std::atomic_load(&i); + return 0; +} diff --git a/yarpl/test/test_wrap_shared_in_atomic_support.cpp b/yarpl/test/test_wrap_shared_in_atomic_support.cpp new file mode 100644 index 000000000..136e87f2b --- /dev/null +++ b/yarpl/test/test_wrap_shared_in_atomic_support.cpp @@ -0,0 +1,23 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main() { + std::atomic> i; + std::shared_ptr j; + std::atomic_store(&i, j); + return 0; +} diff --git a/yarpl/test/yarpl-tests.cpp b/yarpl/test/yarpl-tests.cpp index bc5bdf61a..5046e46dd 100644 --- a/yarpl/test/yarpl-tests.cpp +++ b/yarpl/test/yarpl-tests.cpp @@ -1,8 +1,31 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include #include +#include "yarpl/Refcounted.h" + int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + int ret; + { + FLAGS_logtostderr = true; + ::testing::InitGoogleTest(&argc, argv); + folly::init(&argc, &argv); + ret = RUN_ALL_TESTS(); + } + + return ret; } diff --git a/test/test_utils/Mocks.h b/yarpl/test_utils/Mocks.h similarity index 52% rename from test/test_utils/Mocks.h rename to yarpl/test_utils/Mocks.h index f5a788a80..8662fcdf1 100644 --- a/test/test_utils/Mocks.h +++ b/yarpl/test_utils/Mocks.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -10,22 +22,22 @@ #include #include -#include "rsocket/framing/FrameProcessor.h" -#include "rsocket/internal/Common.h" #include "yarpl/flowable/Flowable.h" -namespace rsocket { - -using namespace yarpl::flowable; +namespace yarpl { +namespace mocks { /// GoogleMock-compatible Publisher implementation for fast prototyping. /// UnmanagedMockPublisher's lifetime MUST be managed externally. template -class MockFlowable : public Flowable { +class MockFlowable : public flowable::Flowable { public: - MOCK_METHOD1_T(subscribe_, void(yarpl::Reference> subscriber)); + MOCK_METHOD1_T( + subscribe_, + void(std::shared_ptr> subscriber)); - void subscribe(yarpl::Reference> subscriber) noexcept override { + void subscribe( + std::shared_ptr> subscriber) noexcept override { subscribe_(std::move(subscriber)); } }; @@ -35,17 +47,23 @@ class MockFlowable : public Flowable { /// For the same reason putting mock instance in a smart pointer is a poor idea. /// Can only be instanciated for CopyAssignable E type. template -class MockSubscriber : public Subscriber { +class MockSubscriber : public flowable::Subscriber, + public yarpl::enable_get_ref { public: - MOCK_METHOD1(onSubscribe_, void(yarpl::Reference subscription)); + MOCK_METHOD1( + onSubscribe_, + void(std::shared_ptr subscription)); MOCK_METHOD1_T(onNext_, void(const T& value)); MOCK_METHOD0(onComplete_, void()); - MOCK_METHOD1_T(onError_, void(std::exception_ptr ex)); + MOCK_METHOD1_T(onError_, void(folly::exception_wrapper ex)); - explicit MockSubscriber(int64_t initial = kMaxRequestN) : initial_(initial) {} + explicit MockSubscriber(int64_t initial = std::numeric_limits::max()) + : initial_(initial) {} - void onSubscribe(yarpl::Reference subscription) override { + void onSubscribe( + std::shared_ptr subscription) override { subscription_ = subscription; + auto this_ = this->ref_from_this(this); onSubscribe_(subscription); if (initial_ > 0) { @@ -54,6 +72,7 @@ class MockSubscriber : public Subscriber { } void onNext(T element) override { + auto this_ = this->ref_from_this(this); onNext_(element); --waitedFrameCount_; @@ -61,19 +80,21 @@ class MockSubscriber : public Subscriber { } void onComplete() override { + auto this_ = this->ref_from_this(this); onComplete_(); subscription_.reset(); terminated_ = true; terminalEventCV_.notify_all(); } - void onError(std::exception_ptr ex) override { - onError_(ex); + void onError(folly::exception_wrapper ex) override { + auto this_ = this->ref_from_this(this); + onError_(std::move(ex)); terminated_ = true; terminalEventCV_.notify_all(); } - Subscription* subscription() const { + flowable::Subscription* subscription() const { return subscription_.operator->(); } @@ -85,25 +106,22 @@ class MockSubscriber : public Subscriber { // now block this thread std::unique_lock lk(m_); // if shutdown gets implemented this would then be released by it - bool result = terminalEventCV_.wait_for( - lk, timeout, [this] { - return terminated_; - }); + bool result = + terminalEventCV_.wait_for(lk, timeout, [this] { return terminated_; }); EXPECT_TRUE(result) << "Timed out"; } /** * Block the current thread until onNext is called 'count' times. */ - void awaitFrames(uint64_t count, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { + void awaitFrames( + uint64_t count, + std::chrono::milliseconds timeout = std::chrono::seconds(1)) { waitedFrameCount_ += count; std::unique_lock lk(mFrame_); if (waitedFrameCount_ > 0) { bool result = framesEventCV_.wait_for( - lk, timeout, [this] { - return waitedFrameCount_ <= 0; - }); + lk, timeout, [this] { return waitedFrameCount_ <= 0; }); EXPECT_TRUE(result) << "Timed out"; } } @@ -111,9 +129,9 @@ class MockSubscriber : public Subscriber { protected: // As the 'subscription_' member in the parent class is private, // we define it here again. - yarpl::Reference subscription_; + std::shared_ptr subscription_; - int64_t initial_{kMaxRequestN}; + int64_t initial_; bool terminated_{false}; mutable std::mutex m_, mFrame_; @@ -124,28 +142,29 @@ class MockSubscriber : public Subscriber { /// GoogleMock-compatible Subscriber implementation for fast prototyping. /// MockSubscriber MUST be heap-allocated, as it manages its own lifetime. /// For the same reason putting mock instance in a smart pointer is a poor idea. -class MockSubscription : public Subscription { +class MockSubscription : public flowable::Subscription { public: MOCK_METHOD1(request_, void(int64_t n)); MOCK_METHOD0(cancel_, void()); void request(int64_t n) override { - if (!requested_) { - requested_ = true; - EXPECT_CALL(checkpoint_, Call()).Times(1); - } - request_(n); } void cancel() override { cancel_(); - checkpoint_.Call(); } +}; +} // namespace mocks - protected: - bool requested_{false}; - testing::MockFunction checkpoint_; +template +class MockBaseSubscriber + : public flowable::BaseSubscriber { + public: + MOCK_METHOD0_T(onSubscribeImpl, void()); + MOCK_METHOD1_T(onNextImpl, void(T)); + MOCK_METHOD0_T(onCompleteImpl, void()); + MOCK_METHOD1_T(onErrorImpl, void(folly::exception_wrapper)); }; -} // namespace rsocket +} // namespace yarpl diff --git a/yarpl/test_utils/Tuple.cpp b/yarpl/test_utils/Tuple.cpp new file mode 100644 index 000000000..0ab948c42 --- /dev/null +++ b/yarpl/test_utils/Tuple.cpp @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Tuple.h" + +namespace yarpl { + +std::atomic Tuple::createdCount; +std::atomic Tuple::destroyedCount; +std::atomic Tuple::instanceCount; +} // namespace yarpl diff --git a/yarpl/test/Tuple.h b/yarpl/test_utils/Tuple.h similarity index 61% rename from yarpl/test/Tuple.h rename to yarpl/test_utils/Tuple.h index ec1829132..663e29637 100644 --- a/yarpl/test/Tuple.h +++ b/yarpl/test_utils/Tuple.h @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once @@ -43,4 +55,4 @@ struct Tuple { static std::atomic instanceCount; }; -} // yarpl +} // namespace yarpl diff --git a/yarpl/src/yarpl/utils/credits.cpp b/yarpl/utils/credits.cpp similarity index 51% rename from yarpl/src/yarpl/utils/credits.cpp rename to yarpl/utils/credits.cpp index 12687fddf..28cd9e9b5 100644 --- a/yarpl/src/yarpl/utils/credits.cpp +++ b/yarpl/utils/credits.cpp @@ -1,4 +1,16 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "yarpl/utils/credits.h" @@ -8,7 +20,7 @@ namespace yarpl { namespace credits { -int64_t add(std::atomic* current, int64_t n) { +int64_t add(std::atomic* current, int64_t n) { for (;;) { auto r = current->load(); // if already "infinite" @@ -52,7 +64,7 @@ int64_t add(int64_t current, int64_t n) { return current + n; } -bool cancel(std::atomic* current) { +bool cancel(std::atomic* current) { for (;;) { auto r = current->load(); if (r == kCanceled) { @@ -67,9 +79,17 @@ bool cancel(std::atomic* current) { } } -int64_t consume(std::atomic* current, int64_t n) { +int64_t consume(std::atomic* current, int64_t n) { for (;;) { auto r = current->load(); + // if already "infinite" + if (r == kNoFlowControl) { + return kNoFlowControl; + } + // if already "cancelled" + if (r == kCanceled) { + return kCanceled; + } if (n <= 0) { // do nothing, return existing unmodified value return r; @@ -89,11 +109,47 @@ int64_t consume(std::atomic* current, int64_t n) { } } -bool isCancelled(std::atomic* current) { +bool tryConsume(std::atomic* current, int64_t n) { + if (n <= 0) { + // do nothing, return existing unmodified value + return false; + } + + for (;;) { + auto r = current->load(); + if (r < n) { + return false; + } + + auto u = r - n; + + // set the new number + if (current->compare_exchange_strong(r, u)) { + return true; + } + // if failed to set (concurrent modification) loop and try again + } +} + +bool isCancelled(std::atomic* current) { return current->load() == kCanceled; } -bool isInfinite(std::atomic* current) { +int64_t consume(int64_t& current, int64_t n) { + if (n <= 0) { + // do nothing, return existing unmodified value + return current; + } + if (current < n) { + // bad usage somewhere ... be resilient, just set to r + n = current; + } + + current -= n; + return current; +} + +bool isInfinite(std::atomic* current) { return current->load() == kNoFlowControl; } diff --git a/yarpl/include/yarpl/utils/credits.h b/yarpl/utils/credits.h similarity index 63% rename from yarpl/include/yarpl/utils/credits.h rename to yarpl/utils/credits.h index 39da9fb15..10063b728 100644 --- a/yarpl/include/yarpl/utils/credits.h +++ b/yarpl/utils/credits.h @@ -1,10 +1,23 @@ -// Copyright 2004-present Facebook. All Rights Reserved. +// Copyright (c) Facebook, Inc. and its affiliates. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #pragma once #include #include #include +#include namespace yarpl { namespace credits { @@ -34,6 +47,8 @@ constexpr int64_t kNoFlowControl{std::numeric_limits::max()}; * * If 'current' is set to "cancelled" using the magic number INT64_MIN it will * not be changed. + * + * Returns new value of credits. */ int64_t add(std::atomic*, int64_t); @@ -54,9 +69,23 @@ bool cancel(std::atomic*); * Consume (remove) credits from the 'current' atomic. * * This MUST only be used to remove credits after emitting a value via onNext. + * + * Returns new value of credits. */ int64_t consume(std::atomic*, int64_t); +/** + * Try Consume (remove) credits from the 'current' atomic. + * + * Returns true if consuming the credit was successful. + */ +bool tryConsume(std::atomic*, int64_t); + +/** + * Version of consume that works for non-atomic integers. + */ +int64_t consume(int64_t&, int64_t); + /** * Whether the current value represents a "cancelled" subscription. */ @@ -67,5 +96,5 @@ bool isCancelled(std::atomic*); */ bool isInfinite(std::atomic*); -} -} +} // namespace credits +} // namespace yarpl