diff --git a/.github/workflows/gradle-all.yml b/.github/workflows/gradle-all.yml new file mode 100644 index 000000000..abbd14106 --- /dev/null +++ b/.github/workflows/gradle-all.yml @@ -0,0 +1,152 @@ +name: Branches Java CI + +on: + # Trigger the workflow on push + # but only for the non master/1.0.x branches + push: + branches-ignore: + - 1.1.x + - master + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Publish Packages to Artifactory + if: ${{ matrix.jdk == '1.8' }} + run: | + githubRef="${githubRef#refs/heads/}" + githubRef="${githubRef////-}" + ./gradlew -PversionSuffix="-${githubRef}-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToGitHubPackagesRepository --no-daemon --stacktrace + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + githubRef: ${{ github.ref }} + buildNumber: ${{ github.run_number }} \ No newline at end of file diff --git a/.github/workflows/gradle-main.yml b/.github/workflows/gradle-main.yml new file mode 100644 index 000000000..33bca8e72 --- /dev/null +++ b/.github/workflows/gradle-main.yml @@ -0,0 +1,161 @@ +name: Main Branches Java CI + +on: + # Trigger the workflow on push + # but only for the master/1.1.x branch + push: + branches: + - master + - 1.1.x + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Publish Packages to Artifactory + if: ${{ matrix.jdk == '1.8' }} + run: ./gradlew -PversionSuffix="-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToSonatypeRepository --no-daemon --stacktrace + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} + - name: Aggregate test reports with ciMate + if: always() + continue-on-error: true + env: + CIMATE_PROJECT_ID: m84qx17y + run: | + wget -q https://get.cimate.io/release/linux/cimate + chmod +x cimate + ./cimate "**/TEST-*.xml" \ No newline at end of file diff --git a/.github/workflows/gradle-pr.yml b/.github/workflows/gradle-pr.yml new file mode 100644 index 000000000..cecca085f --- /dev/null +++ b/.github/workflows/gradle-pr.yml @@ -0,0 +1,111 @@ +name: Pull Request Java CI + +on: [pull_request] + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon \ No newline at end of file diff --git a/.github/workflows/gradle-release.yml b/.github/workflows/gradle-release.yml new file mode 100644 index 000000000..922eb0e3e --- /dev/null +++ b/.github/workflows/gradle-release.yml @@ -0,0 +1,44 @@ +name: Release Java CI + +on: + # Trigger the workflow on push + push: + # Sequence of patterns matched against refs/tags + tags: + - '*' # Push events to matching *, i.e. 1.0, 20.15.10 + +jobs: + publish: + + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK 1.8 + uses: actions/setup-java@v1 + with: + java-version: 1.8 + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test + - name: Publish Packages to Sonotype + run: ./gradlew -Pversion="${githubRef#refs/tags/}" -PbuildNumber="${buildNumber}" sign publishMavenPublicationToSonatypeRepository + env: + githubRef: ${{ github.ref }} + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 4722957c8..000000000 --- a/.travis.yml +++ /dev/null @@ -1,45 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---- -language: java - -dist: trusty - -matrix: - include: - - jdk: openjdk8 - - jdk: openjdk11 - env: SKIP_RELEASE=true - - jdk: openjdk14 - env: SKIP_RELEASE=true - -env: - global: - - secure: "WBCy0hsF96Xybj4n0AUrGY2m5FWCUa30XR+aVElSOO8d7v7BMypAT8mAd+yC2Y+j8WUGpIv59CqgeK1JrYdR9b3qRKhJmoE1Q92TotrxXMTIC9OKuU51LaaOqGYqx4SqiA2AyaikTFPd8um7KZfUpW/dG4IXySsiJ2OKT1jMUq6TmbWHnAYtjbl3u3WdjBQTIZNMtqG1+H1vIpsWyZrvbB4TWlNzhKBAu/YnlzMtvStrDaF7XrCJ2BQdMomQO18NH2gWxUEvLbQb6ip3wFl9CRe6vID7K1dmFwm08RPt9hRPC9yDahlIy8VvuNcWrP42TV+BVYy8V/hfaIo1pPsDBrtmVyc7YZjXSUM68orDFOkRB35qGkNIaAhy5Yt6G9QfwLXJkDFofW5KMKtDFUzf+j4DwS0CiDMF4k6Qq7YN1tYFXE9R8xa6Gv+wTNHqs4RURbYMS9IlbkhKxNbtyuema2sIUbsIfDezIzLI5BnfH2uli7O6/z0/G0Vfmf6A4q5Olm+7uhzMTI0GKheUIKr16SOxABlrwJtLJftzoKz9hYd3b7C9t61vYzccC3rWYobplwIcK2w50gFHQS8HLeiCjo8yjCx+IRSvAGaZIBPQdHCktrEYCVDUTXOxdaD6k6Ef+ppm8Nn+M+iC8x/G1wYE4x1lDqHw3GfhKsEQmiHL/98=" - - secure: "mbB+rv9eWUFQ9/yr2REH2ztH6r/Uq7cq/OJ5WK6yFp0TmPzlJ8jbEVwe/sdAMW2E4qrfMu1c2h3qsVm41pNx0MwEsIW/lTIZRiRmNYon32n+SHlRWyTn8dJeY/p1HoHs450OjLgB4X4jmRmfSt8IQ/w9ZCjF6HVcgR4ctt+myECTNcRidEIOahljnSJmnFFDsKbt2UJN96AfvvhbxcarEKgKLXLd9tQT2GlvEOM+hVOY9hKD5FvIoRp9heyCEAsSBXe+MIWQlh4jx+B4zCajZJ+8KN6M8KIt40lV8z4Zbc11jgq/xULJwkQIuVZvkJ3huIfUrxwLPgYWeai/TR/m3+2jy1hFajt96pnhJzFEz0IBL0wFALwAY1n2R/6uugEUYnDsFcGQGTsO5OeeOixiRPH5HNgfOhInqJoFh/887f+gq7OLXjlRCTsw+S9KknZ3iBpHX/+khurfAUC9khiMvufEq6Wyu0TvxhmGERFrs7uugeJ1VA85SDVQ6Au9MV831PeBGqzHpYG7w2kJj1EiFjBRUhCthxyDfX2b04egozlKF8JEifZ9EVj7pNMQUvVG2c9Wj6M0fG84NusnlZlA16XxAmfLevc9b/BOSSrqc2r9Z1ZvxFnBPP9H94Uqt9ZninhW/T49jRF+lQzD45MTVogzVk77XtdpzUemf4t5mHc=" - - secure: "GcPu3U4o2Dp7QLCqaAo3mGMJTl9yd+w+elXqqt7WDjrjm5p8mrzvQfyiJA7mRJVDTGpgib8fLctL1X1+QOX4fNKElrDUFhE3bWAqwVwHGPK4D3HCb6THD5XVqE4qcPmdLWPkvJ9ZY5nSIfuRVASjZTcc4XSXISK2jUSGar0PNYlo62/OFGvNvMz/qINU9RU7iYdDlL19yd72TKDfuK0UOKhQEGypamEHam3SMNCw/p8Q5K1vQe+Oba3ILCvYHJvqWc2NLjRXJjXfIaOq/NpCK6Lx2U9etdpkb5lyW5Cx1lkzIcRUq8ZUCwbkHog9LJoZGrZFh5AzlZ6kRuejBqu7AISmZy4s9HVAb7AQmNxvXkK9EIt8lavcaHnLYUIfuxvBqK/ptcUN5P/KXCs1DsbpADjB7YbUu/EQ2OAWncV31Z+O4uMHV29eGTtaz9LoK28+mHRfFHqoazWyuUejor6iSSkrCeqsLEvU8o6rH4oenKz7hLlZsJqHGACYtYNYi2CXYlTu0bMX+Hb1EtTu6Awm9Gn04TqVdmNexgF5CdqW4A696i6jlkPpVCt4B4nq4VPs2RMTkjVl3B7uOkDm18u35dncuhgsnMfVmo9cWX5COeyefdh6kdnKsUf0+IPbV/hix/OCP72dpuhxgcyzN+DvaVLzX7YOx7TpJTzPSKNEQZc=" - - secure: "UFJEzDEv6H2Qscg9UgZFVJq5oFvq7nQkVoSuGfh5Y4ZhL9PCK5f3Ft9oYEZOQwXaxWD1qivtJjQV3DdBiqsHkrnPrJ0hi3iYVDJo26xLNtu3welFw5Veqmgu2NuwjaDn6cjRFCJRLzpszMUWO1DvfLJTs3LuJDuXEyAKDw9eQgfOakqO4xeloyXgM7xnoXz11rgqtJNU6snjVPHftXNPTHGsNDlTR7SAIbjYwLMbdIKM2qjzrXkg+a94QOz2stnTDz9V5iYNH+3XXCcYxD9nb1Ol1XGWvtDnNGEhtGmylLdjHXwGLHiW2HOXskLzSkm7ASie1WdyHVHZb4X8LjxCy62S0FPevBgat1a443Khx5HCMYR/8dQrlOI82GYTr8n9U6QQE4Li8XLw64DVP9HGs9jdbsfEdlIsiPWqB6ujlwiO6pyfmQGQCgjALA+oD87uDQLcgh+SDYgE0ZwmwGzbjeynZpoCrEE8A1GHhSwkM9khx6EJFacm9XzqoUGK0wB1f8su+51fqPglF1zye80IFA4wOMMAY+KUc9du/vQ98f0lfjsNSOC02CKYxbA5RaakQMAYjirsZraA57xLmCSIGMhhW4wClQdJBww6LLz463yZU4WPwyqU+ZW12aV5dVLb5RWXIbZKmdT74DfZajHvqgTYpb05L5cJl7ApMspUkKk=" - -script: ci/travis.sh - -before_cache: -- rm -f $HOME/.gradle/caches/modules-2/modules-2.lock -- rm -fr $HOME/.gradle/caches/*/plugin-resolution/ - -cache: - directories: - - $HOME/.gradle/caches/ - - $HOME/.gradle/wrapper/ diff --git a/README.md b/README.md index ebc9d024f..7ed3244b8 100644 --- a/README.md +++ b/README.md @@ -15,21 +15,22 @@ Learn more at http://rsocket.io ## Build and Binaries -[![Build Status](https://travis-ci.org/rsocket/rsocket-java.svg?branch=develop)](https://travis-ci.org/rsocket/rsocket-java) +[![Build Status](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml/badge.svg?branch=master)](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml) -⚠️ The `master` branch is now dedicated to development of the `1.1.x` line. +⚠️ The `master` branch is now dedicated to development of the `1.2.x` line. -Releases are available via Maven Central. +Releases and milestones are available via Maven Central. Example: ```groovy repositories { - mavenCentral() + mavenCentral() + maven { url 'https://repo.spring.io/milestone' } // Reactor milestones (if needed) } dependencies { - implementation 'io.rsocket:rsocket-core:1.1.0-SNAPSHOT' - implementation 'io.rsocket:rsocket-transport-netty:1.1.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } ``` @@ -39,11 +40,12 @@ Example: ```groovy repositories { - maven { url 'https://oss.jfrog.org/oss-snapshot-local' } + maven { url 'https://maven.pkg.github.com/rsocket/rsocket-java' } + maven { url 'https://repo.spring.io/snapshot' } // Reactor snapshots (if needed) } dependencies { - implementation 'io.rsocket:rsocket-core:1.1.1-SNAPSHOT' - implementation 'io.rsocket:rsocket-transport-netty:1.1.1-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } ``` @@ -131,7 +133,7 @@ For bugs, questions and discussions please use the [Github Issues](https://githu ## LICENSE -Copyright 2015-2018 the original author or authors. +Copyright 2015-2020 the original author or authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/benchmarks/README.md b/benchmarks/README.md index 6ba6755a6..656e2de4b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -17,7 +17,7 @@ Specify extra profilers: Prominent profilers (for full list call `jmhProfilers` task): - comp - JitCompilations, tune your iterations - stack - which methods used most time -- gc - print garbage collection stats +- gc - print garbage collection defaultWeightedStats - hs_thr - thread usage Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index f07f7c6f5..74e571d1f 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -12,12 +12,14 @@ dependencies { // Use the baseline to avoid using new APIs in the benchmarks compileOnly "io.rsocket:rsocket-core:${perfBaselineVersion}" compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-netty:${perfBaselineVersion}" - implementation "org.openjdk.jmh:jmh-core:1.21" - annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.21" + implementation "org.openjdk.jmh:jmh-core:1.35" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.35" current project(':rsocket-core') current project(':rsocket-transport-local') + current project(':rsocket-transport-netty') baseline "io.rsocket:rsocket-core:${perfBaselineVersion}", { changing = true } @@ -33,10 +35,12 @@ task jmhProfilers(type: JavaExec, description:'Lists the available profilers for } task jmh(type: JmhExecTask, description: 'Executing JMH benchmarks') { + main = 'org.openjdk.jmh.Main' classpath = sourceSets.main.runtimeClasspath + configurations.current } task jmhBaseline(type: JmhExecTask, description: 'Executing JMH baseline benchmarks') { + main = 'org.openjdk.jmh.Main' classpath = sourceSets.main.runtimeClasspath + configurations.baseline } @@ -123,7 +127,6 @@ class JmhExecTask extends JavaExec { @TaskAction public void exec() { - setMain("org.openjdk.jmh.Main"); File resultFile = getProject().file("build/reports/" + getName() + "/result." + format); if (include != null) { diff --git a/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java index f78843f5b..4437400c4 100644 --- a/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java +++ b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java @@ -1,25 +1,34 @@ package io.rsocket.core; -import io.rsocket.AbstractRSocket; import io.rsocket.Closeable; import io.rsocket.Payload; import io.rsocket.PayloadsMaxPerfSubscriber; import io.rsocket.PayloadsPerfSubscriber; import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.EmptyPayload; import java.lang.reflect.Field; import java.util.Queue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.LockSupport; -import java.util.stream.IntStream; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; import org.openjdk.jmh.annotations.Level; import org.openjdk.jmh.annotations.Measurement; import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; @@ -31,19 +40,23 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -@BenchmarkMode(Mode.Throughput) -@Fork( - value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} - ) +@BenchmarkMode({Mode.Throughput, Mode.SampleTime}) +@Fork(value = 2) @Warmup(iterations = 10) -@Measurement(iterations = 10, time = 20) +@Measurement(iterations = 10, time = 10) @State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.MICROSECONDS) public class RSocketPerf { - static final Payload PAYLOAD = EmptyPayload.INSTANCE; - static final Mono PAYLOAD_MONO = Mono.just(PAYLOAD); - static final Flux PAYLOAD_FLUX = - Flux.fromArray(IntStream.range(0, 100000).mapToObj(__ -> PAYLOAD).toArray(Payload[]::new)); + @Param({"tcp", "websocket", "local"}) + String transportType; + + @Param({"0", "64", "1024", "131072", "1048576", "15728640"}) + String payloadSize; + + Payload payload; + Mono payloadMono; + Flux payloadsFlux; RSocket client; Closeable server; @@ -53,6 +66,7 @@ public class RSocketPerf { public void tearDown() { client.dispose(); server.dispose(); + payload.release(); } @TearDown(Level.Iteration) @@ -63,12 +77,45 @@ public void awaitToBeConsumed() { } @Setup - public void setUp() throws NoSuchFieldException, IllegalAccessException { - server = + public void setUp() throws NoSuchFieldException, IllegalAccessException, ClassNotFoundException { + ClientTransport clientTransport; + ServerTransport serverTransport; + switch (transportType) { + case "tcp": + clientTransport = TcpClientTransport.create(8081); + serverTransport = TcpServerTransport.create(8081); + break; + case "websocket": + clientTransport = WebsocketClientTransport.create(8081); + serverTransport = WebsocketServerTransport.create(8081); + break; + case "local": + default: + clientTransport = LocalClientTransport.create("server"); + serverTransport = LocalServerTransport.create("server"); + break; + } + Payload payload; + int payloadSize = Integer.parseInt(this.payloadSize); + if (payloadSize == 0) { + payload = EmptyPayload.INSTANCE; + } else { + byte[] randomMetadata = new byte[payloadSize / 2]; + byte[] randomData = new byte[payloadSize / 2]; + ThreadLocalRandom.current().nextBytes(randomData); + ThreadLocalRandom.current().nextBytes(randomMetadata); + + payload = ByteBufPayload.create(randomData, randomMetadata); + } + + this.payload = payload; + this.payloadMono = Mono.fromSupplier(payload::retain); + this.payloadsFlux = Flux.range(0, 100000).map(__ -> payload.retain()); + this.server = RSocketServer.create( (setup, sendingSocket) -> Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono fireAndForget(Payload payload) { @@ -79,13 +126,13 @@ public Mono fireAndForget(Payload payload) { @Override public Mono requestResponse(Payload payload) { payload.release(); - return PAYLOAD_MONO; + return payloadMono; } @Override public Flux requestStream(Payload payload) { payload.release(); - return PAYLOAD_FLUX; + return payloadsFlux; } @Override @@ -94,26 +141,35 @@ public Flux requestChannel(Publisher payloads) { } })) .payloadDecoder(PayloadDecoder.ZERO_COPY) - .bind(LocalServerTransport.create("server")) + .bind(serverTransport) .block(); - client = + this.client = RSocketConnector.create() .payloadDecoder(PayloadDecoder.ZERO_COPY) - .connect(LocalClientTransport.create("server")) + .connect(clientTransport) .block(); - Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); - sendProcessorField.setAccessible(true); + try { + Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); + } catch (Throwable t) { + Field sendProcessorField = + Class.forName("io.rsocket.core.RequesterResponderSupport") + .getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); - clientsQueue = (Queue) sendProcessorField.get(client); + clientsQueue = (Queue) sendProcessorField.get(client); + } } @Benchmark @SuppressWarnings("unchecked") public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); - client.fireAndForget(PAYLOAD).subscribe((CoreSubscriber) subscriber); + client.fireAndForget(payload.retain()).subscribe((CoreSubscriber) subscriber); subscriber.await(); return subscriber; @@ -122,7 +178,7 @@ public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws Interrup @Benchmark public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); - client.requestResponse(PAYLOAD).subscribe(subscriber); + client.requestResponse(payload.retain()).subscribe(subscriber); subscriber.await(); return subscriber; @@ -132,7 +188,7 @@ public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws Interr public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) throws InterruptedException { PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); - client.requestStream(PAYLOAD).subscribe(subscriber); + client.requestStream(payload.retain()).subscribe(subscriber); subscriber.await(); return subscriber; @@ -142,7 +198,7 @@ public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole bl public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) throws InterruptedException { PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); - client.requestStream(PAYLOAD).subscribe(subscriber); + client.requestStream(payload.retain()).subscribe(subscriber); subscriber.await(); return subscriber; @@ -152,7 +208,7 @@ public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole b public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) throws InterruptedException { PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); - client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); + client.requestChannel(payloadsFlux).subscribe(subscriber); subscriber.await(); return subscriber; @@ -162,7 +218,7 @@ public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole b public PayloadsMaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) throws InterruptedException { PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); - client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); + client.requestChannel(payloadsFlux).subscribe(subscriber); subscriber.await(); return subscriber; diff --git a/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java b/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java deleted file mode 100644 index 6b4f3f624..000000000 --- a/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java +++ /dev/null @@ -1,43 +0,0 @@ -package io.rsocket.core; - -import io.netty.util.collection.IntObjectMap; -import io.rsocket.internal.SynchronizedIntObjectHashMap; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.infra.Blackhole; - -@BenchmarkMode(Mode.Throughput) -@Fork( - value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} - ) -@Warmup(iterations = 10) -@Measurement(iterations = 10) -@State(Scope.Thread) -public class StreamIdSupplierPerf { - @Benchmark - public void benchmarkStreamId(Input input) { - int i = input.supplier.nextStreamId(input.map); - input.bh.consume(i); - } - - @State(Scope.Benchmark) - public static class Input { - Blackhole bh; - IntObjectMap map; - StreamIdSupplier supplier; - - @Setup - public void setup(Blackhole bh) { - this.supplier = StreamIdSupplier.clientSupplier(); - this.bh = bh; - this.map = new SynchronizedIntObjectHashMap(); - } - } -} diff --git a/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java similarity index 97% rename from benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java index b4ac808d0..402cdb353 100644 --- a/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java @@ -12,7 +12,7 @@ @Warmup(iterations = 10) @Measurement(iterations = 10) @State(Scope.Thread) -public class FrameHeaderFlyweightPerf { +public class FrameHeaderCodecPerf { @Benchmark public void encode(Input input) { diff --git a/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java similarity index 98% rename from benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java index 01d82a08f..ead1c2fa3 100644 --- a/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java +++ b/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java @@ -13,7 +13,7 @@ @Warmup(iterations = 10) @Measurement(iterations = 10) @State(Scope.Thread) -public class PayloadFlyweightPerf { +public class PayloadFrameCodecPerf { @Benchmark public void encode(Input input) { diff --git a/build.gradle b/build.gradle index 1c11c0511..2971a7767 100644 --- a/build.gradle +++ b/build.gradle @@ -15,31 +15,39 @@ */ plugins { - id 'com.github.sherter.google-java-format' version '0.8' apply false - id 'com.jfrog.artifactory' version '4.15.2' apply false - id 'com.jfrog.bintray' version '1.8.5' apply false - id 'me.champeau.gradle.jmh' version '0.5.0' apply false - id 'io.spring.dependency-management' version '1.0.9.RELEASE' apply false + id 'com.github.sherter.google-java-format' version '0.9' apply false + id 'me.champeau.jmh' version '0.7.1' apply false + id 'io.spring.dependency-management' version '1.1.0' apply false id 'io.morethan.jmhreport' version '0.9.0' apply false + id 'io.github.reyerizo.gradle.jcstress' version '0.8.15' apply false + id 'com.github.vlsi.gradle-extensions' version '1.89' apply false +} + +boolean isCiServer = ["CI", "CONTINUOUS_INTEGRATION", "TRAVIS", "CIRCLECI", "bamboo_planKey", "GITHUB_ACTION"].with { + retainAll(System.getenv().keySet()) + return !isEmpty() } subprojects { apply plugin: 'io.spring.dependency-management' apply plugin: 'com.github.sherter.google-java-format' + apply plugin: 'com.github.vlsi.gradle-extensions' - ext['reactor-bom.version'] = 'Dysprosium-SR8' - ext['logback.version'] = '1.2.3' - ext['findbugs.version'] = '3.0.2' - ext['netty-bom.version'] = '4.1.50.Final' - ext['netty-boringssl.version'] = '2.0.30.Final' - ext['hdrhistogram.version'] = '2.1.10' - ext['mockito.version'] = '3.2.0' - ext['slf4j.version'] = '1.7.25' - ext['jmh.version'] = '1.21' - ext['junit.version'] = '5.5.2' - ext['hamcrest.version'] = '1.3' - ext['micrometer.version'] = '1.0.6' - ext['assertj.version'] = '3.11.1' + ext['reactor-bom.version'] = '2022.0.7-SNAPSHOT' + ext['logback.version'] = '1.2.13' + ext['netty-bom.version'] = '4.1.117.Final' + ext['netty-boringssl.version'] = '2.0.69.Final' + ext['hdrhistogram.version'] = '2.1.12' + ext['mockito.version'] = '4.11.0' + ext['slf4j.version'] = '1.7.36' + ext['jmh.version'] = '1.36' + ext['junit.version'] = '5.9.3' + ext['micrometer.version'] = '1.11.12' + ext['micrometer-tracing.version'] = '1.1.13' + ext['assertj.version'] = '3.24.2' + ext['netflix.limits.version'] = '0.3.6' + ext['bouncycastle-bcpkix.version'] = '1.70' + ext['awaitility.version'] = '4.2.0' group = "io.rsocket" @@ -53,27 +61,32 @@ subprojects { } } + configurations.all { + resolutionStrategy.cacheChangingModulesFor 60, "minutes" + } + dependencyManagement { imports { mavenBom "io.projectreactor:reactor-bom:${ext['reactor-bom.version']}" mavenBom "io.netty:netty-bom:${ext['netty-bom.version']}" mavenBom "org.junit:junit-bom:${ext['junit.version']}" + mavenBom "io.micrometer:micrometer-bom:${ext['micrometer.version']}" + mavenBom "io.micrometer:micrometer-tracing-bom:${ext['micrometer-tracing.version']}" } dependencies { + dependency "com.netflix.concurrency-limits:concurrency-limits-core:${ext['netflix.limits.version']}" dependency "ch.qos.logback:logback-classic:${ext['logback.version']}" dependency "io.netty:netty-tcnative-boringssl-static:${ext['netty-boringssl.version']}" - dependency "io.micrometer:micrometer-core:${ext['micrometer.version']}" + dependency "org.bouncycastle:bcpkix-jdk15on:${ext['bouncycastle-bcpkix.version']}" dependency "org.assertj:assertj-core:${ext['assertj.version']}" dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}" dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}" + dependency "org.awaitility:awaitility:${ext['awaitility.version']}" dependencySet(group: 'org.mockito', version: ext['mockito.version']) { entry 'mockito-junit-jupiter' entry 'mockito-core' } - // TODO: Remove after JUnit5 migration - dependency 'junit:junit:4.12' - dependency "org.hamcrest:hamcrest-library:${ext['hamcrest.version']}" dependencySet(group: 'org.openjdk.jmh', version: ext['jmh.version']) { entry 'jmh-core' entry 'jmh-generator-annprocess' @@ -87,12 +100,30 @@ subprojects { repositories { mavenCentral() - if (version.endsWith('SNAPSHOT') || project.hasProperty('platformVersion')) { - maven { url 'http://repo.spring.io/libs-snapshot' } - maven { - url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' + maven { + url 'https://repo.spring.io/milestone' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + includeGroup "io.micrometer" } } + + maven { + url 'https://repo.spring.io/snapshot' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + } + } + + if (version.endsWith('SNAPSHOT') || project.hasProperty('versionSuffix')) { + maven { url 'https://repo.spring.io/libs-snapshot' } + maven { url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' } + mavenLocal() + } } tasks.withType(GenerateModuleMetadata) { @@ -100,6 +131,7 @@ subprojects { } plugins.withType(JavaPlugin) { + compileJava { sourceCompatibility = 1.8 @@ -118,6 +150,7 @@ subprojects { links 'https://projectreactor.io/docs/core/release/api/' links 'https://netty.io/4.1/api/' } + failOnError = false } tasks.named("javadoc").configure { @@ -126,32 +159,57 @@ subprojects { test { useJUnitPlatform() - - systemProperty "io.netty.leakDetection.level", "ADVANCED" - } - - //all test tasks will show FAILED for each test method, - // common exclusions, no scanning - project.tasks.withType(Test).all { testLogging { - events "FAILED" + events "PASSED", "FAILED" showExceptions true + showCauses true exceptionFormat "FULL" stackTraceFilters "ENTRY_POINT" maxGranularity 3 } + //show progress by displaying test classes, avoiding test suite timeouts + TestDescriptor last + afterTest { TestDescriptor td, TestResult tr -> + if (last != td.getParent()) { + last = td.getParent() + println last + } + } + + if (isCiServer) { + def stdout = new LinkedList() + beforeTest { TestDescriptor td -> + stdout.clear() + } + onOutput { TestDescriptor td, TestOutputEvent toe -> + stdout.add(toe) + } + afterTest { TestDescriptor td, TestResult tr -> + if (tr.resultType == TestResult.ResultType.FAILURE && stdout.size() > 0) { + def stdOutput = stdout.collect { + it.getDestination().name() == "StdErr" + ? "STD_ERR: ${it.getMessage()}" + : "STD_OUT: ${it.getMessage()}" + } + .join() + println "This is the console output of the failing test below:\n$stdOutput" + } + } + + reports { + junitXml.outputPerTestCase = true + } + } + if (JavaVersion.current().isJava9Compatible()) { println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" - jvmArgs = ["-XX:MaxGCPauseMillis=20"] + println "Java 9+: enabling leak detection [ADVANCED]" + jvmArgs = ["-XX:MaxGCPauseMillis=20", "-Dio.netty.leakDetection.level=ADVANCED", "-Dio.netty.leakDetection.samplingInterval=32"] } systemProperty("java.awt.headless", "true") - systemProperty("reactor.trace.cancel", "true") - systemProperty("reactor.trace.nocapacity", "true") systemProperty("testGroups", project.properties.get("testGroups")) - scanForTestClasses = false - exclude '**/*Abstract*.*' //allow re-run of failed tests only without special test tasks failing // because the filter is too restrictive @@ -202,4 +260,31 @@ description = 'RSocket: Stream Oriented Messaging Passing with Reactive Stream S repositories { mavenCentral() + + maven { url 'https://repo.spring.io/snapshot' } + mavenLocal() +} + +configurations { + adoc +} + +dependencies { + adoc "io.micrometer:micrometer-docs-generator-spans:1.0.0-SNAPSHOT" + adoc "io.micrometer:micrometer-docs-generator-metrics:1.0.0-SNAPSHOT" +} + +task generateObservabilityDocs(dependsOn: ["generateObservabilityMetricsDocs", "generateObservabilitySpansDocs"]) { +} + +task generateObservabilityMetricsDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.metrics.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() +} + +task generateObservabilitySpansDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.spans.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() } diff --git a/ci/travis.sh b/ci/travis.sh deleted file mode 100755 index 74e26fdab..000000000 --- a/ci/travis.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env bash - -if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - - echo -e "Building PR #$TRAVIS_PULL_REQUEST [$TRAVIS_PULL_REQUEST_SLUG/$TRAVIS_PULL_REQUEST_BRANCH => $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH]" - ./gradlew build - -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] && [ "$TRAVIS_BRANCH" == "master" ] ; then - - echo -e "Building Develop Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" - ./gradlew \ - -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ - -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ - -PversionSuffix="-SNAPSHOT" \ - -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ - build artifactoryPublish --stacktrace - -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] ; then - - echo -e "Building Branch Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" - ./gradlew \ - -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ - -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ - -PversionSuffix="-${TRAVIS_BRANCH//\//-}-SNAPSHOT" \ - -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ - build artifactoryPublish --stacktrace - -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bintrayUser" != "" ] ; then - - echo -e "Building Tag $TRAVIS_REPO_SLUG/$TRAVIS_TAG" - ./gradlew \ - -Pversion="$TRAVIS_TAG" \ - -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ - -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ - -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ - build bintrayUpload --stacktrace - -else - - echo -e "Building $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH" - ./gradlew build - -fi - diff --git a/gradle.properties b/gradle.properties index 18e7c5584..d138852c5 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,5 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=1.1.0 -perfBaselineVersion=1.0.1 +version=1.2.0 +perfBaselineVersion=1.1.4 diff --git a/gradle/artifactory.gradle b/gradle/artifactory.gradle deleted file mode 100644 index cdffb2741..000000000 --- a/gradle/artifactory.gradle +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey')) { - - subprojects { - plugins.withId('com.jfrog.artifactory') { - artifactory { - publish { - contextUrl = 'https://oss.jfrog.org' - - repository { - repoKey = 'oss-snapshot-local' - - // Credentials for oss.jfrog.org are a user's Bintray credentials - username = project.property('bintrayUser') - password = project.property('bintrayKey') - } - - defaults { - publications(publishing.publications.maven) - } - - if (project.hasProperty('buildNumber')) { - clientConfig.info.setBuildNumber(project.property('buildNumber').toString()) - } - } - } - tasks.named("artifactoryPublish").configure { - onlyIf { System.getenv('SKIP_RELEASE') != "true" } - } - } - } -} diff --git a/gradle/bintray.gradle b/gradle/bintray.gradle deleted file mode 100644 index 5015f94e4..000000000 --- a/gradle/bintray.gradle +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey') && - project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { - - subprojects { - plugins.withId('com.jfrog.bintray') { - bintray { - user = project.property('bintrayUser') - key = project.property('bintrayKey') - - publications = ['maven'] - publish = true - override = true - - pkg { - repo = 'RSocket' - name = 'rsocket-java' - licenses = ['Apache-2.0'] - - issueTrackerUrl = 'https://github.com/rsocket/rsocket-java/issues' - websiteUrl = 'https://github.com/rsocket/rsocket-java' - vcsUrl = 'https://github.com/rsocket/rsocket-java.git' - - githubRepo = 'rsocket/rsocket-java' //Optional Github repository - githubReleaseNotesFile = 'README.md' //Optional Github readme file - - version { - name = project.version - released = new Date() - vcsTag = project.version - - gpg { - sign = true - } - - mavenCentralSync { - user = project.property('sonatypeUsername') - password = project.property('sonatypePassword') - } - } - } - } - tasks.named("bintrayUpload").configure { - onlyIf { System.getenv('SKIP_RELEASE') != "true" } - } - } - } -} diff --git a/gradle/github-pkg.gradle b/gradle/github-pkg.gradle new file mode 100644 index 000000000..f53413766 --- /dev/null +++ b/gradle/github-pkg.gradle @@ -0,0 +1,21 @@ +subprojects { + + plugins.withType(MavenPublishPlugin) { + publishing { + repositories { + maven { + name = "GitHubPackages" + url = uri("https://maven.pkg.github.com/rsocket/rsocket-java") + credentials { + username = project.findProperty("gpr.user") ?: System.getenv("GITHUB_ACTOR") + password = project.findProperty("gpr.key") ?: System.getenv("GITHUB_TOKEN") + } + } + } + } + + tasks.named("publish").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + } +} \ No newline at end of file diff --git a/gradle/publications.gradle b/gradle/publications.gradle index b12d9e9c2..9e8dd6d88 100644 --- a/gradle/publications.gradle +++ b/gradle/publications.gradle @@ -1,5 +1,5 @@ -apply from: "${rootDir}/gradle/artifactory.gradle" -apply from: "${rootDir}/gradle/bintray.gradle" +apply from: "${rootDir}/gradle/github-pkg.gradle" +apply from: "${rootDir}/gradle/sonotype.gradle" subprojects { plugins.withType(MavenPublishPlugin) { @@ -21,30 +21,15 @@ subprojects { } } developers { - developer { - id = 'robertroeser' - name = 'Robert Roeser' - email = 'robert@netifi.com' - } - developer { - id = 'rdegnan' - name = 'Ryland Degnan' - email = 'ryland@netifi.com' - } - developer { - id = 'yschimke' - name = 'Yuri Schimke' - email = 'yuri@schimke.ee' - } developer { id = 'OlegDokuka' name = 'Oleh Dokuka' - email = 'oleh@netifi.com' + email = 'oleh.dokuka@icloud.com' } developer { - id = 'mostroverkhov' - name = 'Maksym Ostroverkhov' - email = 'm.ostroverkhov@gmail.com' + id = 'rstoyanchev' + name = 'Rossen Stoyanchev' + email = 'rstoyanchev@vmware.com' } } scm { diff --git a/gradle/sonotype.gradle b/gradle/sonotype.gradle new file mode 100644 index 000000000..f339079b0 --- /dev/null +++ b/gradle/sonotype.gradle @@ -0,0 +1,36 @@ +subprojects { + if (project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { + plugins.withType(MavenPublishPlugin) { + plugins.withType(SigningPlugin) { + + signing { + //requiring signature if there is a publish task that is not to MavenLocal + required { gradle.taskGraph.allTasks.any { it.name.toLowerCase().contains("publish") && !it.name.contains("MavenLocal") } } + def signingKey = project.findProperty("signingKey") + def signingPassword = project.findProperty("signingPassword") + + useInMemoryPgpKeys(signingKey, signingPassword) + + afterEvaluate { + sign publishing.publications.maven + } + } + + publishing { + repositories { + maven { + name = "sonatype" + url = project.version.contains("-SNAPSHOT") + ? "https://oss.sonatype.org/content/repositories/snapshots/" + : "https://oss.sonatype.org/service/local/staging/deploy/maven2" + credentials { + username project.findProperty("sonatypeUsername") + password project.findProperty("sonatypePassword") + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 5c2d1cf01..249e5832f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 8f6e03af5..774fae876 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,5 @@ -#Mon Jun 08 20:22:21 EEST 2020 -distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-all.zip distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 83f2acfdc..a69d9cb6c 100755 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,78 +17,113 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +APP_BASE_NAME=${0##*/} # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -97,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -105,84 +140,101 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=$((i+1)) + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=$(save "$@") +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then - cd "$(dirname "$0")" -fi +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 24467a141..53a6b238d 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,10 +25,13 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @@ -37,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init +if %ERRORLEVEL% equ 0 goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -51,7 +54,7 @@ goto fail set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe -if exist "%JAVA_EXE%" goto init +if exist "%JAVA_EXE%" goto execute echo. echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% @@ -61,38 +64,26 @@ echo location of your Java installation. goto fail -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - :execute @rem Setup the command line set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle index 2efc20a91..a75ab3bc8 100755 --- a/rsocket-bom/build.gradle +++ b/rsocket-bom/build.gradle @@ -16,8 +16,7 @@ plugins { id 'java-platform' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } description = 'RSocket Java Bill of materials.' diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index 41adbd7a8..da5b69b14 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' id 'io.morethan.jmhreport' - id 'me.champeau.gradle.jmh' + id 'me.champeau.jmh' + id 'io.github.reyerizo.gradle.jcstress' } dependencies { @@ -29,19 +29,32 @@ dependencies { implementation 'org.slf4j:slf4j-api' + testImplementation (project(":rsocket-transport-local")) testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.junit.jupiter:junit-jupiter-params' - testImplementation 'org.mockito:mockito-core' + testImplementation 'org.mockito:mockito-junit-jupiter' + testImplementation 'org.awaitility:awaitility' testRuntimeOnly 'ch.qos.logback:logback-classic' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' - // TODO: Remove after JUnit5 migration - testCompileOnly 'junit:junit' - testImplementation 'org.hamcrest:hamcrest-library' - testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + jcstressImplementation(project(":rsocket-test")) + jcstressImplementation 'org.slf4j:slf4j-api' + jcstressImplementation "ch.qos.logback:logback-classic" + jcstressImplementation 'io.projectreactor:reactor-test' } -description = "Core functionality for the RSocket library" \ No newline at end of file +jcstress { + mode = 'sanity' //sanity, quick, default, tough + jcstressDependency = "org.openjdk.jcstress:jcstress-core:0.16" +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.core") + } +} + +description = "Core functionality for the RSocket library" diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..e91be2451 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,115 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class FireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport(testDuplexConnection, StreamIdSupplier.clientSupplier()); + + final FireAndForgetRequesterMono source = source(); + + abstract FireAndForgetRequesterMono source(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java new file mode 100644 index 000000000..ef79d344d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java @@ -0,0 +1,604 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.ResolvingOperator.EMPTY_SUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.EMPTY_UNSUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.READY; +import static io.rsocket.core.ResolvingOperator.TERMINATED; +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.BiConsumer; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.IIIIIII_Result; +import org.openjdk.jcstress.infra.results.IIIIII_Result; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; + +public abstract class ReconnectMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscription stressSubscription = new StressSubscription<>(); + + final Mono source = source(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + volatile int onValueExpire; + + static final AtomicIntegerFieldUpdater ON_VALUE_EXPIRE = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueExpire"); + + volatile int onValueReceived; + + static final AtomicIntegerFieldUpdater ON_VALUE_RECEIVED = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueReceived"); + final ReconnectMono reconnectMono = + new ReconnectMono<>( + source, + (__) -> ON_VALUE_EXPIRE.incrementAndGet(BaseStressTest.this), + (__, ___) -> ON_VALUE_RECEIVED.incrementAndGet(BaseStressTest.this)); + + abstract Mono source(); + + int state() { + final BiConsumer[] subscribers = reconnectMono.resolvingInner.subscribers; + if (subscribers == EMPTY_UNSUBSCRIBED) { + return 0; + } else if (subscribers == EMPTY_SUBSCRIBED) { + return 1; + } else if (subscribers == READY) { + return 2; + } else if (subscribers == TERMINATED) { + return 3; + } else { + return 4; + } + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before value is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onComplete but before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after value is delivered") + @State + public static class ExpireValueOnRacingDisposeAndNext extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before error is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onError") + @State + public static class ExpireValueOnRacingDisposeAndError extends BaseStressTest { + + { + Hooks.onErrorDropped(t -> {}); + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onError(new RuntimeException("boom")); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + Hooks.resetOnErrorDropped(); + + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 0, 1, 2"}, + expect = ACCEPTABLE, + desc = "Invalidate happens before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Invalidate happens after value is delivered") + @State + public static class ExpireValueOnRacingInvalidateAndNextComplete extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndInvalidate extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate1() { + reconnectMono.invalidate(); + } + + @Actor + void invalidate2() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndDispose extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 2, 2, 0, 1"}, + expect = ACCEPTABLE) + @State + public static class DeliversValueToAllSubscribersUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNextAndComplete() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.requestsCount; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 1, 1, 1, 1, 4"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed after invalidate") + @Outcome( + id = {"1, 0, 2, 2, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed before invalidate and received value") + @State + public static class InvalidateAndSubscribeUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 2, 1, 2, 2"}, + expect = ACCEPTABLE, + desc = "Subscribed again after invalidate") + @Outcome( + id = {"1, 0, 1, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Subscribed before invalidate") + @State + public static class InvalidateAndBlockUnderRace extends BaseStressTest { + + String receivedValue; + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + receivedValue = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals("value1") ? 1 : receivedValue.equals("value2") ? 2 : -1; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRace extends BaseStressTest { + + StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void subscribe1() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void subscribe2() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.values.get(0).equals(stressSubscriber2.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class SubscribeBlockConnectRace extends BaseStressTest { + + String receivedValue; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block() { + receivedValue = reconnectMono.block(); + } + + @Actor + void subscribe() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void connect() { + reconnectMono.resolvingInner.connect(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals(stressSubscriber.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoBlocksRace extends BaseStressTest { + + String receivedValue1; + String receivedValue2; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block1() { + receivedValue1 = reconnectMono.block(); + } + + @Actor + void block2() { + receivedValue2 = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue1.equals(receivedValue2) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java new file mode 100644 index 000000000..1dde77b34 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java @@ -0,0 +1,650 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import java.util.stream.IntStream; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class RequestResponseRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(initialRequest()); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker; + + final TestRequesterResponderSupport requesterResponderSupport; + + final RequestResponseRequesterMono source; + + BaseStressTest(RequesterLeaseTracker requesterLeaseTracker) { + this.requesterLeaseTracker = requesterLeaseTracker; + this.requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + this.source = source(); + } + + abstract RequestResponseRequesterMono source(); + + abstract long initialRequest(); + } + + abstract static class BaseStressTestWithLease extends BaseStressTest { + + BaseStressTestWithLease(int maximumAllowedAwaitingPermitHandlersNumber) { + super(new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber)); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTestWithLease { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + public TwoSubscribesRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return Long.MAX_VALUE; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + this.source.handleNext(nextFrame, false, true); + nextFrame.release(); + + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + this.stressSubscriber1.values.forEach(Payload::release); + + r.r5 = this.source.payload.refCnt() + nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelRaceStressTest extends BaseStressTestWithLease { + + public SubscribeAndRequestAndCancelRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + public SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest() { + super(1); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "NoLeaseError delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancel extends BaseStressTest { + + SubscribeAndRequestAndCancel() { + super(null); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + r.r5 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundNextRaceStressTest extends BaseStressTestWithLease { + + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + + CancelWithInboundNextRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundNext() { + this.source.handleNext(this.nextFrame, false, true); + this.nextFrame.release(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt() + this.nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundCompleteRaceStressTest extends BaseStressTestWithLease { + + CancelWithInboundCompleteRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundComplete() { + this.source.handleComplete(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 2, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 3, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first. inbound error dropped") + @State + public static class CancelWithInboundErrorRaceStressTest extends BaseStressTestWithLease { + + static final RuntimeException ERROR = new RuntimeException("Test"); + + CancelWithInboundErrorRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundError() { + this.source.handleError(ERROR); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..5de7eb4b9 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,288 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLL_Result; + +public abstract class SlowFireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker = + new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber()); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + + final SlowFireAndForgetRequesterMono source = source(); + + abstract SlowFireAndForgetRequesterMono source(); + + abstract int maximumAllowedAwaitingPermitHandlersNumber(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelWithDeferredLeaseRaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 1; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "no lease error delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndCancelWithDeferredLease2RaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java new file mode 100644 index 000000000..883077f77 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java @@ -0,0 +1,472 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static reactor.core.publisher.Operators.addCap; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +public class StressSubscriber implements CoreSubscriber { + + enum Operation { + ON_NEXT, + ON_ERROR, + ON_COMPLETE, + ON_SUBSCRIBE + } + + final Context context; + final int requestedFusionMode; + + int fusionMode; + Subscription subscription; + + public Throwable error; + public boolean done; + + public List droppedErrors = new CopyOnWriteArrayList<>(); + + public List values = new ArrayList<>(); + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscriber.class, "requested"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "wip"); + + public volatile Operation guard; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater GUARD = + AtomicReferenceFieldUpdater.newUpdater(StressSubscriber.class, Operation.class, "guard"); + + public volatile boolean concurrentOnNext; + + public volatile boolean concurrentOnError; + + public volatile boolean concurrentOnComplete; + + public volatile boolean concurrentOnSubscribe; + + public volatile int onNextCalls; + + public BlockingQueue q = new LinkedBlockingDeque<>(); + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextCalls"); + + public volatile int onNextDiscarded; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_DISCARDED = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextDiscarded"); + + public volatile int onErrorCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_ERROR_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onErrorCalls"); + + public volatile int onCompleteCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_COMPLETE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onCompleteCalls"); + + public volatile int onSubscribeCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_SUBSCRIBE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onSubscribeCalls"); + + /** Build a {@link StressSubscriber} that makes an unbounded request upon subscription. */ + public StressSubscriber() { + this(Long.MAX_VALUE, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest) { + this(initRequest, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest, int requestedFusionMode) { + this.requestedFusionMode = requestedFusionMode; + this.context = + Operators.enableOnDiscard( + Context.of( + "reactor.onErrorDropped.local", + (Consumer) throwable -> droppedErrors.add(throwable)), + (__) -> ON_NEXT_DISCARDED.incrementAndGet(this)); + REQUESTED.lazySet(this, initRequest | Long.MIN_VALUE); + } + + @Override + public Context currentContext() { + return this.context; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (!GUARD.compareAndSet(this, null, Operation.ON_SUBSCRIBE)) { + concurrentOnSubscribe = true; + subscription.cancel(); + } else { + final boolean isValid = Operators.validate(this.subscription, subscription); + if (isValid) { + this.subscription = subscription; + } + GUARD.compareAndSet(this, Operation.ON_SUBSCRIBE, null); + + if (this.requestedFusionMode > 0 && subscription instanceof Fuseable.QueueSubscription) { + final int m = + ((Fuseable.QueueSubscription) subscription).requestFusion(this.requestedFusionMode); + final long requested = this.requested; + this.fusionMode = m; + if (m != Fuseable.NONE) { + if (requested == Long.MAX_VALUE) { + subscription.cancel(); + } + drain(); + return; + } + } + + if (isValid) { + long delivered = 0; + for (; ; ) { + long s = requested; + if (s == Long.MAX_VALUE) { + subscription.cancel(); + break; + } + + long r = s & Long.MAX_VALUE; + long toRequest = r - delivered; + if (toRequest > 0) { + subscription.request(toRequest); + delivered = r; + } + + if (REQUESTED.compareAndSet(this, s, 0)) { + break; + } + } + } + } + ON_SUBSCRIBE_CALLS.incrementAndGet(this); + } + + @Override + public void onNext(T value) { + if (fusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + if (!GUARD.compareAndSet(this, null, Operation.ON_NEXT)) { + concurrentOnNext = true; + } else { + values.add(value); + GUARD.compareAndSet(this, Operation.ON_NEXT, null); + } + ON_NEXT_CALLS.incrementAndGet(this); + } + + @Override + public void onError(Throwable throwable) { + if (!GUARD.compareAndSet(this, null, Operation.ON_ERROR)) { + concurrentOnError = true; + } else { + GUARD.compareAndSet(this, Operation.ON_ERROR, null); + } + + if (done) { + throw new IllegalStateException("Already done"); + } + + error = throwable; + done = true; + q.offer(throwable); + ON_ERROR_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + @Override + public void onComplete() { + if (!GUARD.compareAndSet(this, null, Operation.ON_COMPLETE)) { + concurrentOnComplete = true; + } else { + GUARD.compareAndSet(this, Operation.ON_COMPLETE, null); + } + if (done) { + throw new IllegalStateException("Already done"); + } + + done = true; + ON_COMPLETE_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + public void request(long n) { + if (Operators.validate(n)) { + for (; ; ) { + final long s = this.requested; + if (s == 0) { + this.subscription.request(n); + return; + } + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = addCap(r, n); + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + } + + public void cancel() { + for (; ; ) { + long s = this.requested; + if (s == 0) { + this.subscription.cancel(); + return; + } + + if (REQUESTED.compareAndSet(this, s, Long.MAX_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + + @SuppressWarnings("unchecked") + private void drain() { + final int previousState = markWorkAdded(); + if (isFinalized(previousState)) { + ((Queue) this.subscription).clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + final Subscription s = this.subscription; + final Queue q = (Queue) s; + + int expectedState = previousState + 1; + for (; ; ) { + long r = this.requested & Long.MAX_VALUE; + long e = 0L; + + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + boolean done = this.done; + + final T t = q.poll(); + final boolean empty = t == null; + + if (checkTerminated(done, empty)) { + if (!empty) { + values.add(t); + } + return; + } + + if (empty) { + break; + } + + values.add(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + boolean done = this.done; + boolean empty = q.isEmpty(); + + if (checkTerminated(done, empty)) { + return; + } + } + + if (e != 0) { + ON_NEXT_CALLS.addAndGet(this, (int) e); + if (r != Long.MAX_VALUE) { + produce(e); + } + } + + expectedState = markWorkDone(expectedState); + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + boolean checkTerminated(boolean done, boolean empty) { + final long state = this.requested; + if (state == Long.MAX_VALUE) { + this.subscription.cancel(); + clearAndFinalize(); + return true; + } + + if (done && empty) { + clearAndFinalize(); + return true; + } + + return false; + } + + final void produce(long produced) { + for (; ; ) { + final long s = this.requested; + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = r - produced; + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + return; + } + } + } + + @SuppressWarnings("unchecked") + final void clearAndFinalize() { + final Queue q = (Queue) this.subscription; + for (; ; ) { + final int state = this.wip; + + q.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + final int markWorkAdded() { + for (; ; ) { + final int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + int nextState = state + 1; + if ((nextState & Integer.MAX_VALUE) == 0) { + return state; + } + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + final int markWorkDone(int expectedState) { + for (; ; ) { + final int state = this.wip; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + if (WIP.compareAndSet(this, state, 0)) { + return 0; + } + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return (state & Integer.MAX_VALUE) > 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java new file mode 100644 index 000000000..3b51b8ef6 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Operators; + +public class StressSubscription implements Subscription { + + CoreSubscriber actual; + + public volatile int subscribes; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater SUBSCRIBES = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "subscribes"); + + public volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscription.class, "requested"); + + public volatile int requestsCount; + + @SuppressWarnings("rawtype s") + static final AtomicIntegerFieldUpdater REQUESTS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "requestsCount"); + + public volatile boolean cancelled; + + void subscribe(CoreSubscriber actual) { + this.actual = actual; + actual.onSubscribe(this); + SUBSCRIBES.getAndIncrement(this); + } + + @Override + public void request(long n) { + REQUESTS_COUNT.incrementAndGet(this); + Operators.addCap(REQUESTED, this, n); + } + + @Override + public void cancel() { + cancelled = true; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..420da66ba --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,39 @@ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import reactor.util.annotation.Nullable; + +public class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + public TestRequesterResponderSupport( + DuplexConnection connection, StreamIdSupplier streamIdSupplier) { + this(connection, streamIdSupplier, null); + } + + public TestRequesterResponderSupport( + DuplexConnection connection, + StreamIdSupplier streamIdSupplier, + @Nullable RequesterLeaseTracker requesterLeaseTracker) { + super( + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + __ -> null); + this.requesterLeaseTracker = requesterLeaseTracker; + } + + @Override + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java new file mode 100644 index 000000000..22c478979 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java @@ -0,0 +1,155 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class UnpooledByteBufPayload extends AbstractReferenceCounted implements Payload { + + private final ByteBuf data; + private final ByteBuf metadata; + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data) { + return create(data, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload(ByteBufUtil.writeUtf8(allocator, data), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata) { + return create(data, metadata, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload( + ByteBufUtil.writeUtf8(allocator, data), + metadata == null ? null : ByteBufUtil.writeUtf8(allocator, metadata)); + } + + public UnpooledByteBufPayload(ByteBuf data, @Nullable ByteBuf metadata) { + this.data = data; + this.metadata = metadata; + } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public UnpooledByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public UnpooledByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public UnpooledByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public UnpooledByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + if (metadata != null) { + metadata.release(); + } + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link UnpooledByteBufPayload#ensureAccessible()} to try to guard against + * using the buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java new file mode 100644 index 000000000..a2d9fcf4d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java @@ -0,0 +1,1733 @@ +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.core.StressSubscriber; +import io.rsocket.utils.FastLogger; +import java.util.Arrays; +import java.util.ConcurrentModificationException; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.Expect; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; +import org.openjdk.jcstress.infra.results.LLL_Result; +import org.openjdk.jcstress.infra.results.L_Result; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.util.Logger; + +public abstract class UnboundedProcessorStressTest { + + static { + Hooks.onErrorDropped(t -> {}); + } + + final Logger logger = new FastLogger(getClass().getName()); + + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(logger); + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeFusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.onCompleteCalls > 0 && stressSubscriber.onErrorCalls > 0) { + throw new RuntimeException("boom"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke24StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class Smoke21FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with onComplete()") + @State + public static class Smoke30StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke31StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.concurrentOnNext || stressSubscriber.concurrentOnComplete) { + throw new ConcurrentModificationException("boo"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke32StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0, 5", + "1, 1, 0, 5", + "2, 1, 0, 5", + "3, 1, 0, 5", + "4, 1, 0, 5", + "5, 1, 0, 5", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke33StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + final ByteBuf byteBuf5 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(5); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.tryEmitNormal(byteBuf1); + unboundedProcessor.tryEmitPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.tryEmitPrioritized(byteBuf3); + unboundedProcessor.tryEmitNormal(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.tryEmitFinal(byteBuf5); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + r.r4 = stressSubscriber.values.get(stressSubscriber.values.size() - 1).readByte(); + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = + byteBuf1.refCnt() + + byteBuf2.refCnt() + + byteBuf3.refCnt() + + byteBuf4.refCnt() + + byteBuf5.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"1"}, + expect = Expect.ACCEPTABLE) + @State + public static class SubscribeVsSubscribeStressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(0, Fuseable.NONE); + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(0, Fuseable.NONE); + + @Actor + public void subscribe1() { + unboundedProcessor.subscribe(stressSubscriber1); + } + + @Actor + public void subscribe2() { + unboundedProcessor.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(L_Result r) { + r.r1 = stressSubscriber1.onErrorCalls + stressSubscriber2.onErrorCalls; + + checkOutcomes(this, r.toString(), logger); + } + } + + static void checkOutcomes(Object instance, String result, Logger logger) { + if (Arrays.stream(instance.getClass().getDeclaredAnnotationsByType(Outcome.class)) + .flatMap(o -> Arrays.stream(o.id())) + .noneMatch(s -> s.equalsIgnoreCase(result))) { + throw new RuntimeException(result + " " + logger); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java new file mode 100644 index 000000000..f0b209552 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java @@ -0,0 +1,118 @@ +package io.rsocket.resume; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LL_Result; +import reactor.core.Disposable; + +public class InMemoryResumableFramesStoreStressTest { + boolean storeClosed; + + InMemoryResumableFramesStore store = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 128); + boolean processorClosed; + UnboundedProcessor processor = new UnboundedProcessor(() -> processorClosed = true); + + void subscribe() { + store.saveFrames(processor).subscribe(); + store.onClose().subscribe(null, t -> storeClosed = true, () -> storeClosed = true); + } + + @JCStressTest + @Outcome( + id = {"true, true"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends InMemoryResumableFramesStoreStressTest { + + Disposable d1; + + final ByteBuf b1 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello1"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello2")); + final ByteBuf b2 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 3, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello3"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello4")); + final ByteBuf b3 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 5, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello5"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello6")); + + final ByteBuf c1 = + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new ConnectionErrorException("closed")); + + { + subscribe(); + d1 = store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer1() { + processor.tryEmitNormal(b1); + processor.tryEmitNormal(b2); + processor.tryEmitNormal(b3); + } + + @Actor + public void producer2() { + processor.tryEmitFinal(c1); + } + + @Actor + public void producer3() { + d1.dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer4() { + store.releaseFrames(0); + store.releaseFrames(0); + store.releaseFrames(0); + } + + @Arbiter + public void arbiter(LL_Result r) { + r.r1 = storeClosed; + r.r2 = processorClosed; + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java new file mode 100644 index 000000000..c301d87cf --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java @@ -0,0 +1,137 @@ +package io.rsocket.utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import reactor.util.Logger; + +/** + * Implementation of {@link Logger} which is based on the {@link ThreadLocal} based queue which + * collects all the events on the per-thread basis.
Such logger is designed to have all events + * stored during the stress-test run and then sorted and printed out once all the Threads completed + * execution (inside the {@link org.openjdk.jcstress.annotations.Arbiter} annotated method.
+ * Note, this implementation only supports trace-level logs and ignores all others, it is intended + * to be used by {@link reactor.core.publisher.StateLogger}. + */ +public class FastLogger implements Logger { + + final Map> queues = new ConcurrentHashMap<>(); + + final ThreadLocal> logsQueueLocal = + ThreadLocal.withInitial( + () -> { + final ArrayList logs = new ArrayList<>(100); + queues.put(Thread.currentThread(), logs); + return logs; + }); + + private final String name; + + public FastLogger(String name) { + this.name = name; + } + + @Override + public String toString() { + return queues + .values() + .stream() + .flatMap(List::stream) + .sorted( + Comparator.comparingLong( + s -> { + Pattern pattern = Pattern.compile("\\[(.*?)]"); + Matcher matcher = pattern.matcher(s); + matcher.find(); + return Long.parseLong(matcher.group(1)); + })) + .collect(Collectors.joining("\n")); + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean isTraceEnabled() { + return true; + } + + @Override + public void trace(String msg) { + logsQueueLocal.get().add(String.format("[%s] %s", System.nanoTime(), msg)); + } + + @Override + public void trace(String format, Object... arguments) { + trace(String.format(format, arguments)); + } + + @Override + public void trace(String msg, Throwable t) { + trace(String.format("%s, %s", msg, Arrays.toString(t.getStackTrace()))); + } + + @Override + public boolean isDebugEnabled() { + return false; + } + + @Override + public void debug(String msg) {} + + @Override + public void debug(String format, Object... arguments) {} + + @Override + public void debug(String msg, Throwable t) {} + + @Override + public boolean isInfoEnabled() { + return false; + } + + @Override + public void info(String msg) {} + + @Override + public void info(String format, Object... arguments) {} + + @Override + public void info(String msg, Throwable t) {} + + @Override + public boolean isWarnEnabled() { + return false; + } + + @Override + public void warn(String msg) {} + + @Override + public void warn(String format, Object... arguments) {} + + @Override + public void warn(String msg, Throwable t) {} + + @Override + public boolean isErrorEnabled() { + return false; + } + + @Override + public void error(String msg) {} + + @Override + public void error(String format, Object... arguments) {} + + @Override + public void error(String msg, Throwable t) {} +} diff --git a/rsocket-core/src/jcstress/resources/logback.xml b/rsocket-core/src/jcstress/resources/logback.xml new file mode 100644 index 000000000..e5877552c --- /dev/null +++ b/rsocket-core/src/jcstress/resources/logback.xml @@ -0,0 +1,39 @@ + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java deleted file mode 100644 index 7f39956dc..000000000 --- a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * An abstract implementation of {@link RSocket}. All request handling methods emit {@link - * UnsupportedOperationException} and hence must be overridden to provide a valid implementation. - * - * @deprecated as of 1.0 in favor of implementing {@link RSocket} directly which has default - * methods. - */ -@Deprecated -public abstract class AbstractRSocket implements RSocket { - - private final MonoProcessor onClose = MonoProcessor.create(); - - @Override - public void dispose() { - onClose.onComplete(); - } - - @Override - public boolean isDisposed() { - return onClose.isDisposed(); - } - - @Override - public Mono onClose() { - return onClose; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index ece2aa9fa..c39e679a1 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -18,7 +18,6 @@ import io.netty.buffer.ByteBuf; import io.netty.util.AbstractReferenceCounted; -import io.rsocket.core.DefaultConnectionSetupPayload; import reactor.util.annotation.Nullable; /** @@ -57,16 +56,4 @@ public ConnectionSetupPayload retain(int increment) { @Override public abstract ConnectionSetupPayload touch(); - - /** - * Create a {@code ConnectionSetupPayload}. - * - * @deprecated as of 1.0 RC7. Please, use {@link - * DefaultConnectionSetupPayload#DefaultConnectionSetupPayload(ByteBuf) - * DefaultConnectionSetupPayload} constructor. - */ - @Deprecated - public static ConnectionSetupPayload create(final ByteBuf setupFrame) { - return new DefaultConnectionSetupPayload(setupFrame); - } } diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index 6190d24e3..fe91f4bf0 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,41 +18,30 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import java.net.SocketAddress; import java.nio.channels.ClosedChannelException; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** Represents a connection with input/output that the protocol uses. */ public interface DuplexConnection extends Availability, Closeable { /** - * Sends the source of Frames on this connection and returns the {@code Publisher} representing - * the result of this send. + * Delivers the given frame to the underlying transport connection. This method is non-blocking + * and can be safely executed from multiple threads. This method does not provide any flow-control + * mechanism. * - *

Flow control - * - *

The passed {@code Publisher} must - * - * @param frames Stream of {@code Frame}s to send on the connection. - * @return {@code Publisher} that completes when all the frames are written on the connection - * successfully and errors when it fails. - * @throws NullPointerException if {@code frames} is {@code null} + * @param streamId to which the given frame relates + * @param frame with the encoded content */ - Mono send(Publisher frames); + void sendFrame(int streamId, ByteBuf frame); /** - * Sends a single {@code Frame} on this connection and returns the {@code Publisher} representing - * the result of this send. + * Send an error frame and after it is successfully sent, close the connection. * - * @param frame {@code Frame} to send. - * @return {@code Publisher} that completes when the frame is written on the connection - * successfully and errors when it fails. + * @param errorException to encode in the error frame */ - default Mono sendOne(ByteBuf frame) { - return send(Mono.just(frame)); - } + void sendErrorAndClose(RSocketErrorException errorException); /** * Returns a stream of all {@code Frame}s received on this connection. @@ -60,7 +49,7 @@ default Mono sendOne(ByteBuf frame) { *

Completion * *

Returned {@code Publisher} MUST never emit a completion event ({@link - * Subscriber#onComplete()}. + * Subscriber#onComplete()}). * *

Error * @@ -86,6 +75,17 @@ default Mono sendOne(ByteBuf frame) { */ ByteBufAllocator alloc(); + /** + * Return the remote address that this connection is connected to. The returned {@link + * SocketAddress} varies by transport type and should be downcast to obtain more detailed + * information. For TCP and WebSocket, the address type is {@link java.net.InetSocketAddress}. For + * local transport, it is {@link io.rsocket.transport.local.LocalSocketAddress}. + * + * @return the address + * @since 1.1 + */ + SocketAddress remoteAddress(); + @Override default double availability() { return isDisposed() ? 0.0 : 1.0; diff --git a/rsocket-core/src/main/java/io/rsocket/RSocket.java b/rsocket-core/src/main/java/io/rsocket/RSocket.java index 773c93dc2..b05241365 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,8 +34,7 @@ public interface RSocket extends Availability, Closeable { * handled, otherwise errors. */ default Mono fireAndForget(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Fire-and-Forget not implemented.")); + return RSocketAdapter.fireAndForget(payload); } /** @@ -46,8 +45,7 @@ default Mono fireAndForget(Payload payload) { * response. */ default Mono requestResponse(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); + return RSocketAdapter.requestResponse(payload); } /** @@ -57,8 +55,7 @@ default Mono requestResponse(Payload payload) { * @return {@code Publisher} containing the stream of {@code Payload}s representing the response. */ default Flux requestStream(Payload payload) { - payload.release(); - return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); + return RSocketAdapter.requestStream(payload); } /** @@ -68,7 +65,7 @@ default Flux requestStream(Payload payload) { * @return Stream of response payloads. */ default Flux requestChannel(Publisher payloads) { - return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); + return RSocketAdapter.requestChannel(payloads); } /** @@ -79,8 +76,7 @@ default Flux requestChannel(Publisher payloads) { * handled, otherwise errors. */ default Mono metadataPush(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); + return RSocketAdapter.metadataPush(payload); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java new file mode 100644 index 000000000..b5a64b8dd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Package private class with default implementations for use in {@link RSocket}. The main purpose + * is to hide static {@link UnsupportedOperationException} declarations. + * + * @since 1.0.3 + */ +class RSocketAdapter { + + private static final Mono UNSUPPORTED_FIRE_AND_FORGET = + Mono.error(new UnsupportedInteractionException("Fire-and-Forget")); + + private static final Mono UNSUPPORTED_REQUEST_RESPONSE = + Mono.error(new UnsupportedInteractionException("Request-Response")); + + private static final Flux UNSUPPORTED_REQUEST_STREAM = + Flux.error(new UnsupportedInteractionException("Request-Stream")); + + private static final Flux UNSUPPORTED_REQUEST_CHANNEL = + Flux.error(new UnsupportedInteractionException("Request-Channel")); + + private static final Mono UNSUPPORTED_METADATA_PUSH = + Mono.error(new UnsupportedInteractionException("Metadata-Push")); + + static Mono fireAndForget(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_FIRE_AND_FORGET; + } + + static Mono requestResponse(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_RESPONSE; + } + + static Flux requestStream(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_STREAM; + } + + static Flux requestChannel(Publisher payloads) { + return RSocketAdapter.UNSUPPORTED_REQUEST_CHANNEL; + } + + static Mono metadataPush(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_METADATA_PUSH; + } + + private static class UnsupportedInteractionException extends RuntimeException { + + private static final long serialVersionUID = 5084623297446471999L; + + UnsupportedInteractionException(String interactionName) { + super(interactionName + " not implemented.", null, false, false); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java deleted file mode 100644 index 098cdfe9c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ /dev/null @@ -1,82 +0,0 @@ -package io.rsocket; - -import org.reactivestreams.Publisher; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; - -/** - * Contract to perform RSocket requests from client to server, transparently connecting and ensuring - * a single, shared connection to make requests with. - * - *

{@code RSocketClient} contains a {@code Mono} {@link #source() source}. It uses it to - * obtain a live, shared {@link RSocket} connection on the first request and on subsequent requests - * if the connection is lost. This eliminates the need to obtain a connection first, and makes it - * easy to pass a single {@code RSocketClient} to use from multiple places. - * - *

Request methods of {@code RSocketClient} allow multiple subscriptions with each subscription - * performing a new request. Therefore request methods accept {@code Mono} rather than - * {@code Payload} as on {@link RSocket}. By contrast, {@link RSocket} request methods cannot be - * subscribed to more than once. - * - *

Use {@link io.rsocket.core.RSocketConnector RSocketConnector} to create a client: - * - *

{@code
- * RSocketClient client =
- *         RSocketConnector.create()
- *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
- *                 .dataMimeType("application/cbor")
- *                 .toRSocketClient(TcpClientTransport.create("localhost", 7000));
- * }
- * - *

Use the {@link io.rsocket.core.RSocketConnector#reconnect(Retry) RSocketConnector#reconnect} - * method to configure the retry logic to use whenever a shared {@code RSocket} connection needs to - * be obtained: - * - *

{@code
- * RSocketClient client =
- *         RSocketConnector.create()
- *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
- *                 .dataMimeType("application/cbor")
- *                 .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
- *                 .toRSocketClient(TcpClientTransport.create("localhost", 7000));
- * }
- * - * @since 1.0.1 - */ -public interface RSocketClient extends Disposable { - - /** Return the underlying source used to obtain a shared {@link RSocket} connection. */ - Mono source(); - - /** - * Perform a Fire-and-Forget interaction via {@link RSocket#fireAndForget(Payload)}. Allows - * multiple subscriptions and performs a request per subscriber. - */ - Mono fireAndForget(Mono payloadMono); - - /** - * Perform a Request-Response interaction via {@link RSocket#requestResponse(Payload)}. Allows - * multiple subscriptions and performs a request per subscriber. - */ - Mono requestResponse(Mono payloadMono); - - /** - * Perform a Request-Stream interaction via {@link RSocket#requestStream(Payload)}. Allows - * multiple subscriptions and performs a request per subscriber. - */ - Flux requestStream(Mono payloadMono); - - /** - * Perform a Request-Channel interaction via {@link RSocket#requestChannel(Publisher)}. Allows - * multiple subscriptions and performs a request per subscriber. - */ - Flux requestChannel(Publisher payloads); - - /** - * Perform a Metadata Push via {@link RSocket#metadataPush(Payload)}. Allows multiple - * subscriptions and performs a request per subscriber. - */ - Mono metadataPush(Mono payloadMono); -} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java deleted file mode 100644 index e23bcceb2..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ /dev/null @@ -1,571 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.core.RSocketConnector; -import io.rsocket.core.RSocketServer; -import io.rsocket.core.Resume; -import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.lease.LeaseStats; -import io.rsocket.lease.Leases; -import io.rsocket.plugins.DuplexConnectionInterceptor; -import io.rsocket.plugins.RSocketInterceptor; -import io.rsocket.plugins.SocketAcceptorInterceptor; -import io.rsocket.resume.ClientResume; -import io.rsocket.resume.ResumableFramesStore; -import io.rsocket.resume.ResumeStrategy; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.time.Duration; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; - -/** - * Main entry point to create RSocket clients or servers as follows: - * - *
    - *
  • {@link ClientRSocketFactory} to connect as a client. Use {@link #connect()} for a default - * instance. - *
  • {@link ServerRSocketFactory} to start a server. Use {@link #receive()} for a default - * instance. - *
- * - * @deprecated please use {@link RSocketConnector} and {@link RSocketServer}. - */ -@Deprecated -public final class RSocketFactory { - - /** - * Create a {@code ClientRSocketFactory} to connect to a remote RSocket endpoint. Internally - * delegates to {@link RSocketConnector}. - * - * @return the {@code ClientRSocketFactory} instance - */ - public static ClientRSocketFactory connect() { - return new ClientRSocketFactory(); - } - - /** - * Create a {@code ServerRSocketFactory} to accept connections from RSocket clients. Internally - * delegates to {@link RSocketServer}. - * - * @return the {@code ClientRSocketFactory} instance - */ - public static ServerRSocketFactory receive() { - return new ServerRSocketFactory(); - } - - public interface Start { - Mono start(); - } - - public interface ClientTransportAcceptor { - Start transport(Supplier transport); - - default Start transport(ClientTransport transport) { - return transport(() -> transport); - } - } - - public interface ServerTransportAcceptor { - - ServerTransport.ConnectionAcceptor toConnectionAcceptor(); - - Start transport(Supplier> transport); - - default Start transport(ServerTransport transport) { - return transport(() -> transport); - } - } - - /** Factory to create and configure an RSocket client, and connect to a server. */ - public static class ClientRSocketFactory implements ClientTransportAcceptor { - private static final ClientResume CLIENT_RESUME = - new ClientResume(Duration.ofMinutes(2), Unpooled.EMPTY_BUFFER); - - private final RSocketConnector connector; - - private Duration tickPeriod = Duration.ofSeconds(20); - private Duration ackTimeout = Duration.ofSeconds(30); - private int missedAcks = 3; - - private Resume resume; - - public ClientRSocketFactory() { - this(RSocketConnector.create()); - } - - public ClientRSocketFactory(RSocketConnector connector) { - this.connector = connector; - } - - /** - * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in - * order configure the custom {@link ByteBufAllocator} it is recommended to use the - * following setup for Reactor Netty based transport:
- * 1. For Client:
- *
{@code
-     * TcpClient.create()
-     *          ...
-     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
-     * }
- *
- * 2. For server:
- *
{@code
-     * TcpServer.create()
-     *          ...
-     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
-     * }
- * Or in case of local transport, to use corresponding factory method {@code - * LocalClientTransport.creat(String, ByteBufAllocator)} - * @param allocator instance of {@link ByteBufAllocator} - * @return this factory instance - */ - public ClientRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - return this; - } - - public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - connector.interceptors(registry -> registry.forConnection(interceptor)); - return this; - } - - /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ - @Deprecated - public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - return addRequesterPlugin(interceptor); - } - - public ClientRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { - connector.interceptors(registry -> registry.forRequester(interceptor)); - return this; - } - - /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ - @Deprecated - public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - return addResponderPlugin(interceptor); - } - - public ClientRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { - connector.interceptors(registry -> registry.forResponder(interceptor)); - return this; - } - - public ClientRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - connector.interceptors(registry -> registry.forSocketAcceptor(interceptor)); - return this; - } - - /** - * Deprecated without replacement as Keep-Alive is not optional according to spec - * - * @return this ClientRSocketFactory - */ - @Deprecated - public ClientRSocketFactory keepAlive() { - connector.keepAlive(tickPeriod, ackTimeout.plus(tickPeriod.multipliedBy(missedAcks))); - return this; - } - - public ClientTransportAcceptor keepAlive( - Duration tickPeriod, Duration ackTimeout, int missedAcks) { - this.tickPeriod = tickPeriod; - this.ackTimeout = ackTimeout; - this.missedAcks = missedAcks; - keepAlive(); - return this; - } - - public ClientRSocketFactory keepAliveTickPeriod(Duration tickPeriod) { - this.tickPeriod = tickPeriod; - keepAlive(); - return this; - } - - public ClientRSocketFactory keepAliveAckTimeout(Duration ackTimeout) { - this.ackTimeout = ackTimeout; - keepAlive(); - return this; - } - - public ClientRSocketFactory keepAliveMissedAcks(int missedAcks) { - this.missedAcks = missedAcks; - keepAlive(); - return this; - } - - public ClientRSocketFactory mimeType(String metadataMimeType, String dataMimeType) { - connector.metadataMimeType(metadataMimeType); - connector.dataMimeType(dataMimeType); - return this; - } - - public ClientRSocketFactory dataMimeType(String dataMimeType) { - connector.dataMimeType(dataMimeType); - return this; - } - - public ClientRSocketFactory metadataMimeType(String metadataMimeType) { - connector.metadataMimeType(metadataMimeType); - return this; - } - - public ClientRSocketFactory lease(Supplier> supplier) { - connector.lease(supplier); - return this; - } - - public ClientRSocketFactory lease() { - connector.lease(Leases::new); - return this; - } - - /** @deprecated without a replacement and no longer used. */ - @Deprecated - public ClientRSocketFactory singleSubscriberRequester() { - return this; - } - - /** - * Enables a reconnectable, shared instance of {@code Mono} so every subscriber will - * observe the same RSocket instance up on connection establishment.
- * For example: - * - *
{@code
-     * Mono sharedRSocketMono =
-     *   RSocketFactory
-     *                .connect()
-     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
-     *                .transport(transport)
-     *                .start();
-     *
-     *  RSocket r1 = sharedRSocketMono.block();
-     *  RSocket r2 = sharedRSocketMono.block();
-     *
-     *  assert r1 == r2;
-     *
-     * }
- * - * Apart of the shared behavior, if the connection is lost, the same {@code Mono} - * instance will transparently re-establish the connection for subsequent subscribers.
- * For example: - * - *
{@code
-     * Mono sharedRSocketMono =
-     *   RSocketFactory
-     *                .connect()
-     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
-     *                .transport(transport)
-     *                .start();
-     *
-     *  RSocket r1 = sharedRSocketMono.block();
-     *  RSocket r2 = sharedRSocketMono.block();
-     *
-     *  assert r1 == r2;
-     *
-     *  r1.dispose()
-     *
-     *  assert r2.isDisposed()
-     *
-     *  RSocket r3 = sharedRSocketMono.block();
-     *  RSocket r4 = sharedRSocketMono.block();
-     *
-     *
-     *  assert r1 != r3;
-     *  assert r4 == r3;
-     *
-     * }
- * - * Note, having reconnect() enabled does not eliminate the need to accompany each - * individual request with the corresponding retry logic.
- * For example: - * - *
{@code
-     * Mono sharedRSocketMono =
-     *   RSocketFactory
-     *                .connect()
-     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
-     *                .transport(transport)
-     *                .start();
-     *
-     *  sharedRSocket.flatMap(rSocket -> rSocket.requestResponse(...))
-     *               .retryWhen(ownRetry)
-     *               .subscribe()
-     *
-     * }
- * - * @param retrySpec a retry factory applied for {@link Mono#retryWhen(Retry)} - * @return a shared instance of {@code Mono}. - */ - public ClientRSocketFactory reconnect(Retry retrySpec) { - connector.reconnect(retrySpec); - return this; - } - - public ClientRSocketFactory resume() { - resume = resume != null ? resume : new Resume(); - connector.resume(resume); - return this; - } - - public ClientRSocketFactory resumeToken(Supplier supplier) { - resume(); - resume.token(supplier); - return this; - } - - public ClientRSocketFactory resumeStore( - Function storeFactory) { - resume(); - resume.storeFactory(storeFactory); - return this; - } - - public ClientRSocketFactory resumeSessionDuration(Duration sessionDuration) { - resume(); - resume.sessionDuration(sessionDuration); - return this; - } - - public ClientRSocketFactory resumeStreamTimeout(Duration streamTimeout) { - resume(); - resume.streamTimeout(streamTimeout); - return this; - } - - public ClientRSocketFactory resumeStrategy(Supplier strategy) { - resume(); - resume.retry( - Retry.from( - signals -> signals.flatMap(s -> strategy.get().apply(CLIENT_RESUME, s.failure())))); - return this; - } - - public ClientRSocketFactory resumeCleanupOnKeepAlive() { - resume(); - resume.cleanupStoreOnKeepAlive(); - return this; - } - - public Start transport(Supplier transport) { - return () -> connector.connect(transport); - } - - public ClientTransportAcceptor acceptor(Function acceptor) { - return acceptor(() -> acceptor); - } - - public ClientTransportAcceptor acceptor(Supplier> acceptorSupplier) { - return acceptor( - (setup, sendingSocket) -> { - acceptorSupplier.get().apply(sendingSocket); - return Mono.empty(); - }); - } - - public ClientTransportAcceptor acceptor(SocketAcceptor acceptor) { - connector.acceptor(acceptor); - return this; - } - - public ClientRSocketFactory fragment(int mtu) { - connector.fragment(mtu); - return this; - } - - /** - * @deprecated this handler is deliberately no-ops and is deprecated with no replacement. In - * order to observe errors, it is recommended to add error handler using {@code doOnError} - * on the specific logical stream. In order to observe connection, or RSocket terminal - * errors, it is recommended to hook on {@link Closeable#onClose()} handler. - */ - public ClientRSocketFactory errorConsumer(Consumer errorConsumer) { - return this; - } - - public ClientRSocketFactory setupPayload(Payload payload) { - connector.setupPayload(payload); - return this; - } - - public ClientRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - connector.payloadDecoder(payloadDecoder); - return this; - } - } - - /** Factory to create, configure, and start an RSocket server. */ - public static class ServerRSocketFactory implements ServerTransportAcceptor { - private final RSocketServer server; - - private Resume resume; - - public ServerRSocketFactory() { - this(RSocketServer.create()); - } - - public ServerRSocketFactory(RSocketServer server) { - this.server = server; - } - - /** - * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in - * order configure the custom {@link ByteBufAllocator} it is recommended to use the - * following setup for Reactor Netty based transport:
- * 1. For Client:
- *
{@code
-     * TcpClient.create()
-     *          ...
-     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
-     * }
- *
- * 2. For server:
- *
{@code
-     * TcpServer.create()
-     *          ...
-     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
-     * }
- * Or in case of local transport, to use corresponding factory method {@code - * LocalClientTransport.creat(String, ByteBufAllocator)} - * @param allocator instance of {@link ByteBufAllocator} - * @return this factory instance - */ - @Deprecated - public ServerRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - return this; - } - - public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - server.interceptors(registry -> registry.forConnection(interceptor)); - return this; - } - /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ - @Deprecated - public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - return addRequesterPlugin(interceptor); - } - - public ServerRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { - server.interceptors(registry -> registry.forRequester(interceptor)); - return this; - } - - /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ - @Deprecated - public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - return addResponderPlugin(interceptor); - } - - public ServerRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { - server.interceptors(registry -> registry.forResponder(interceptor)); - return this; - } - - public ServerRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - server.interceptors(registry -> registry.forSocketAcceptor(interceptor)); - return this; - } - - public ServerTransportAcceptor acceptor(SocketAcceptor acceptor) { - server.acceptor(acceptor); - return this; - } - - public ServerRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - server.payloadDecoder(payloadDecoder); - return this; - } - - public ServerRSocketFactory fragment(int mtu) { - server.fragment(mtu); - return this; - } - - /** - * @deprecated this handler is deliberately no-ops and is deprecated with no replacement. In - * order to observe errors, it is recommended to add error handler using {@code doOnError} - * on the specific logical stream. In order to observe connection, or RSocket terminal - * errors, it is recommended to hook on {@link Closeable#onClose()} handler. - */ - public ServerRSocketFactory errorConsumer(Consumer errorConsumer) { - return this; - } - - public ServerRSocketFactory lease(Supplier> supplier) { - server.lease(supplier); - return this; - } - - public ServerRSocketFactory lease() { - server.lease(Leases::new); - return this; - } - - /** @deprecated without a replacement and no longer used. */ - @Deprecated - public ServerRSocketFactory singleSubscriberRequester() { - return this; - } - - public ServerRSocketFactory resume() { - resume = resume != null ? resume : new Resume(); - server.resume(resume); - return this; - } - - public ServerRSocketFactory resumeStore( - Function storeFactory) { - resume(); - resume.storeFactory(storeFactory); - return this; - } - - public ServerRSocketFactory resumeSessionDuration(Duration sessionDuration) { - resume(); - resume.sessionDuration(sessionDuration); - return this; - } - - public ServerRSocketFactory resumeStreamTimeout(Duration streamTimeout) { - resume(); - resume.streamTimeout(streamTimeout); - return this; - } - - public ServerRSocketFactory resumeCleanupOnKeepAlive() { - resume(); - resume.cleanupStoreOnKeepAlive(); - return this; - } - - @Override - public ServerTransport.ConnectionAcceptor toConnectionAcceptor() { - return server.asConnectionAcceptor(); - } - - @Override - public Start transport(Supplier> transport) { - return () -> server.bind(transport.get()); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java deleted file mode 100644 index 22697f130..000000000 --- a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java +++ /dev/null @@ -1,28 +0,0 @@ -package io.rsocket; - -import java.util.function.BiFunction; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; - -/** - * Extends the {@link RSocket} that allows an implementer to peek at the first request payload of a - * channel. - * - * @deprecated as of 1.0 RC7 in favor of using {@link RSocket#requestChannel(Publisher)} with {@link - * Flux#switchOnFirst(BiFunction)} - */ -@Deprecated -public interface ResponderRSocket extends RSocket { - /** - * Implement this method to peak at the first payload of the incoming request stream without - * having to subscribe to Publish<Payload> payloads - * - * @param payload First payload in the stream - this is the same payload as the first payload in - * Publisher<Payload> payloads - * @param payloads Stream of request payloads. - * @return Stream of response payloads. - */ - default Flux requestChannel(Payload payload, Publisher payloads) { - return requestChannel(payloads); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java new file mode 100644 index 000000000..e19d31924 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java @@ -0,0 +1,348 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; + +/** + * {@link DuplexConnection#receive()} is a single stream on which the following type of frames + * arrive: + * + *
    + *
  • Frames for streams initiated by the initiator of the connection (client). + *
  • Frames for streams initiated by the acceptor of the connection (server). + *
+ * + *

The only way to differentiate these two frames is determining whether the stream Id is odd or + * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the + * client. + */ +class ClientServerInputMultiplexer implements CoreSubscriber, Closeable { + + private final InternalDuplexConnection serverReceiver; + private final InternalDuplexConnection clientReceiver; + private final DuplexConnection serverConnection; + private final DuplexConnection clientConnection; + private final DuplexConnection source; + private final boolean isClient; + + private Subscription s; + + private Throwable t; + + private volatile int state; + private static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ClientServerInputMultiplexer.class, "state"); + + public ClientServerInputMultiplexer( + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { + this.source = source; + this.isClient = isClient; + + this.serverReceiver = new InternalDuplexConnection(Type.SERVER, this, source); + this.clientReceiver = new InternalDuplexConnection(Type.CLIENT, this, source); + this.serverConnection = registry.initConnection(Type.SERVER, serverReceiver); + this.clientConnection = registry.initConnection(Type.CLIENT, clientReceiver); + } + + DuplexConnection asServerConnection() { + return serverConnection; + } + + DuplexConnection asClientConnection() { + return clientConnection; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(ByteBuf frame) { + int streamId = FrameHeaderCodec.streamId(frame); + final Type type; + if (streamId == 0) { + switch (FrameHeaderCodec.frameType(frame)) { + case LEASE: + case KEEPALIVE: + case ERROR: + type = isClient ? Type.CLIENT : Type.SERVER; + break; + default: + type = isClient ? Type.SERVER : Type.CLIENT; + } + } else if ((streamId & 0b1) == 0) { + type = Type.SERVER; + } else { + type = Type.CLIENT; + } + + switch (type) { + case CLIENT: + clientReceiver.onNext(frame); + break; + case SERVER: + serverReceiver.onNext(frame); + break; + } + } + + @Override + public void onComplete() { + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onComplete(); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + this.t = t; + + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onError(t); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onError(t); + } + } + + boolean notifyRequested() { + final int currentState = incrementAndGetCheckingState(); + if (currentState == Integer.MIN_VALUE) { + return false; + } + + if (currentState == 2) { + source.receive().subscribe(this); + } + + return true; + } + + int incrementAndGetCheckingState() { + int prev, next; + for (; ; ) { + prev = this.state; + + if (prev == Integer.MIN_VALUE) { + return prev; + } + + next = prev + 1; + if (STATE.compareAndSet(this, prev, next)) { + return next; + } + } + } + + @Override + public String toString() { + return "ClientServerInputMultiplexer{" + + "serverReceiver=" + + serverReceiver + + ", clientReceiver=" + + clientReceiver + + ", serverConnection=" + + serverConnection + + ", clientConnection=" + + clientConnection + + ", source=" + + source + + ", isClient=" + + isClient + + ", s=" + + s + + ", t=" + + t + + ", state=" + + state + + '}'; + } + + private static class InternalDuplexConnection extends Flux + implements Subscription, DuplexConnection { + private final Type type; + private final ClientServerInputMultiplexer clientServerInputMultiplexer; + private final DuplexConnection source; + + private volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(InternalDuplexConnection.class, "state"); + + CoreSubscriber actual; + + public InternalDuplexConnection( + Type type, + ClientServerInputMultiplexer clientServerInputMultiplexer, + DuplexConnection source) { + this.type = type; + this.clientServerInputMultiplexer = clientServerInputMultiplexer; + this.source = source; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (this.state == 0 && STATE.compareAndSet(this, 0, 1)) { + this.actual = actual; + actual.onSubscribe(this); + } else { + Operators.error( + actual, + new IllegalStateException("InternalDuplexConnection allows only single subscription")); + } + } + + @Override + public void request(long n) { + if (this.state == 1 && STATE.compareAndSet(this, 1, 2)) { + final ClientServerInputMultiplexer multiplexer = clientServerInputMultiplexer; + if (!multiplexer.notifyRequested()) { + final Throwable t = multiplexer.t; + if (t != null) { + this.actual.onError(t); + } else { + this.actual.onComplete(); + } + } + } + } + + @Override + public void cancel() { + // no ops + } + + void onNext(ByteBuf frame) { + this.actual.onNext(frame); + } + + void onComplete() { + this.actual.onComplete(); + } + + void onError(Throwable t) { + this.actual.onError(t); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + public boolean isSubscribed() { + return this.state != 0; + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public double availability() { + return source.availability(); + } + + @Override + public String toString() { + return "InternalDuplexConnection{" + + "type=" + + type + + ", source=" + + source + + ", state=" + + state + + '}'; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java new file mode 100644 index 000000000..3477b8d6d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java @@ -0,0 +1,49 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import java.nio.channels.ClosedChannelException; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +abstract class ClientSetup { + abstract Mono> init(DuplexConnection connection); +} + +class DefaultClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> sink.onRequest(__ -> sink.success(Tuples.of(Unpooled.EMPTY_BUFFER, connection)))); + } +} + +class ResumableClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> { + sink.onRequest( + __ -> { + new SetupHandlingDuplexConnection(connection, sink); + }); + + Disposable subscribe = + connection + .onClose() + .doFinally(__ -> sink.error(new ClosedChannelException())) + .subscribe(); + sink.onCancel( + () -> { + subscribe.dispose(); + connection.dispose(); + connection.receive().subscribe(); + }); + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java index 24fa8f84c..82a02268d 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java @@ -1,10 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketClient; import io.rsocket.frame.FrameType; import java.util.AbstractMap; import java.util.Map; @@ -21,6 +35,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoOperator; import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -31,13 +46,16 @@ */ class DefaultRSocketClient extends ResolvingOperator implements CoreSubscriber, CorePublisher, RSocketClient { - static final Consumer DISCARD_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored + static final Consumer DISCARD_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + ReferenceCounted referenceCounted = ((ReferenceCounted) data); + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } } } }; @@ -51,13 +69,25 @@ class DefaultRSocketClient extends ResolvingOperator final Mono source; + final Sinks.Empty onDisposeSink; + volatile Subscription s; static final AtomicReferenceFieldUpdater S = AtomicReferenceFieldUpdater.newUpdater(DefaultRSocketClient.class, Subscription.class, "s"); DefaultRSocketClient(Mono source) { - this.source = source; + this.source = unwrapReconnectMono(source); + this.onDisposeSink = Sinks.empty(); + } + + private Mono unwrapReconnectMono(Mono source) { + return source instanceof ReconnectMono ? ((ReconnectMono) source).getSource() : source; + } + + @Override + public Mono onClose() { + return this.onDisposeSink.asMono(); } @Override @@ -176,6 +206,12 @@ protected void doOnValueExpired(RSocket value) { @Override protected void doOnDispose() { Operators.terminate(S, this); + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onDisposeSink::tryEmitError, onDisposeSink::tryEmitEmpty); + } else { + onDisposeSink.tryEmitEmpty(); + } } static final class FlatMapMain implements CoreSubscriber, Context, Scannable { @@ -417,8 +453,8 @@ public void accept(RSocket rSocket, Throwable t) { @Override public void request(long n) { - this.main.request(n); super.request(n); + this.main.request(n); } public void cancel() { diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java new file mode 100644 index 000000000..a5d527f5c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -0,0 +1,295 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class FireAndForgetRequesterMono extends Mono implements Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(FireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; + + FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + actual.onSubscribe(this); + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + actual.onError(e); + return; + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + markTerminated(STATE, this); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw e; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + throw e; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + throw Exceptions.propagate(e); + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(Exceptions.unwrap(t), FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + throw Exceptions.propagate(t); + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_FNF, + this.mtu, + this.payload, + this.connection, + this.allocator, + true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + throw Exceptions.propagate(e); + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + return null; + } + + @Override + public Object scanUnsafe(Scannable.Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java new file mode 100644 index 000000000..e76fdf9ed --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -0,0 +1,183 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +final class FireAndForgetResponderSubscriber + implements CoreSubscriber, ResponderFrameHandler { + + static final Logger logger = LoggerFactory.getLogger(FireAndForgetResponderSubscriber.class); + + static final FireAndForgetResponderSubscriber INSTANCE = new FireAndForgetResponderSubscriber(); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final RequesterResponderSupport requesterResponderSupport; + final RSocket handler; + final int maxInboundPayloadSize; + + @Nullable final RequestInterceptor requestInterceptor; + + CompositeByteBuf frames; + + private FireAndForgetResponderSubscriber() { + this.streamId = 0; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.handler = handler; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Dropped Outbound error", t); + } + + @Override + public void onComplete() { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, null); + } + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + + try { + ReassemblyUtils.addFollowingFrame( + frames, followingFrame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException t) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + this.frames = null; + frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + if (!hasFollows) { + this.requesterResponderSupport.remove(this.streamId, this); + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + Mono source = this.handler.fireAndForget(payload); + source.subscribe(this); + } + } + + @Override + public final void handleCancel() { + final CompositeByteBuf frames = this.frames; + if (frames != null) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + this.frames = null; + frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java new file mode 100644 index 000000000..03b6f9e09 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java @@ -0,0 +1,224 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import reactor.util.annotation.Nullable; + +class FragmentationUtils { + + static final int MIN_MTU_SIZE = 64; + + static final int FRAME_OFFSET = // 9 bytes in total + FrameLengthCodec.FRAME_LENGTH_SIZE // includes encoded frame length bytes size + + FrameHeaderCodec.size(); // includes encoded frame headers info bytes size + static final int FRAME_OFFSET_WITH_METADATA = // 12 bytes in total + FRAME_OFFSET + + FrameLengthCodec.FRAME_LENGTH_SIZE; // include encoded metadata length bytes size + + static final int FRAME_OFFSET_WITH_INITIAL_REQUEST_N = // 13 bytes in total + FRAME_OFFSET + Integer.BYTES; // includes extra space for initialRequestN bytes size + static final int FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N = // 16 bytes in total + FRAME_OFFSET_WITH_METADATA + + Integer.BYTES; // includes extra space for initialRequestN bytes size + + static boolean isFragmentable( + int mtu, ByteBuf data, @Nullable ByteBuf metadata, boolean hasInitialRequestN) { + if (mtu == 0) { + return false; + } + + if (metadata != null) { + int remaining = + mtu + - (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA); + + return (metadata.readableBytes() + data.readableBytes()) > remaining; + } else { + int remaining = + mtu - (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET); + + return data.readableBytes() > remaining; + } + } + + static ByteBuf encodeFollowsFragment( + ByteBufAllocator allocator, + int mtu, + int streamId, + boolean complete, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + boolean follows = data.isReadable() || metadata.isReadable(); + return PayloadFrameCodec.encode( + allocator, streamId, follows, (!follows && complete), true, metadataFragment, dataFragment); + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (hasMetadata) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + if (metadata.isReadable()) { + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + case REQUEST_FNF: + return RequestFireAndForgetFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_RESPONSE: + return RequestResponseFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + // Payload and synthetic types from the responder side + case PAYLOAD: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, false, metadataFragment, dataFragment); + case NEXT: + // see https://github.com/rsocket/rsocket/blob/master/Protocol.md#handling-the-unexpected + // point 7 + case NEXT_COMPLETE: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, true, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + long initialRequestN, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length bytes + initial requestN bytes + int remaining = mtu - FRAME_OFFSET_WITH_INITIAL_REQUEST_N; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (hasMetadata) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + if (metadata.isReadable()) { + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + // Requester Side + case REQUEST_STREAM: + return RequestStreamFrameCodec.encode( + allocator, streamId, true, initialRequestN, metadataFragment, dataFragment); + case REQUEST_CHANNEL: + return RequestChannelFrameCodec.encode( + allocator, streamId, true, false, initialRequestN, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static int assertMtu(int mtu) { + if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } else { + return mtu; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java similarity index 63% rename from rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java rename to rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java index 791f5a023..6d1ee1b09 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,16 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.rsocket.core; -package io.rsocket.lease; +import io.netty.buffer.ByteBuf; -public interface LeaseStats { +interface FrameHandler { - void onEvent(EventType eventType); + void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload); - enum EventType { - ACCEPT, - REJECT, - TERMINATE - } + void handleError(Throwable t); + + void handleComplete(); + + void handleCancel(); + + void handleRequestN(long n); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java new file mode 100644 index 000000000..03ab7c257 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java @@ -0,0 +1,20 @@ +package io.rsocket.core; + +/** Handler which enables async lease permits issuing */ +interface LeasePermitHandler { + + /** + * Called by {@link RequesterLeaseTracker} when there is an available lease + * + * @return {@code true} to indicate that lease permit was consumed successfully + */ + boolean handlePermit(); + + /** + * Called by {@link RequesterLeaseTracker} when there are no lease permit available at the moment + * and the list of awaiting {@link LeasePermitHandler} reached the configured limit + * + * @param t associated lease permit rejection exception + */ + void handlePermitError(Throwable t); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java new file mode 100644 index 000000000..ad4b36e3a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.rsocket.lease.LeaseSender; +import reactor.core.publisher.Flux; + +public final class LeaseSpec { + + LeaseSender sender = Flux::never; + int maxPendingRequests = 256; + + LeaseSpec() {} + + public LeaseSpec sender(LeaseSender sender) { + this.sender = sender; + return this; + } + + /** + * Setup the maximum queued requests waiting for lease to be available. The default value is 256 + * + * @param maxPendingRequests if set to 0 the requester will terminate the request immediately if + * no leases is available + */ + public LeaseSpec maxPendingRequests(int maxPendingRequests) { + this.maxPendingRequests = maxPendingRequests; + return this; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java new file mode 100644 index 000000000..7b5d8f6c2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java @@ -0,0 +1,72 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameUtil; +import java.net.SocketAddress; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +class LoggingDuplexConnection implements DuplexConnection { + + private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); + + final DuplexConnection source; + + LoggingDuplexConnection(DuplexConnection source) { + this.source = source; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); + + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + LOGGER.debug("sending -> " + e.getClass().getSimpleName() + ": " + e.getMessage()); + + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return source + .receive() + .doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + static DuplexConnection wrapIfEnabled(DuplexConnection source) { + if (LOGGER.isDebugEnabled()) { + return new LoggingDuplexConnection(source); + } + + return source; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java new file mode 100644 index 000000000..e2512e995 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java @@ -0,0 +1,190 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValidMetadata; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.MetadataPushFrameCodec; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class MetadataPushRequesterMono extends Mono implements Scannable { + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(MetadataPushRequesterMono.class, "state"); + + final ByteBufAllocator allocator; + final Payload payload; + final int maxFrameLength; + final DuplexConnection connection; + + MetadataPushRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.connection = requesterResponderSupport.getDuplexConnection(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("MetadataPushMono allows only a single Subscriber")); + return; + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException("Metadata push should have metadata field present")); + return; + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + Operators.error(actual, e); + return; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.connection.sendFrame(0, requestFrame); + + Operators.complete(actual); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + throw new IllegalStateException("MetadataPushMono allows only a single Subscriber"); + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException("Metadata push should have metadata field present"); + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + throw e; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.connection.sendFrame(0, requestFrame); + + return null; + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(MetadataPushMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java new file mode 100644 index 000000000..4c69934e8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; + +final class MetadataPushResponderSubscriber implements CoreSubscriber { + static final Logger logger = LoggerFactory.getLogger(MetadataPushResponderSubscriber.class); + + static final MetadataPushResponderSubscriber INSTANCE = new MetadataPushResponderSubscriber(); + + private MetadataPushResponderSubscriber() {} + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + logger.debug("Dropped error", t); + } + + @Override + public void onComplete() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java index 5e62105c9..6ece319c9 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -1,33 +1,48 @@ package io.rsocket.core; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_INITIAL_REQUEST_N; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import io.netty.buffer.ByteBuf; import io.rsocket.Payload; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameLengthCodec; final class PayloadValidationUtils { static final String INVALID_PAYLOAD_ERROR_MESSAGE = - "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; + "The payload is too big to be send as a single frame with a max frame length %s. Consider enabling fragmentation."; + + static boolean isValid(int mtu, int maxFrameLength, Payload payload, boolean hasInitialRequestN) { - static boolean isValid(int mtu, Payload payload, int maxFrameLength) { if (mtu > 0) { return true; } - if (payload.hasMetadata()) { - return ((FrameHeaderCodec.size() - + FrameLengthCodec.FRAME_LENGTH_SIZE - + FrameHeaderCodec.size() - + payload.data().readableBytes() - + payload.metadata().readableBytes()) - <= maxFrameLength); + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf data = payload.data(); + + int unitSize; + if (hasMetadata) { + final ByteBuf metadata = payload.metadata(); + unitSize = + (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA) + + metadata.readableBytes() + + // metadata payload bytes + data.readableBytes(); // data payload bytes } else { - return ((FrameHeaderCodec.size() - + payload.data().readableBytes() - + FrameLengthCodec.FRAME_LENGTH_SIZE) - <= maxFrameLength); + unitSize = + (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET) + + data.readableBytes(); // data payload bytes } + + return unitSize <= maxFrameLength; + } + + static boolean isValidMetadata(int maxFrameLength, ByteBuf metadata) { + return FRAME_OFFSET + metadata.readableBytes() <= maxFrameLength; } static void assertValidateSetup(int maxFrameLength, int maxInboundPayloadSize, int mtu) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java new file mode 100644 index 000000000..32e3c229d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java @@ -0,0 +1,153 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import sun.reflect.generics.reflectiveObjects.NotImplementedException; + +/** + * Contract for performing RSocket requests. + * + *

{@link RSocketClient} differs from {@link RSocket} in a number of ways: + * + *

    + *
  • {@code RSocket} represents a "live" connection that is transient and needs to be obtained + * typically from a {@code Mono} source via {@code flatMap} or block. By contrast, + * {@code RSocketClient} is a higher level layer that contains such a {@link #source() source} + * of connections and transparently obtains and re-obtains a shared connection as needed when + * requests are made concurrently. That means an {@code RSocketClient} can simply be created + * once, even before a connection is established, and shared as a singleton across multiple + * places as you would with any other client. + *
  • For request input {@code RSocket} accepts an instance of {@code Payload} and does not allow + * more than one subscription per request because there is no way to safely re-use that input. + * By contrast {@code RSocketClient} accepts {@code Publisher} and allow + * re-subscribing which repeats the request. + *
  • {@code RSocket} can be used for sending and it can also be implemented for receiving. By + * contrast {@code RSocketClient} is used only for sending, typically from the client side + * which allows obtaining and re-obtaining connections from a source as needed. However it can + * also be used from the server side by {@link #from(RSocket) wrapping} the "live" {@code + * RSocket} for a given connection. + *
+ * + *

The example below shows how to create an {@code RSocketClient}: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ *
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + *

The below configures retry logic to use when a shared {@code RSocket} connection is obtained: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ *
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + * @since 1.1 + * @see io.rsocket.loadbalance.LoadbalanceRSocketClient + */ +public interface RSocketClient extends Closeable { + + /** + * Connect to the remote rsocket endpoint, if not yet connected. This method is a shortcut for + * {@code RSocketClient#source().subscribe()}. + * + * @return {@code true} if an attempt to connect was triggered or if already connected, or {@code + * false} if the client is terminated. + */ + default boolean connect() { + throw new NotImplementedException(); + } + + default Mono onClose() { + return Mono.error(new NotImplementedException()); + } + + /** Return the underlying source used to obtain a shared {@link RSocket} connection. */ + Mono source(); + + /** + * Perform a Fire-and-Forget interaction via {@link RSocket#fireAndForget(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Mono fireAndForget(Mono payloadMono); + + /** + * Perform a Request-Response interaction via {@link RSocket#requestResponse(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Mono requestResponse(Mono payloadMono); + + /** + * Perform a Request-Stream interaction via {@link RSocket#requestStream(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Flux requestStream(Mono payloadMono); + + /** + * Perform a Request-Channel interaction via {@link RSocket#requestChannel(Publisher)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Flux requestChannel(Publisher payloads); + + /** + * Perform a Metadata Push via {@link RSocket#metadataPush(Payload)}. Allows multiple + * subscriptions and performs a request per subscriber. + */ + Mono metadataPush(Mono payloadMono); + + /** + * Create an {@link RSocketClient} that obtains shared connections as needed, when requests are + * made, from the given {@code Mono} source. + * + * @param source the source for connections, typically prepared via {@link RSocketConnector}. + * @return the created client instance + */ + static RSocketClient from(Mono source) { + return new DefaultRSocketClient(source); + } + + /** + * Adapt the given {@link RSocket} to use as {@link RSocketClient}. This is useful to wrap the + * sending {@code RSocket} in a server. + * + *

Note: unlike an {@code RSocketClient} created via {@link + * RSocketClient#from(Mono)}, the instance returned from this factory method can only perform + * requests for as long as the given {@code RSocket} remains "live". + * + * @param rsocket the {@code RSocket} to perform requests with + * @return the created client instance + */ + static RSocketClient from(RSocket rsocket) { + return new RSocketClientAdapter(rsocket); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java new file mode 100644 index 000000000..ae8b7da97 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java @@ -0,0 +1,88 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Simple adapter from {@link RSocket} to {@link RSocketClient}. This is useful in code that needs + * to deal with both in the same way. When connecting to a server, typically {@link RSocketClient} + * is expected to be used, but in a responder (client or server), it is necessary to interact with + * {@link RSocket} to make requests to the remote end. + * + * @since 1.1 + */ +class RSocketClientAdapter implements RSocketClient { + + private final RSocket rsocket; + + public RSocketClientAdapter(RSocket rsocket) { + this.rsocket = rsocket; + } + + public RSocket rsocket() { + return rsocket; + } + + @Override + public boolean connect() { + throw new UnsupportedOperationException("Connect does not apply to a server side RSocket"); + } + + @Override + public Mono source() { + return Mono.just(rsocket); + } + + @Override + public Mono onClose() { + return rsocket.onClose(); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return payloadMono.flatMap(rsocket::fireAndForget); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return payloadMono.flatMap(rsocket::requestResponse); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return payloadMono.flatMapMany(rsocket::requestStream); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return rsocket.requestChannel(payloads); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return payloadMono.flatMap(rsocket::metadataPush); + } + + @Override + public void dispose() { + rsocket.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index fdb4859cf..de494c4e3 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -15,7 +15,9 @@ */ package io.rsocket.core; +import static io.rsocket.core.FragmentationUtils.assertMtu; import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -23,21 +25,18 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketClient; import io.rsocket.SocketAcceptor; -import io.rsocket.fragmentation.FragmentationDuplexConnection; -import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.keepalive.KeepAliveHandler; -import io.rsocket.lease.LeaseStats; -import io.rsocket.lease.Leases; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumableFramesStore; import io.rsocket.transport.ClientTransport; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; @@ -48,7 +47,7 @@ import java.util.function.Supplier; import reactor.core.Disposable; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; import reactor.util.function.Tuples; import reactor.util.retry.Retry; @@ -61,18 +60,20 @@ *

{@code
  * import io.rsocket.transport.netty.client.TcpClientTransport;
  *
- * RSocketClient client =
- *         RSocketConnector.createRSocketClient(TcpClientTransport.create("localhost", 7000));
+ * Mono source =
+ *         RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
  * }
* *

To customize connection settings before connecting: * *

{@code
- * RSocketClient client =
+ * Mono source =
  *         RSocketConnector.create()
  *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
  *                 .dataMimeType("application/cbor")
- *                 .toRSocketClient(TcpClientTransport.create("localhost", 7000));
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
  * }
*/ public class RSocketConnector { @@ -92,7 +93,8 @@ public class RSocketConnector { private Retry retrySpec; private Resume resume; - private Supplier> leasesSupplier; + + @Nullable private Consumer leaseConfigurer; private int mtu = 0; private int maxInboundPayloadSize = Integer.MAX_VALUE; @@ -112,7 +114,7 @@ public static RSocketConnector create() { * Static factory method to connect with default settings, effectively a shortcut for: * *
-   * RSocketConnector.create().connectWith(transport);
+   * RSocketConnector.create().connect(transport);
    * 
* * @param transport the transport of choice to connect with @@ -122,15 +124,6 @@ public static Mono connectWith(ClientTransport transport) { return RSocketConnector.create().connect(() -> transport); } - /** - * @param transport - * @return - * @since 1.0.1 - */ - public static RSocketClient createRSocketClient(ClientTransport transport) { - return RSocketConnector.create().toRSocketClient(transport); - } - /** * Provide a {@code Mono} from which to obtain the {@code Payload} for the initial SETUP frame. * Data and metadata should be formatted according to the MIME types specified via {@link @@ -192,9 +185,9 @@ public RSocketConnector dataMimeType(String dataMimeType) { *

For metadata encoding, consider using one of the following encoders: * *

    - *
  • {@link io.rsocket.metadata.CompositeMetadataFlyweight Composite Metadata} - *
  • {@link io.rsocket.metadata.TaggingMetadataFlyweight Routing} - *
  • {@link io.rsocket.metadata.security.AuthMetadataFlyweight Authentication} + *
  • {@link io.rsocket.metadata.CompositeMetadataCodec Composite Metadata} + *
  • {@link io.rsocket.metadata.TaggingMetadataCodec Routing} + *
  • {@link io.rsocket.metadata.AuthMetadataCodec Authentication} *
* *

For more on the above metadata formats, see the corresponding For server-to-server connections, a reasonable time interval between client {@code * KEEPALIVE} frames is 500ms. *

  • For mobile-to-server connections, the time interval between client {@code KEEPALIVE} - * frames is often > 30,000ms. + * frames is often {@code >} 30,000ms. * * *

    By default these are set to 20 seconds and 90 seconds respectively. @@ -410,18 +403,43 @@ public RSocketConnector resume(Resume resume) { * *

    {@code
        * Mono rocketMono =
    -   *         RSocketConnector.create().lease(Leases::new).connect(transport);
    +   *         RSocketConnector.create()
    +   *                         .lease()
    +   *                         .connect(transport);
    +   * }
    + * + *

    By default this is not enabled. + * + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease() { + return lease((config -> {})); + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

    Example usage: + * + *

    {@code
    +   * Mono rocketMono =
    +   *         RSocketConnector.create()
    +   *                         .lease(spec -> spec.maxPendingRequests(128))
    +   *                         .connect(transport);
        * }
    * *

    By default this is not enabled. * - * @param supplier supplier for a {@link Leases} + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring * @return the same instance for method chaining * @see Lease * Semantics */ - public RSocketConnector lease(Supplier> supplier) { - this.leasesSupplier = supplier; + public RSocketConnector lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; return this; } @@ -441,8 +459,7 @@ public RSocketConnector lease(Supplier> supplier) { * and Reassembly */ public RSocketConnector maxInboundPayloadSize(int maxInboundPayloadSize) { - this.maxInboundPayloadSize = - ReassemblyDuplexConnection.assertInboundPayloadSize(maxInboundPayloadSize); + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); return this; } @@ -460,7 +477,7 @@ public RSocketConnector maxInboundPayloadSize(int maxInboundPayloadSize) { * and Reassembly */ public RSocketConnector fragment(int mtu) { - this.mtu = FragmentationDuplexConnection.assertMtu(mtu); + this.mtu = assertMtu(mtu); return this; } @@ -488,37 +505,6 @@ public RSocketConnector payloadDecoder(PayloadDecoder decoder) { return this; } - /** - * Create {@link RSocketClient} that will use {@link #connect(ClientTransport)} as its source to - * obtain a live, shared {@code RSocket} when the first request is made, and also on subsequent - * requests after the connection is lost. - * - *

    The following transports are available through additional RSocket Java modules: - * - *

      - *
    • {@link io.rsocket.transport.netty.client.TcpClientTransport TcpClientTransport} via - * {@code rsocket-transport-netty}. - *
    • {@link io.rsocket.transport.netty.client.WebsocketClientTransport - * WebsocketClientTransport} via {@code rsocket-transport-netty}. - *
    • {@link io.rsocket.transport.local.LocalClientTransport LocalClientTransport} via {@code - * rsocket-transport-local} - *
    - * - * @param transport the transport of choice to connect with - * @return a {@code RSocketClient} with not established connection. Note, connection will be - * established on the first request - * @since 1.0.1 - */ - public RSocketClient toRSocketClient(ClientTransport transport) { - Mono source = connect0(() -> transport); - - if (retrySpec != null) { - source = source.retryWhen(retrySpec); - } - - return new DefaultRSocketClient(source); - } - /** * Connect with the given transport and obtain a live {@link RSocket} to use for making requests. * Each subscriber to the returned {@code Mono} receives a new connection, if neither {@link @@ -552,19 +538,6 @@ public Mono connect(ClientTransport transport) { * @return a {@code Mono} with the connected RSocket */ public Mono connect(Supplier transportSupplier) { - return this.connect0(transportSupplier) - .as( - source -> { - if (retrySpec != null) { - return new ReconnectMono<>( - source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); - } else { - return source; - } - }); - } - - private Mono connect0(Supplier transportSupplier) { return Mono.fromSupplier(transportSupplier) .flatMap( ct -> { @@ -578,12 +551,10 @@ private Mono connect0(Supplier transportSupplier) { }) .flatMap(transport -> transport.connect()) .map( - connection -> - mtu > 0 - ? new FragmentationDuplexConnection( - connection, mtu, maxInboundPayloadSize, "client") - : new ReassemblyDuplexConnection( - connection, maxInboundPayloadSize)); + sourceConnection -> + interceptors.initConnection( + DuplexConnectionInterceptor.Type.SOURCE, sourceConnection)) + .map(source -> LoggingDuplexConnection.wrapIfEnabled(source)); return connectionMono .flatMap( @@ -594,65 +565,24 @@ private Mono connect0(Supplier transportSupplier) { .doOnError(ex -> connection.dispose()) .doOnCancel(connection::dispose)) .flatMap( - tuple -> { - DuplexConnection connection = tuple.getT1(); - Payload setupPayload = tuple.getT2(); + tuple2 -> { + DuplexConnection sourceConnection = tuple2.getT1(); + Payload setupPayload = tuple2.getT2(); + boolean leaseEnabled = leaseConfigurer != null; + boolean resumeEnabled = resume != null; + // TODO: add LeaseClientSetup + ClientSetup clientSetup = new DefaultClientSetup(); ByteBuf resumeToken; - KeepAliveHandler keepAliveHandler; - DuplexConnection wrappedConnection; - if (resume != null) { + if (resumeEnabled) { resumeToken = resume.getTokenSupplier().get(); - ClientRSocketSession session = - new ClientRSocketSession( - connection, - resume.getSessionDuration(), - resume.getRetry(), - resume.getStoreFactory(CLIENT_TAG).apply(resumeToken), - resume.getStreamTimeout(), - resume.isCleanupStoreOnKeepAlive()) - .continueWith(connectionMono) - .resumeToken(resumeToken); - keepAliveHandler = - new KeepAliveHandler.ResumableKeepAliveHandler( - session.resumableConnection()); - wrappedConnection = session.resumableConnection(); } else { resumeToken = Unpooled.EMPTY_BUFFER; - keepAliveHandler = - new KeepAliveHandler.DefaultKeepAliveHandler(connection); - wrappedConnection = connection; } - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(wrappedConnection, interceptors, true); - - boolean leaseEnabled = leasesSupplier != null; - Leases leases = leaseEnabled ? leasesSupplier.get() : null; - RequesterLeaseHandler requesterLeaseHandler = - leaseEnabled - ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) - : RequesterLeaseHandler.None; - - RSocket rSocketRequester = - new RSocketRequester( - multiplexer.asClientConnection(), - payloadDecoder, - StreamIdSupplier.clientSupplier(), - mtu, - maxFrameLength, - (int) keepAliveInterval.toMillis(), - (int) keepAliveMaxLifeTime.toMillis(), - keepAliveHandler, - requesterLeaseHandler, - Schedulers.single(Schedulers.parallel())); - - RSocket wrappedRSocketRequester = - interceptors.initRequester(rSocketRequester); - ByteBuf setupFrame = SetupFrameCodec.encode( - wrappedConnection.alloc(), + sourceConnection.alloc(), leaseEnabled, (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), @@ -661,46 +591,156 @@ private Mono connect0(Supplier transportSupplier) { dataMimeType, setupPayload); - SocketAcceptor acceptor = - this.acceptor != null - ? this.acceptor - : SocketAcceptor.with(new RSocket() {}); + sourceConnection.sendFrame(0, setupFrame.retainedSlice()); - ConnectionSetupPayload setup = - new DefaultConnectionSetupPayload(setupFrame); - - return interceptors - .initSocketAcceptor(acceptor) - .accept(setup, wrappedRSocketRequester) + return clientSetup + .init(sourceConnection) .flatMap( - rSocketHandler -> { - RSocket wrappedRSocketHandler = - interceptors.initResponder(rSocketHandler); - - ResponderLeaseHandler responderLeaseHandler = - leaseEnabled - ? new ResponderLeaseHandler.Impl<>( - CLIENT_TAG, - wrappedConnection.alloc(), - leases.sender(), - leases.stats()) - : ResponderLeaseHandler.None; - - RSocket rSocketResponder = - new RSocketResponder( - multiplexer.asServerConnection(), - wrappedRSocketHandler, + tuple -> { + // should be used if lease setup sequence; + // See: + // https://github.com/rsocket/rsocket/blob/master/Protocol.md#sequences-with-lease + final ByteBuf serverResponse = tuple.getT1(); + final DuplexConnection clientServerConnection = tuple.getT2(); + final KeepAliveHandler keepAliveHandler; + final DuplexConnection wrappedConnection; + final InitializingInterceptorRegistry interceptors = + this.interceptors; + + if (resumeEnabled) { + final ResumableFramesStore resumableFramesStore = + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + CLIENT_TAG, + resumeToken, + clientServerConnection, + resumableFramesStore); + final ResumableClientSetup resumableClientSetup = + new ResumableClientSetup(); + final ClientRSocketSession session = + new ClientRSocketSession( + resumeToken, + resumableDuplexConnection, + connectionMono, + resumableClientSetup::init, + resumableFramesStore, + resume.getSessionDuration(), + resume.getRetry(), + resume.isCleanupStoreOnKeepAlive()); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler( + resumableDuplexConnection, session, session); + wrappedConnection = resumableDuplexConnection; + } else { + keepAliveHandler = + new KeepAliveHandler.DefaultKeepAliveHandler(); + wrappedConnection = clientServerConnection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer( + wrappedConnection, interceptors, true); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker( + CLIENT_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } + + final Sinks.Empty requesterOnAllClosedSink = + Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = + Sinks.unsafe().empty(); + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), payloadDecoder, - responderLeaseHandler, + StreamIdSupplier.clientSupplier(), mtu, - maxFrameLength); - - return wrappedConnection - .sendOne(setupFrame.retain()) - .thenReturn(wrappedRSocketRequester); - }) - .doFinally(signalType -> setup.release()); + maxFrameLength, + maxInboundPayloadSize, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), + requesterOnAllClosedSink.asMono())); + + RSocket wrappedRSocketRequester = + interceptors.initRequester(rSocketRequester); + + SocketAcceptor acceptor = + this.acceptor != null + ? this.acceptor + : SocketAcceptor.with(new RSocket() {}); + + ConnectionSetupPayload setup = + new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .map( + rSocketHandler -> { + RSocket wrappedRSocketHandler = + interceptors.initResponder(rSocketHandler); + + ResponderLeaseTracker responderLeaseTracker = + leaseEnabled + ? new ResponderLeaseTracker( + CLIENT_TAG, + wrappedConnection, + leases.sender) + : null; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + responderLeaseTracker, + mtu, + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled + && leases.sender + instanceof TrackingLeaseSender + ? rSocket -> + interceptors + .initResponderRequestInterceptor( + rSocket, + (RequestInterceptor) + leases.sender) + : interceptors + ::initResponderRequestInterceptor, + responderOnAllClosedSink); + + return wrappedRSocketRequester; + }) + .doFinally(signalType -> setup.release()); + }); }); + }) + .as( + source -> { + if (retrySpec != null) { + return new ReconnectMono<>( + source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); + } else { + return source; + } }); } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 56d301ebd..b8a9c00ff 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,102 +16,60 @@ package io.rsocket.core; -import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; -import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.IllegalReferenceCountException; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.Exceptions; -import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; -import io.rsocket.frame.MetadataPushFrameCodec; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.frame.RequestChannelFrameCodec; -import io.rsocket.frame.RequestFireAndForgetFrameCodec; import io.rsocket.frame.RequestNFrameCodec; -import io.rsocket.frame.RequestResponseFrameCodec; -import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.SynchronizedIntObjectHashMap; -import io.rsocket.internal.UnboundedProcessor; import io.rsocket.keepalive.KeepAliveFramesAcceptor; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.keepalive.KeepAliveSupport; -import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.ArrayList; +import java.util.Collection; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; -import org.reactivestreams.Processor; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; -import reactor.core.publisher.UnicastProcessor; -import reactor.core.scheduler.Scheduler; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; -import reactor.util.concurrent.Queues; /** * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer */ -class RSocketRequester implements RSocket { +class RSocketRequester extends RequesterResponderSupport implements RSocket { private static final Logger LOGGER = LoggerFactory.getLogger(RSocketRequester.class); private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); - private static final Consumer DROPPED_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored - } - } - }; static { CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); } private volatile Throwable terminationError; - private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = AtomicReferenceFieldUpdater.newUpdater( RSocketRequester.class, Throwable.class, "terminationError"); - private final DuplexConnection connection; - private final PayloadDecoder payloadDecoder; - private final StreamIdSupplier streamIdSupplier; - private final IntObjectMap senders; - private final IntObjectMap> receivers; - private final UnboundedProcessor sendProcessor; - private final int mtu; - private final int maxFrameLength; - private final RequesterLeaseHandler leaseHandler; - private final ByteBufAllocator allocator; + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + private final Sinks.Empty onThisSideClosedSink; + private final Mono onAllClosed; private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; - private final MonoProcessor onClose; - private final Scheduler serialScheduler; RSocketRequester( DuplexConnection connection, @@ -119,37 +77,40 @@ class RSocketRequester implements RSocket { StreamIdSupplier streamIdSupplier, int mtu, int maxFrameLength, + int maxInboundPayloadSize, int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, - RequesterLeaseHandler leaseHandler, - Scheduler serialScheduler) { - this.connection = connection; - this.allocator = connection.alloc(); - this.payloadDecoder = payloadDecoder; - this.streamIdSupplier = streamIdSupplier; - this.mtu = mtu; - this.maxFrameLength = maxFrameLength; - this.leaseHandler = leaseHandler; - this.senders = new SynchronizedIntObjectHashMap<>(); - this.receivers = new SynchronizedIntObjectHashMap<>(); - this.onClose = MonoProcessor.create(); - this.serialScheduler = serialScheduler; + Function requestInterceptorFunction, + @Nullable RequesterLeaseTracker requesterLeaseTracker, + Sinks.Empty onThisSideClosedSink, + Mono onAllClosed) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + streamIdSupplier, + requestInterceptorFunction); + + this.requesterLeaseTracker = requesterLeaseTracker; + this.onThisSideClosedSink = onThisSideClosedSink; + this.onAllClosed = onAllClosed; // DO NOT Change the order here. The Send processor must be subscribed to before receiving - this.sendProcessor = new UnboundedProcessor<>(); - - connection.onClose().subscribe(null, this::tryTerminateOnConnectionError, this::tryShutdown); - connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); + connection.onClose().subscribe(null, this::tryShutdown, this::tryShutdown); connection.receive().subscribe(this::handleIncomingFrames, e -> {}); if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { KeepAliveSupport keepAliveSupport = - new ClientKeepAliveSupport(this.allocator, keepAliveTickPeriod, keepAliveAckTimeout); + new ClientKeepAliveSupport(this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); this.keepAliveFramesAcceptor = keepAliveHandler.start( - keepAliveSupport, sendProcessor::onNextPrioritized, this::tryTerminateOnKeepAlive); + keepAliveSupport, + (keepAliveFrame) -> connection.sendFrame(0, keepAliveFrame), + this::tryTerminateOnKeepAlive); } else { keepAliveFramesAcceptor = null; } @@ -157,473 +118,96 @@ class RSocketRequester implements RSocket { @Override public Mono fireAndForget(Payload payload) { - return handleFireAndForget(payload); + if (this.requesterLeaseTracker == null) { + return new FireAndForgetRequesterMono(payload, this); + } else { + return new SlowFireAndForgetRequesterMono(payload, this); + } } @Override public Mono requestResponse(Payload payload) { - return handleRequestResponse(payload); + return new RequestResponseRequesterMono(payload, this); } @Override public Flux requestStream(Payload payload) { - return handleRequestStream(payload); + return new RequestStreamRequesterFlux(payload, this); } @Override public Flux requestChannel(Publisher payloads) { - return handleChannel(Flux.from(payloads)); + return new RequestChannelRequesterFlux(payloads, this); } @Override public Mono metadataPush(Payload payload) { - return handleMetadataPush(payload); - } + Throwable terminationError = this.terminationError; + if (terminationError != null) { + payload.release(); + return Mono.error(terminationError); + } - @Override - public double availability() { - return Math.min(connection.availability(), leaseHandler.availability()); + return new MetadataPushRequesterMono(payload, this); } @Override - public void dispose() { - tryShutdown(); + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; } @Override - public boolean isDisposed() { - return terminationError != null; - } + public int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); - @Override - public Mono onClose() { - return onClose; - } - - private Mono handleFireAndForget(Payload payload) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); + Throwable terminationError = this.terminationError; + if (terminationError != null) { + throw reactor.core.Exceptions.propagate(terminationError); } - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); - } - - final AtomicBoolean once = new AtomicBoolean(); - - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("FireAndForgetMono allows only a single subscriber")); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - return Mono.error(lh.leaseError()); - } - - final int streamId = streamIdSupplier.nextStreamId(receivers); - final ByteBuf requestFrame = - RequestFireAndForgetFrameCodec.encodeReleasingPayload( - allocator, streamId, payload); - - sendProcessor.onNext(requestFrame); - - return Mono.empty(); - }) - .subscribeOn(serialScheduler); + return nextStreamId; } - private Mono handleRequestResponse(final Payload payload) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); - } + @Override + public int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + Throwable terminationError = this.terminationError; + if (terminationError != null) { + super.remove(nextStreamId, frameHandler); + throw reactor.core.Exceptions.propagate(terminationError); } - final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); - final AtomicBoolean once = new AtomicBoolean(); - - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("RequestResponseMono allows only a single subscriber")); - } - - return receiver - .next() - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - receiver.onError(lh.leaseError()); - return; - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - ByteBuf requestResponseFrame = - RequestResponseFrameCodec.encodeReleasingPayload( - allocator, streamId, payload); - - receivers.put(streamId, receiver); - sendProcessor.onNext(requestResponseFrame); - } - - @Override - void hookOnCancel() { - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else { - payload.release(); - } - } - - @Override - public void hookOnTerminal(SignalType signalType) { - receivers.remove(streamId, receiver); - } - })) - .subscribeOn(serialScheduler) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - }); + return nextStreamId; } - private Flux handleRequestStream(final Payload payload) { - if (payload.refCnt() <= 0) { - return Flux.error(new IllegalReferenceCountException()); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Flux.error(t); - } - - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + @Override + public double availability() { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + return Math.min(getDuplexConnection().availability(), requesterLeaseTracker.availability()); + } else { + return getDuplexConnection().availability(); } - - final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); - final AtomicBoolean once = new AtomicBoolean(); - - return Flux.defer( - () -> { - if (once.getAndSet(true)) { - return Flux.error( - new IllegalStateException("RequestStreamFlux allows only a single subscriber")); - } - - return receiver - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - payload.release(); - receiver.onError(lh.leaseError()); - return; - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - ByteBuf requestStreamFrame = - RequestStreamFrameCodec.encodeReleasingPayload( - allocator, streamId, n, payload); - - receivers.put(streamId, receiver); - - sendProcessor.onNext(requestStreamFrame); - } - - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } - - sendProcessor.onNext( - RequestNFrameCodec.encode(allocator, streamId, n)); - } - - @Override - void hookOnCancel() { - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else { - payload.release(); - } - } - - @Override - void hookOnTerminal(SignalType signalType) { - receivers.remove(streamId); - } - })) - .subscribeOn(serialScheduler, false) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - }); } - private Flux handleChannel(Flux request) { - if (isDisposed()) { - final Throwable t = terminationError; - return Flux.error(t); + @Override + public void dispose() { + if (terminationError != null) { + return; } - return request - .switchOnFirst( - (s, flux) -> { - Payload payload = s.get(); - if (payload != null) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - return Mono.error(t); - } - return handleChannel(payload, flux); - } else { - return flux; - } - }, - false) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + getDuplexConnection().sendErrorAndClose(new ConnectionErrorException("Disposed")); } - private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { - final UnboundedProcessor sendProcessor = this.sendProcessor; - - final UnicastProcessor receiver = UnicastProcessor.create(); - - return receiver - .transform( - Operators.lift( - (s, actual) -> - new RequestOperator(actual) { - - final BaseSubscriber upstreamSubscriber = - new BaseSubscriber() { - - boolean first = true; - - @Override - protected void hookOnSubscribe(Subscription subscription) { - // noops - } - - @Override - protected void hookOnNext(Payload payload) { - if (first) { - // need to skip first since we have already sent it - // no need to release it since it was released earlier on the - // request - // establishment - // phase - first = false; - request(1); - return; - } - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - // no need to send any errors. - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - receiver.onError(t); - return; - } - final ByteBuf frame = - PayloadFrameCodec.encodeNextReleasingPayload( - allocator, streamId, payload); - - sendProcessor.onNext(frame); - } - - @Override - protected void hookOnComplete() { - ByteBuf frame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sendProcessor.onNext(frame); - } - - @Override - protected void hookOnError(Throwable t) { - ByteBuf frame = ErrorFrameCodec.encode(allocator, streamId, t); - sendProcessor.onNext(frame); - receiver.onError(t); - } - - @Override - protected void hookFinally(SignalType type) { - senders.remove(streamId, this); - } - }; - - @Override - void hookOnFirstRequest(long n) { - if (isDisposed()) { - initialPayload.release(); - final Throwable t = terminationError; - upstreamSubscriber.cancel(); - receiver.onError(t); - return; - } - - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - initialPayload.release(); - receiver.onError(lh.leaseError()); - return; - } - - final int streamId = streamIdSupplier.nextStreamId(receivers); - this.streamId = streamId; - - final ByteBuf frame = - RequestChannelFrameCodec.encodeReleasingPayload( - allocator, streamId, false, n, initialPayload); - - senders.put(streamId, upstreamSubscriber); - receivers.put(streamId, receiver); - - inboundFlux - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) - .subscribe(upstreamSubscriber); - - sendProcessor.onNext(frame); - } - - @Override - void hookOnRemainingRequests(long n) { - if (receiver.isDisposed()) { - return; - } - - sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); - } - - @Override - void hookOnCancel() { - senders.remove(streamId, upstreamSubscriber); - if (receivers.remove(streamId, receiver)) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } - } - - @Override - void hookOnTerminal(SignalType signalType) { - if (signalType == SignalType.ON_ERROR) { - upstreamSubscriber.cancel(); - } - receivers.remove(streamId, receiver); - } - - @Override - public void cancel() { - upstreamSubscriber.cancel(); - super.cancel(); - } - })) - .subscribeOn(serialScheduler, false); + @Override + public boolean isDisposed() { + return terminationError != null; } - private Mono handleMetadataPush(Payload payload) { - if (payload.refCnt() <= 0) { - return Mono.error(new IllegalReferenceCountException()); - } - - if (isDisposed()) { - Throwable err = this.terminationError; - payload.release(); - return Mono.error(err); - } - - if (!PayloadValidationUtils.isValid(this.mtu, payload, maxFrameLength)) { - payload.release(); - return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); - } - - final AtomicBoolean once = new AtomicBoolean(); - - return Mono.defer( - () -> { - if (once.getAndSet(true)) { - return Mono.error( - new IllegalStateException("MetadataPushMono allows only a single subscriber")); - } - - if (isDisposed()) { - payload.release(); - final Throwable t = terminationError; - return Mono.error(t); - } - - ByteBuf metadataPushFrame = - MetadataPushFrameCodec.encodeReleasingPayload(allocator, payload); - - sendProcessor.onNextPrioritized(metadataPushFrame); - - return Mono.empty(); - }); + @Override + public Mono onClose() { + return onAllClosed; } private void handleIncomingFrames(ByteBuf frame) { @@ -635,10 +219,11 @@ private void handleIncomingFrames(ByteBuf frame) { } else { handleFrame(streamId, type, frame); } - frame.release(); } catch (Throwable t) { - ReferenceCountUtil.safeRelease(frame); - throw reactor.core.Exceptions.propagate(t); + LOGGER.error("Unexpected error during frame handling", t); + final ConnectionErrorException error = + new ConnectionErrorException("Unexpected error during frame handling", t); + getDuplexConnection().sendErrorAndClose(error); } } @@ -648,7 +233,7 @@ private void handleStreamZero(FrameType type, ByteBuf frame) { tryTerminateOnZeroError(frame); break; case LEASE: - leaseHandler.receive(frame); + requesterLeaseTracker.handleLeaseFrame(frame); break; case KEEPALIVE: if (keepAliveFramesAcceptor != null) { @@ -664,79 +249,42 @@ private void handleStreamZero(FrameType type, ByteBuf frame) { } private void handleFrame(int streamId, FrameType type, ByteBuf frame) { - Subscriber receiver = receivers.get(streamId); + FrameHandler receiver = this.get(streamId); + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + switch (type) { - case NEXT: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - receiver.onNext(payloadDecoder.apply(frame)); - break; case NEXT_COMPLETE: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); + receiver.handleNext(frame, false, true); + break; + case NEXT: + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); break; case COMPLETE: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - receiver.onComplete(); - receivers.remove(streamId); + receiver.handleComplete(); break; case ERROR: - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - return; - } - - // FIXME: when https://github.com/reactor/reactor-core/issues/2176 is resolved - // This is workaround to handle specific Reactor related case when - // onError call may not return normally - try { - receiver.onError(Exceptions.from(streamId, frame)); - } catch (RuntimeException e) { - if (reactor.core.Exceptions.isBubbling(e) - || reactor.core.Exceptions.isErrorCallbackNotImplemented(e)) { - if (LOGGER.isDebugEnabled()) { - Throwable unwrapped = reactor.core.Exceptions.unwrap(e); - LOGGER.debug("Unhandled dropped exception", unwrapped); - } - } - } - - receivers.remove(streamId); + receiver.handleError(Exceptions.from(streamId, frame)); break; case CANCEL: - { - Subscription sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - break; - } + receiver.handleCancel(); + break; case REQUEST_N: - { - Subscription sender = senders.get(streamId); - if (sender != null) { - long n = RequestNFrameCodec.requestN(frame); - sender.request(n); - } - break; - } + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + break; default: throw new IllegalStateException( "Requester received unsupported frame on stream " + streamId + ": " + frame.toString()); } } + @SuppressWarnings("ConstantConditions") private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { - if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { + if (!super.streamIdSupplier.isBeforeOrCurrent(streamId)) { if (type == FrameType.ERROR) { // message for stream that has never existed, we have a problem with // the overall connection and must tear down @@ -759,15 +307,39 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBu // so ignore (cancellation is async so there is a race condition) } - private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { + private void tryTerminateOnKeepAlive(KeepAliveSupport.KeepAlive keepAlive) { tryTerminate( () -> new ConnectionErrorException( String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + getDuplexConnection().dispose(); } - private void tryTerminateOnConnectionError(Throwable e) { - tryTerminate(() -> e); + private void tryShutdown(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.info( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } } private void tryTerminateOnZeroError(ByteBuf errorFrame) { @@ -775,64 +347,99 @@ private void tryTerminateOnZeroError(ByteBuf errorFrame) { } private void tryTerminate(Supplier errorSupplier) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - serialScheduler.schedule(() -> terminate(e)); + terminate(e); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); } } } private void tryShutdown() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } if (terminationError == null) { if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { - serialScheduler.schedule(() -> terminate(CLOSED_CHANNEL_EXCEPTION)); + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); } } } private void terminate(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing requester " + getDuplexConnection() + " due to " + e); + } if (keepAliveFramesAcceptor != null) { keepAliveFramesAcceptor.dispose(); } - connection.dispose(); - leaseHandler.dispose(); - - receivers - .values() - .forEach( - receiver -> { - try { - receiver.onError(e); - } catch (Throwable t) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); - senders - .values() - .forEach( - sender -> { - try { - sender.cancel(); - } catch (Throwable t) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); - senders.clear(); - receivers.clear(); - sendProcessor.dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + requesterLeaseTracker.dispose(e); + } + + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + try { + handler.handleError(e); + } catch (Throwable ignored) { + } + } + } + if (e == CLOSED_CHANNEL_EXCEPTION) { - onClose.onComplete(); + onThisSideClosedSink.tryEmitEmpty(); } else { - onClose.onError(e); + onThisSideClosedSink.tryEmitError(e); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("requester closed " + getDuplexConnection()); } - } - - private void handleSendProcessorError(Throwable t) { - connection.dispose(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 581605ff4..50c5ba54c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,149 +16,97 @@ package io.rsocket.core; -import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; - import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.IllegalReferenceCountException; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.frame.*; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.SynchronizedIntObjectHashMap; -import io.rsocket.internal.UnboundedProcessor; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; -import java.util.function.LongConsumer; +import java.util.function.Function; import java.util.function.Supplier; -import org.reactivestreams.Processor; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.Exceptions; -import reactor.core.publisher.*; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ -class RSocketResponder implements RSocket { +class RSocketResponder extends RequesterResponderSupport implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class); - private static final Consumer DROPPED_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored - } - } - }; private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); - private final DuplexConnection connection; private final RSocket requestHandler; + private final Sinks.Empty onThisSideClosedSink; - @SuppressWarnings("deprecation") - private final io.rsocket.ResponderRSocket responderRSocket; - - private final PayloadDecoder payloadDecoder; - private final ResponderLeaseHandler leaseHandler; - private final Disposable leaseHandlerDisposable; + @Nullable private final ResponderLeaseTracker leaseHandler; private volatile Throwable terminationError; private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = AtomicReferenceFieldUpdater.newUpdater( RSocketResponder.class, Throwable.class, "terminationError"); - private final int mtu; - private final int maxFrameLength; - - private final IntObjectMap sendingSubscriptions; - private final IntObjectMap> channelProcessors; - - private final UnboundedProcessor sendProcessor; - private final ByteBufAllocator allocator; - RSocketResponder( DuplexConnection connection, RSocket requestHandler, PayloadDecoder payloadDecoder, - ResponderLeaseHandler leaseHandler, + @Nullable ResponderLeaseTracker leaseHandler, int mtu, - int maxFrameLength) { - this.connection = connection; - this.allocator = connection.alloc(); - this.mtu = mtu; - this.maxFrameLength = maxFrameLength; + int maxFrameLength, + int maxInboundPayloadSize, + Function requestInterceptorFunction, + Sinks.Empty onThisSideClosedSink) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + null, + requestInterceptorFunction); this.requestHandler = requestHandler; - this.responderRSocket = - (requestHandler instanceof io.rsocket.ResponderRSocket) - ? (io.rsocket.ResponderRSocket) requestHandler - : null; - this.payloadDecoder = payloadDecoder; this.leaseHandler = leaseHandler; - this.sendingSubscriptions = new SynchronizedIntObjectHashMap<>(); - this.channelProcessors = new SynchronizedIntObjectHashMap<>(); - - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - // connections - this.sendProcessor = new UnboundedProcessor<>(); - - connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); + this.onThisSideClosedSink = onThisSideClosedSink; - connection.receive().subscribe(this::handleFrame, e -> {}); - leaseHandlerDisposable = leaseHandler.send(sendProcessor::onNextPrioritized); - - this.connection + connection .onClose() .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); - } - private void handleSendProcessorError(Throwable t) { - sendingSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); - - channelProcessors - .values() - .forEach( - subscription -> { - try { - subscription.onError(t); - } catch (Throwable e) { - if (LOGGER.isDebugEnabled()) { - LOGGER.debug("Dropped exception", t); - } - } - }); + connection.receive().subscribe(this::handleFrame, e -> {}); } private void tryTerminateOnConnectionError(Throwable e) { + if (LOGGER.isDebugEnabled()) { + + LOGGER.debug("Try terminate connection on responder side"); + } tryTerminate(() -> e); } private void tryTerminateOnConnectionClose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.info("Try terminate connection on responder side"); + } tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); } @@ -166,7 +114,7 @@ private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - cleanup(e); + doOnDispose(); } } } @@ -174,12 +122,7 @@ private void tryTerminate(Supplier errorSupplier) { @Override public Mono fireAndForget(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.fireAndForget(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.fireAndForget(payload); } catch (Throwable t) { return Mono.error(t); } @@ -188,12 +131,7 @@ public Mono fireAndForget(Payload payload) { @Override public Mono requestResponse(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestResponse(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.requestResponse(payload); } catch (Throwable t) { return Mono.error(t); } @@ -202,12 +140,7 @@ public Mono requestResponse(Payload payload) { @Override public Flux requestStream(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestStream(payload); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestStream(payload); } catch (Throwable t) { return Flux.error(t); } @@ -216,24 +149,7 @@ public Flux requestStream(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - return Flux.error(leaseHandler.leaseError()); - } - } catch (Throwable t) { - return Flux.error(t); - } - } - - private Flux requestChannel(Payload payload, Publisher payloads) { - try { - if (leaseHandler.useLease()) { - return responderRSocket.requestChannel(payload, payloads); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestChannel(payloads); } catch (Throwable t) { return Flux.error(t); } @@ -255,378 +171,307 @@ public void dispose() { @Override public boolean isDisposed() { - return connection.isDisposed(); + return getDuplexConnection().isDisposed(); } @Override public Mono onClose() { - return connection.onClose(); + return getDuplexConnection().onClose(); } - private void cleanup(Throwable e) { + final void doOnDispose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing responder " + getDuplexConnection()); + } cleanUpSendingSubscriptions(); - cleanUpChannelProcessors(e); - connection.dispose(); - leaseHandlerDisposable.dispose(); + getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final ResponderLeaseTracker handler = leaseHandler; + if (handler != null) { + handler.dispose(); + } + requestHandler.dispose(); - sendProcessor.dispose(); + onThisSideClosedSink.tryEmitEmpty(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("responder closed " + getDuplexConnection()); + } } - private synchronized void cleanUpSendingSubscriptions() { - sendingSubscriptions.values().forEach(Subscription::cancel); - sendingSubscriptions.clear(); - } + private void cleanUpSendingSubscriptions() { + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } - private synchronized void cleanUpChannelProcessors(Throwable e) { - channelProcessors - .values() - .forEach( - payloadPayloadProcessor -> { - try { - payloadPayloadProcessor.onError(e); - } catch (Throwable t) { - // noops - } - }); - channelProcessors.clear(); + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + handler.handleCancel(); + } + } } - private void handleFrame(ByteBuf frame) { + final void handleFrame(ByteBuf frame) { try { int streamId = FrameHeaderCodec.streamId(frame); - Subscriber receiver; + FrameHandler receiver; FrameType frameType = FrameHeaderCodec.frameType(frame); switch (frameType) { case REQUEST_FNF: - handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame))); + handleFireAndForget(streamId, frame); break; case REQUEST_RESPONSE: - handleRequestResponse(streamId, requestResponse(payloadDecoder.apply(frame))); - break; - case CANCEL: - handleCancelFrame(streamId); - break; - case REQUEST_N: - handleRequestN(streamId, frame); + handleRequestResponse(streamId, frame); break; case REQUEST_STREAM: long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame); - Payload streamPayload = payloadDecoder.apply(frame); - handleStream(streamId, requestStream(streamPayload), streamInitialRequestN, null); + handleStream(streamId, frame, streamInitialRequestN); break; case REQUEST_CHANNEL: long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); - Payload channelPayload = payloadDecoder.apply(frame); - handleChannel(streamId, channelPayload, channelInitialRequestN); + handleChannel( + streamId, frame, channelInitialRequestN, FrameHeaderCodec.hasComplete(frame)); break; case METADATA_PUSH: - handleMetadataPush(metadataPush(payloadDecoder.apply(frame))); + handleMetadataPush(metadataPush(super.getPayloadDecoder().apply(frame))); + break; + case CANCEL: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleCancel(); + } + break; + case REQUEST_N: + receiver = super.get(streamId); + if (receiver != null) { + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + } break; case PAYLOAD: // TODO: Hook in receiving socket. break; case NEXT: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - receiver.onNext(payloadDecoder.apply(frame)); + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); } break; case COMPLETE: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - receiver.onComplete(); + receiver.handleComplete(); } break; case ERROR: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - // FIXME: when https://github.com/reactor/reactor-core/issues/2176 is resolved - // This is workaround to handle specific Reactor related case when - // onError call may not return normally - try { - receiver.onError(io.rsocket.exceptions.Exceptions.from(streamId, frame)); - } catch (RuntimeException e) { - if (reactor.core.Exceptions.isBubbling(e) - || reactor.core.Exceptions.isErrorCallbackNotImplemented(e)) { - if (LOGGER.isDebugEnabled()) { - Throwable unwrapped = reactor.core.Exceptions.unwrap(e); - LOGGER.debug("Unhandled dropped exception", unwrapped); - } - } - } + receiver.handleError(io.rsocket.exceptions.Exceptions.from(streamId, frame)); } break; case NEXT_COMPLETE: - receiver = channelProcessors.get(streamId); + receiver = super.get(streamId); if (receiver != null) { - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); + receiver.handleNext(frame, false, true); } break; case SETUP: - handleError(streamId, new IllegalStateException("Setup frame received post setup.")); + getDuplexConnection() + .sendFrame( + streamId, + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException("Setup frame received post setup."))); break; case LEASE: default: - handleError( - streamId, - new IllegalStateException("ServerRSocket: Unexpected frame type: " + frameType)); + getDuplexConnection() + .sendFrame( + streamId, + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException( + "ServerRSocket: Unexpected frame type: " + frameType))); break; } - ReferenceCountUtil.safeRelease(frame); } catch (Throwable t) { - ReferenceCountUtil.safeRelease(frame); - throw Exceptions.propagate(t); + LOGGER.error("Unexpected error during frame handling", t); + getDuplexConnection() + .sendFrame( + 0, + ErrorFrameCodec.encode( + super.getAllocator(), + 0, + new ConnectionErrorException("Unexpected error during frame handling", t))); + this.tryTerminateOnConnectionError(t); } } - private void handleFireAndForget(int streamId, Mono result) { - result.subscribe( - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - sendingSubscriptions.put(streamId, subscription); - subscription.request(Long.MAX_VALUE); - } + final void handleFireAndForget(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + if (FrameHeaderCodec.hasFollows(frame)) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } - @Override - protected void hookOnError(Throwable throwable) {} + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); - @Override - protected void hookFinally(SignalType type) { - sendingSubscriptions.remove(streamId); - } - }); + this.add(streamId, subscriber); + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(new FireAndForgetResponderSubscriber(streamId, this)); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } + } + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + } } - private void handleRequestResponse(int streamId, Mono response) { - final BaseSubscriber subscriber = - new BaseSubscriber() { - private boolean isEmpty = true; - - @Override - protected void hookOnNext(Payload payload) { - if (isEmpty) { - isEmpty = false; - } - - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - cancel(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - handleError(streamId, t); - return; - } - - ByteBuf byteBuf = - PayloadFrameCodec.encodeNextCompleteReleasingPayload(allocator, streamId, payload); - sendProcessor.onNext(byteBuf); - } + final void handleRequestResponse(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } - @Override - protected void hookOnError(Throwable throwable) { - if (sendingSubscriptions.remove(streamId, this)) { - handleError(streamId, throwable); - } - } + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); - @Override - protected void hookOnComplete() { - if (isEmpty) { - if (sendingSubscriptions.remove(streamId, this)) { - sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); - } - } - } - }; + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); - sendingSubscriptions.put(streamId, subscriber); - response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } } - private void handleStream( - int streamId, - Flux response, - long initialRequestN, - @Nullable UnicastProcessor requestChannel) { - final BaseSubscriber subscriber = - new BaseSubscriber() { - - @Override - protected void hookOnSubscribe(Subscription s) { - s.request(initialRequestN); - } - - @Override - protected void hookOnNext(Payload payload) { - try { - if (!PayloadValidationUtils.isValid(mtu, payload, maxFrameLength)) { - payload.release(); - final IllegalArgumentException t = - new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); - - cancelStream(t); - return; - } - - ByteBuf byteBuf = - PayloadFrameCodec.encodeNextReleasingPayload(allocator, streamId, payload); - sendProcessor.onNext(byteBuf); - } catch (Throwable e) { - cancelStream(e); - } - } - - private void cancelStream(Throwable t) { - // Cancel the output stream and send an ERROR frame but do not dispose the - // requestChannel (i.e. close the connection) since the spec allows to leave - // the channel in half-closed state. - // specifically for requestChannel case so when Payload is invalid we will not be - // sending CancelFrame and ErrorFrame - // Note: CancelFrame is redundant and due to spec - // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) - // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream - // is terminated on both Requester and Responder. - // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is - // terminated on both the Requester and Responder. - if (requestChannel != null) { - channelProcessors.remove(streamId, requestChannel); - } - cancel(); - handleError(streamId, t); - } + final void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } - @Override - protected void hookOnComplete() { - if (sendingSubscriptions.remove(streamId, this)) { - sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); - } - } + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); - @Override - protected void hookOnError(Throwable throwable) { - if (sendingSubscriptions.remove(streamId, this)) { - // specifically for requestChannel case so when Payload is invalid we will not be - // sending CancelFrame and ErrorFrame - // Note: CancelFrame is redundant and due to spec - // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) - // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream - // is terminated on both Requester and Responder. - // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is - // terminated on both the Requester and Responder. - if (requestChannel != null && !requestChannel.isDisposed()) { - if (channelProcessors.remove(streamId, requestChannel)) { - try { - requestChannel.dispose(); - } catch (Throwable e) { - // ignore to ensure it does not blows up if it racing with async - // cancel - } - } - } - - handleError(streamId, throwable); - } - } - }; - - sendingSubscriptions.put(streamId, subscriber); - response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); - } + this.add(streamId, subscriber); + } else { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); - private void handleChannel(int streamId, Payload payload, long initialRequestN) { - UnicastProcessor frames = UnicastProcessor.create(); - channelProcessors.put(streamId, frames); - - Flux payloads = - frames - .doOnRequest( - new LongConsumer() { - boolean first = true; - - @Override - public void accept(long l) { - long n; - if (first) { - first = false; - n = l - 1L; - } else { - n = l; - } - if (n > 0) { - sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); - } - } - }) - .doFinally( - signalType -> { - if (channelProcessors.remove(streamId, frames)) { - if (signalType == SignalType.CANCEL) { - sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); - } else if (signalType == SignalType.ON_ERROR) { - Subscription subscription = sendingSubscriptions.remove(streamId); - if (subscription != null) { - subscription.cancel(); - } - } - } - }) - .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); - - // not chained, as the payload should be enqueued in the Unicast processor before this method - // returns - // and any later payload can be processed - frames.onNext(payload); - - if (responderRSocket != null) { - handleStream(streamId, requestChannel(payload, payloads), initialRequestN, frames); + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } } else { - handleStream(streamId, requestChannel(payloads), initialRequestN, frames); + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } - private void handleMetadataPush(Mono result) { - result.subscribe( - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - subscription.request(Long.MAX_VALUE); - } - - @Override - protected void hookOnError(Throwable throwable) {} - }); - } + final void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } - private void handleCancelFrame(int streamId) { - Subscription subscription = sendingSubscriptions.remove(streamId); - Processor processor = channelProcessors.remove(streamId); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); - if (processor != null) { - try { - processor.onError(new CancellationException("Disposed")); - } catch (Exception e) { - // ignore + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + + if (this.add(streamId, subscriber)) { + this.requestChannel(subscriber).subscribe(subscriber); + if (complete) { + subscriber.handleComplete(); + } + } } - } - - if (subscription != null) { - subscription.cancel(); + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } - private void handleError(int streamId, Throwable t) { - sendProcessor.onNext(ErrorFrameCodec.encode(allocator, streamId, t)); + private void sendLeaseRejection(int streamId, Throwable leaseError) { + getDuplexConnection() + .sendFrame(streamId, ErrorFrameCodec.encode(getAllocator(), streamId, leaseError)); } - private void handleRequestN(int streamId, ByteBuf frame) { - Subscription subscription = sendingSubscriptions.get(streamId); + private void handleMetadataPush(Mono result) { + result.subscribe(MetadataPushResponderSubscriber.INSTANCE); + } - if (subscription != null) { - long n = RequestNFrameCodec.requestN(frame); - subscription.request(n); + @Override + public boolean add(int streamId, FrameHandler frameHandler) { + if (!super.add(streamId, frameHandler)) { + frameHandler.handleCancel(); + return false; } + + return true; } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index 610636f02..e969c39d2 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package io.rsocket.core; +import static io.rsocket.core.FragmentationUtils.assertMtu; import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import io.netty.buffer.ByteBuf; @@ -25,27 +27,26 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; import io.rsocket.SocketAcceptor; import io.rsocket.exceptions.InvalidSetupException; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.fragmentation.FragmentationDuplexConnection; -import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; -import io.rsocket.lease.Leases; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; import io.rsocket.resume.SessionManager; import io.rsocket.transport.ServerTransport; +import java.time.Duration; import java.util.Objects; import java.util.function.Consumer; import java.util.function.Supplier; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Sinks; /** * The main class for starting an RSocket server. @@ -66,11 +67,12 @@ public final class RSocketServer { private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); private Resume resume; - private Supplier> leasesSupplier = null; + private Consumer leaseConfigurer = null; private int mtu = 0; private int maxInboundPayloadSize = Integer.MAX_VALUE; private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private Duration timeout = Duration.ofMinutes(1); private RSocketServer() {} @@ -184,21 +186,23 @@ public RSocketServer resume(Resume resume) { * *
    {@code
        * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
    -   *         .lease(Leases::new)
    +   *         .lease(spec ->
    +   *            spec.sender(() -> Flux.interval(ofSeconds(1))
    +   *                                  .map(__ -> Lease.create(ofSeconds(1), 1)))
    +   *         )
        *         .bind(TcpServerTransport.create("localhost", 7000))
        *         .subscribe();
        * }
    * *

    By default this is not enabled. * - * @param supplier supplier for a {@link Leases} - * @return the same instance for method chaining + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring * @return the same instance for method chaining * @see Lease * Semantics */ - public RSocketServer lease(Supplier> supplier) { - this.leasesSupplier = supplier; + public RSocketServer lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; return this; } @@ -218,8 +222,24 @@ public RSocketServer lease(Supplier> supplier) { * and Reassembly */ public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { - this.maxInboundPayloadSize = - ReassemblyDuplexConnection.assertInboundPayloadSize(maxInboundPayloadSize); + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); + return this; + } + + /** + * Specify the max time to wait for the first frame (e.g. {@code SETUP}) on an accepted + * connection. + * + *

    By default this is set to 1 minute. + * + * @param timeout duration + * @return the same instance for method chaining + */ + public RSocketServer maxTimeToFirstFrame(Duration timeout) { + if (timeout.isNegative() || timeout.isZero()) { + throw new IllegalArgumentException("Setup Handling Timeout should be greater than zero"); + } + this.timeout = timeout; return this; } @@ -237,7 +257,7 @@ public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { * and Reassembly */ public RSocketServer fragment(int mtu) { - this.mtu = FragmentationDuplexConnection.assertMtu(mtu); + this.mtu = assertMtu(mtu); return this; } @@ -287,7 +307,7 @@ public RSocketServer payloadDecoder(PayloadDecoder decoder) { public Mono bind(ServerTransport transport) { return Mono.defer( new Supplier>() { - final ServerSetup serverSetup = serverSetup(); + final ServerSetup serverSetup = serverSetup(timeout); @Override public Mono get() { @@ -326,7 +346,7 @@ public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { public ServerTransport.ConnectionAcceptor asConnectionAcceptor(int maxFrameLength) { assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); return new ServerTransport.ConnectionAcceptor() { - private final ServerSetup serverSetup = serverSetup(); + private final ServerSetup serverSetup = serverSetup(timeout); @Override public Mono apply(DuplexConnection connection) { @@ -336,105 +356,107 @@ public Mono apply(DuplexConnection connection) { } private Mono acceptor( - ServerSetup serverSetup, DuplexConnection connection, int maxFrameLength) { - connection = - mtu > 0 - ? new FragmentationDuplexConnection(connection, mtu, maxInboundPayloadSize, "server") - : new ReassemblyDuplexConnection(connection, maxInboundPayloadSize); - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, interceptors, false); - - return multiplexer - .asSetupConnection() - .receive() - .next() - .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer, maxFrameLength)); + ServerSetup serverSetup, DuplexConnection sourceConnection, int maxFrameLength) { + + final DuplexConnection interceptedConnection = + interceptors.initConnection(DuplexConnectionInterceptor.Type.SOURCE, sourceConnection); + + return serverSetup + .init(LoggingDuplexConnection.wrapIfEnabled(interceptedConnection)) + .flatMap( + tuple2 -> { + final ByteBuf startFrame = tuple2.getT1(); + final DuplexConnection clientServerConnection = tuple2.getT2(); + + return accept(serverSetup, startFrame, clientServerConnection, maxFrameLength); + }); } private Mono acceptResume( - ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { - return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); + ServerSetup serverSetup, ByteBuf resumeFrame, DuplexConnection clientServerConnection) { + return serverSetup.acceptRSocketResume(resumeFrame, clientServerConnection); } private Mono accept( ServerSetup serverSetup, ByteBuf startFrame, - ClientServerInputMultiplexer multiplexer, + DuplexConnection clientServerConnection, int maxFrameLength) { switch (FrameHeaderCodec.frameType(startFrame)) { case SETUP: - return acceptSetup(serverSetup, startFrame, multiplexer, maxFrameLength); + return acceptSetup(serverSetup, startFrame, clientServerConnection, maxFrameLength); case RESUME: - return acceptResume(serverSetup, startFrame, multiplexer); + return acceptResume(serverSetup, startFrame, clientServerConnection); default: - return serverSetup - .sendError( - multiplexer, - new InvalidSetupException( - "invalid setup frame: " + FrameHeaderCodec.frameType(startFrame))) - .doFinally( - signalType -> { - startFrame.release(); - multiplexer.dispose(); - }); + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException("SETUP or RESUME frame must be received before any others")); + return clientServerConnection.onClose(); } } private Mono acceptSetup( ServerSetup serverSetup, ByteBuf setupFrame, - ClientServerInputMultiplexer multiplexer, + DuplexConnection clientServerConnection, int maxFrameLength) { if (!SetupFrameCodec.isSupportedVersion(setupFrame)) { - return serverSetup - .sendError( - multiplexer, - new InvalidSetupException( - "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))) - .doFinally( - signalType -> { - setupFrame.release(); - multiplexer.dispose(); - }); + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException( + "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))); + return clientServerConnection.onClose(); } - boolean leaseEnabled = leasesSupplier != null; + boolean leaseEnabled = leaseConfigurer != null; if (SetupFrameCodec.honorLease(setupFrame) && !leaseEnabled) { - return serverSetup - .sendError(multiplexer, new InvalidSetupException("lease is not supported")) - .doFinally( - signalType -> { - setupFrame.release(); - multiplexer.dispose(); - }); + serverSetup.sendError( + clientServerConnection, new InvalidSetupException("lease is not supported")); + return clientServerConnection.onClose(); } return serverSetup.acceptRSocketSetup( setupFrame, - multiplexer, - (keepAliveHandler, wrappedMultiplexer) -> { - ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(setupFrame); + clientServerConnection, + (keepAliveHandler, wrappedDuplexConnection) -> { + ConnectionSetupPayload setupPayload = + new DefaultConnectionSetupPayload(setupFrame.retain()); + final InitializingInterceptorRegistry interceptors = this.interceptors; + final ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedDuplexConnection, interceptors, false); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker(SERVER_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } - Leases leases = leaseEnabled ? leasesSupplier.get() : null; - RequesterLeaseHandler requesterLeaseHandler = - leaseEnabled - ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) - : RequesterLeaseHandler.None; + final Sinks.Empty requesterOnAllClosedSink = Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = Sinks.unsafe().empty(); RSocket rSocketRequester = new RSocketRequester( - wrappedMultiplexer.asServerConnection(), + multiplexer.asServerConnection(), payloadDecoder, StreamIdSupplier.serverSupplier(), mtu, maxFrameLength, + maxInboundPayloadSize, setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, - requesterLeaseHandler, - Schedulers.single(Schedulers.parallel())); + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), requesterOnAllClosedSink.asMono())); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -443,40 +465,50 @@ private Mono acceptSetup( .accept(setupPayload, wrappedRSocketRequester) .onErrorResume( err -> - serverSetup - .sendError(multiplexer, rejectedSetupError(err)) + Mono.fromRunnable( + () -> + serverSetup.sendError( + wrappedDuplexConnection, rejectedSetupError(err))) + .then(wrappedDuplexConnection.onClose()) .then(Mono.error(err))) .doOnNext( rSocketHandler -> { RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); - DuplexConnection connection = wrappedMultiplexer.asClientConnection(); + DuplexConnection clientConnection = multiplexer.asClientConnection(); - ResponderLeaseHandler responderLeaseHandler = + ResponderLeaseTracker responderLeaseTracker = leaseEnabled - ? new ResponderLeaseHandler.Impl<>( - SERVER_TAG, connection.alloc(), leases.sender(), leases.stats()) - : ResponderLeaseHandler.None; + ? new ResponderLeaseTracker(SERVER_TAG, clientConnection, leases.sender) + : null; RSocket rSocketResponder = new RSocketResponder( - connection, + clientConnection, wrappedRSocketHandler, payloadDecoder, - responderLeaseHandler, + responderLeaseTracker, mtu, - maxFrameLength); + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled && leases.sender instanceof TrackingLeaseSender + ? rSocket -> + interceptors.initResponderRequestInterceptor( + rSocket, (RequestInterceptor) leases.sender) + : interceptors::initResponderRequestInterceptor, + responderOnAllClosedSink); }) .doFinally(signalType -> setupPayload.release()) .then(); }); } - private ServerSetup serverSetup() { - return resume != null ? createSetup() : new ServerSetup.DefaultServerSetup(); + private ServerSetup serverSetup(Duration timeout) { + return resume != null ? createSetup(timeout) : new ServerSetup.DefaultServerSetup(timeout); } - ServerSetup createSetup() { + ServerSetup createSetup(Duration timeout) { return new ServerSetup.ResumableServerSetup( + timeout, new SessionManager(), resume.getSessionDuration(), resume.getStreamTimeout(), @@ -484,7 +516,7 @@ ServerSetup createSetup() { resume.isCleanupStoreOnKeepAlive()); } - private Exception rejectedSetupError(Throwable err) { + private RSocketErrorException rejectedSetupError(Throwable err) { String msg = err.getMessage(); return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java new file mode 100644 index 000000000..8e084fe9c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java @@ -0,0 +1,247 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.MIN_MTU_SIZE; +import static io.rsocket.core.StateUtils.isReassembling; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.markReassembled; +import static io.rsocket.core.StateUtils.markReassembling; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +class ReassemblyUtils { + static final String ILLEGAL_REASSEMBLED_PAYLOAD_SIZE = + "Reassembled payload size went out of allowed %s bytes"; + + @SuppressWarnings("ConstantConditions") + static void release(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + frames.release(); + } + } + + @SuppressWarnings({"ConstantConditions", "SynchronizationOnLocalVariableOrMethodParameter"}) + static void synchronizedRelease(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + + synchronized (frames) { + frames.release(); + } + } + } + + static void handleNextSupport( + AtomicLongFieldUpdater updater, + T instance, + Subscription subscription, + CoreSubscriber inboundSubscriber, + PayloadDecoder payloadDecoder, + ByteBufAllocator allocator, + int maxInboundPayloadSize, + ByteBuf frame, + boolean hasFollows, + boolean isLastPayload) { + + long state = updater.get(instance); + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = payloadDecoder.apply(frame); + } catch (Throwable t) { + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + if (isLastPayload) { + instance.handleComplete(); + } + return; + } + + CompositeByteBuf frames = instance.getFrames(); + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), frame, hasFollows, maxInboundPayloadSize); + instance.setFrames(frames); + + long previousState = markReassembling(updater, instance); + if (isTerminated(previousState)) { + instance.setFrames(null); + frames.release(); + return; + } + } else { + try { + frames = + ReassemblyUtils.addFollowingFrame(frames, frame, hasFollows, maxInboundPayloadSize); + } catch (IllegalStateException t) { + if (isTerminated(updater.get(instance))) { + return; + } + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(updater, instance); + if (isTerminated(previousState)) { + return; + } + + instance.setFrames(null); + + Payload payload; + try { + payload = payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + + if (isLastPayload) { + instance.handleComplete(); + } + } + } + + static CompositeByteBuf addFollowingFrame( + CompositeByteBuf frames, + ByteBuf followingFrame, + boolean hasFollows, + int maxInboundPayloadSize) { + int readableBytes = frames.readableBytes(); + if (readableBytes == 0) { + return frames.addComponent(true, followingFrame.retain()); + } else if (maxInboundPayloadSize != Integer.MAX_VALUE + && readableBytes + followingFrame.readableBytes() - FrameHeaderCodec.size() + > maxInboundPayloadSize) { + throw new IllegalStateException( + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)); + } else if (followingFrame.readableBytes() < MIN_MTU_SIZE - 3 && hasFollows) { + // FIXME: check MIN_MTU_SIZE only (currently fragments have size of 61) + throw new IllegalStateException("Fragment is too small."); + } + + final boolean hasMetadata = FrameHeaderCodec.hasMetadata(followingFrame); + + // skip headers + followingFrame.skipBytes(FrameHeaderCodec.size()); + + // if has metadata, then we have to increase metadata length in containing frames + // CompositeByteBuf + if (hasMetadata) { + final FrameType frameType = FrameHeaderCodec.frameType(frames); + final int lengthFieldPosition = + FrameHeaderCodec.size() + (frameType.hasInitialRequestN() ? Integer.BYTES : 0); + + frames.markReaderIndex(); + frames.skipBytes(lengthFieldPosition); + + final int nextMetadataLength = decodeLength(frames) + decodeLength(followingFrame); + + frames.resetReaderIndex(); + + frames.markWriterIndex(); + frames.writerIndex(lengthFieldPosition); + + encodeLength(frames, nextMetadataLength); + + frames.resetWriterIndex(); + } + + synchronized (frames) { + if (frames.refCnt() > 0) { + followingFrame.retain(); + return frames.addComponent(true, followingFrame); + } else { + throw new IllegalReferenceCountException(0); + } + } + } + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + int length = (byteBuf.readByte() & 0xFF) << 16; + length |= (byteBuf.readByte() & 0xFF) << 8; + length |= byteBuf.readByte() & 0xFF; + return length; + } + + static int assertInboundPayloadSize(int inboundPayloadSize) { + if (inboundPayloadSize < MIN_MTU_SIZE) { + String msg = + String.format( + "The min allowed inboundPayloadSize size is %d bytes, provided: %d", + FrameLengthCodec.FRAME_LENGTH_MASK, inboundPayloadSize); + throw new IllegalArgumentException(msg); + } else { + return inboundPayloadSize; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java index 44e4ffa81..afad6e0df 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java @@ -48,6 +48,10 @@ final class ReconnectMono extends Mono implements Invalidatable, Disposabl this.resolvingInner = new ResolvingInner<>(this); } + public Mono getSource() { + return source; + } + @Override public Object scanUnsafe(Attr key) { if (key == Attr.PARENT) return source; diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java new file mode 100644 index 000000000..aab491793 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -0,0 +1,829 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.DISCARD_CONTEXT; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +final class RequestChannelRequesterFlux extends Flux + implements RequesterFrameHandler, + LeasePermitHandler, + CoreSubscriber, + Subscription, + Scannable { + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + final Publisher payloadsPublisher; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); + + int streamId; + + boolean isFirstSignal = true; + Payload firstPayload; + + Subscription outboundSubscription; + boolean outboundDone; + Throwable outboundError; + + Context cachedContext; + CoreSubscriber inboundSubscriber; + boolean inboundDone; + long requested; + long produced; + + CompositeByteBuf frames; + + RequestChannelRequesterFlux( + Publisher payloadsPublisher, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadsPublisher = payloadsPublisher; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestChannelFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + this.payloadsPublisher.subscribe(this); + } + + @Override + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + this.inboundSubscriber.onSubscribe(this); + } + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + long previousState = addRequestN(STATE, this, n, this.requesterLeaseTracker == null); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + // do first request + this.outboundSubscription.request(1); + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + p.release(); + return; + } + + if (this.isFirstSignal) { + this.isFirstSignal = false; + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + + if (leaseEnabled) { + this.firstPayload = p; + + final long previousState = markFirstPayloadReceived(STATE, this); + if (isTerminated(previousState)) { + this.firstPayload = null; + p.release(); + return; + } + + requesterLeaseTracker.issue(this); + } else { + final long state = this.state; + if (isTerminated(state)) { + p.release(); + return; + } + // TODO: check if source is Scalar | Callable | Mono + sendFirstPayload(p, extractRequestN(state), false); + } + } else { + sendFollowingPayload(p); + } + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + sendFirstPayload( + firstPayload, extractRequestN(previousState), isOutboundTerminated(previousState)); + return true; + } + + void sendFirstPayload(Payload firstPayload, long initialRequestN, boolean completed) { + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + } catch (IllegalReferenceCountException e) { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(ut); + + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_CHANNEL, + initialRequestN, + mtu, + firstPayload, + connection, + allocator, + completed); + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + sm.remove(streamId, this); + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // now, this can be terminated in case of the following scenarios: + // + // 1) SendFirst is called synchronously from onNext, thus we can have + // handleError called before we marked first frame sent, thus we may check if + // inboundDone flag is true and exit execution without any further actions: + if (this.inboundDone) { + return; + } + + sm.remove(streamId, this); + + // 2) SendFirst is called asynchronously on the connection event-loop. Thus, we + // need to check if outbound error is present. Note, we check outboundError since + // in the last scenario, cancellation may terminate the state and async + // onComplete may set outboundDone to true. Thus, we explicitly check for + // outboundError + final Throwable outboundError = this.outboundError; + if (outboundError != null) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, outboundError); + connection.sendFrame(streamId, errorFrame); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, outboundError); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(outboundError); + } else { + // 3) SendFirst is interleaving with cancel. Thus, we need to generate cancel + // frame + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_CHANNEL); + } + } + + return; + } + + if (!completed && isOutboundTerminated(previousState)) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + connection.sendFrame(streamId, completeFrame); + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + connection.sendFrame(streamId, requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + connection.sendFrame(streamId, requestNFrame); + } + } + + final void sendFollowingPayload(Payload followingPayload) { + int streamId = this.streamId; + int mtu = this.mtu; + + try { + if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { + followingPayload.release(); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + return; + } + } catch (IllegalReferenceCountException e) { + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + + return; + } + + try { + sendReleasingPayload( + streamId, + + // TODO: Should be a different flag in case of the scalar + // source or if we know in advance upstream is mono + FrameType.NEXT, + mtu, + followingPayload, + this.connection, + allocator, + true); + } catch (Throwable e) { + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + } + } + + void propagateErrorSafely(Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + if (!this.inboundDone) { + synchronized (this) { + if (!this.inboundDone) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + + @Override + public final void cancel() { + if (!tryCancel()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + boolean tryCancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return false; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (!isReadyToSendFirstFrame(previousState) && isFirstPayloadReceived(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + // no need to send anything, since we have not started a stream yet (no logical wire) + return false; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final boolean firstFrameSent = isFirstFrameSent(previousState); + if (firstFrameSent) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + } + + return firstFrameSent; + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundError = t; + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (this.isFirstSignal) { + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; + } else if (!isReadyToSendFirstFrame(previousState)) { + // first signal is received but we are still waiting for lease permit to be issued, + // thus, just propagates error to actual subscriber + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + // propagates error to remote responder + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + if (!isInboundTerminated(previousState)) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + if (!isFirstFrameSent(previousState)) { + if (!isFirstPayloadReceived(previousState)) { + // first signal, thus, just propagates error to actual subscriber + this.inboundSubscriber.onError(new CancellationException("Empty Source")); + } + return; + } + + final int streamId = this.streamId; + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated(previousState)) { + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isOutboundTerminated(previousState)) { + this.requesterResponderSupport.remove(this.streamId, this); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Payload p = this.firstPayload; + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onReject(cause, FrameType.REQUEST_CHANNEL, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.inboundDone) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + ReassemblyUtils.release(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handlePayload(Payload value) { + synchronized (this) { + if (this.inboundDone) { + value.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + value.release(); + if (!tryCancel()) { + return; + } + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(value); + } + } + + @Override + public void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public void handleCancel() { + if (this.outboundDone) { + return; + } + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + final boolean inboundTerminated = isInboundTerminated(previousState); + if (inboundTerminated) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + this.outboundSubscription.cancel(); + + if (inboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + @NonNull + public Context currentContext() { + long state = this.state; + + if (isSubscribedOrTerminated(state)) { + Context cachedContext = this.cachedContext; + if (cachedContext == null) { + cachedContext = + this.inboundSubscriber.currentContext().putAll((ContextView) DISCARD_CONTEXT); + this.cachedContext = cachedContext; + } + return cachedContext; + } + + return Context.empty(); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return state; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestChannelFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java new file mode 100644 index 000000000..32128fee4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -0,0 +1,922 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; +import static reactor.core.Exceptions.TERMINATED; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestChannelResponderSubscriber extends Flux + implements ResponderFrameHandler, Subscription, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestChannelResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final long firstRequest; + + @Nullable final RequestInterceptor requestInterceptor; + + final RSocket handler; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelResponderSubscriber.class, "state"); + + Payload firstPayload; + + Subscription outboundSubscription; + CoreSubscriber inboundSubscriber; + + CompositeByteBuf frames; + + volatile Throwable inboundError; + static final AtomicReferenceFieldUpdater + INBOUND_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RequestChannelResponderSubscriber.class, Throwable.class, "inboundError"); + + boolean inboundDone; + boolean outboundDone; + long requested; + long produced; + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + this.firstRequest = firstRequestN; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + STATE.lazySet(this, REASSEMBLING_FLAG); + } + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + Payload firstPayload, + RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.firstRequest = firstRequestN; + this.firstPayload = firstPayload; + + this.handler = null; + this.frames = null; + } + + @Override + // subscriber from the requestChannel method + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isTerminated(previousState)) { + Throwable t = Exceptions.terminate(INBOUND_ERROR, this); + if (t != TERMINATED) { + //noinspection ConstantConditions + Operators.error(actual, t); + } else { + Operators.error( + actual, + new CancellationException("RequestChannelSubscriber has already been terminated")); + } + return; + } + + if (isSubscribed(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestChannelSubscriber allows only one Subscriber")); + return; + } + + this.inboundSubscriber = actual; + // sends sender as a subscription since every request|cancel signal should be encoded to + // requestNFrame|cancelFrame + actual.onSubscribe(this); + } + + @Override + // subscription to the outbound + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + outboundSubscription.request(this.firstRequest); + } + } + + @Override + public void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + long previousState = StateUtils.addRequestN(STATE, this, n); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null || this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + Payload firstPayload = this.firstPayload; + if (firstPayload != null) { + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else { + inboundSubscriber.onComplete(); + } + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (!hasRequested(previousState) && !isFirstFrameSent(previousState) && this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + inboundSubscriber.onComplete(); + + markFirstFrameSent(STATE, this); + } + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(StateUtils.extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + + previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + long requestN = StateUtils.extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + this.connection.sendFrame(streamId, requestNFrame); + } else { + long firstRequestN = requestN - 1; + if (firstRequestN > 0) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(this.allocator, streamId, firstRequestN); + this.connection.sendFrame(streamId, requestNFrame); + } + } + } + + @Override + // inbound cancellation + public void cancel() { + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + INBOUND_ERROR.lazySet(this, TERMINATED); + return; + } + + if (!isFirstFrameSent(previousState) && !hasRequested(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final int streamId = this.streamId; + + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleCancel() { + Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + lazyTerminate(STATE, this); + + this.requesterResponderSupport.remove(this.streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } else { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + final long tryTerminate(boolean isFromInbound) { + Exceptions.addThrowable( + INBOUND_ERROR, this, new CancellationException("Inbound has been canceled")); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return previousState; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + if (isFromInbound) { + frames.release(); + } else { + synchronized (frames) { + frames.release(); + } + } + } + + final Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + return previousState; + } + + outboundSubscription.cancel(); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + if (isFromInbound) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } else { + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + } + + return previousState; + } + + final void handlePayload(Payload p) { + synchronized (this) { + if (this.inboundDone) { + // payload from network so it has refCnt > 0 + p.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + this.inboundDone = true; + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + this.connection.sendFrame( + streamId, + ErrorFrameCodec.encode( + this.allocator, streamId, new CanceledException(cause.getMessage()))); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause); + } + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(p); + } + } + + @Override + public final void handleError(Throwable t) { + if (this.inboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, t); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + } + + @Override + public void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + if (isFirstFrameSent(previousState)) { + this.inboundSubscriber.onComplete(); + } + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + long state = this.state; + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = this.payloadDecoder.apply(frame); + } catch (Throwable t) { + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode(this.allocator, streamId, new CanceledException(t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + return; + } + + this.handlePayload(payload); + if (isLastPayload) { + this.handleComplete(); + } + return; + } + + CompositeByteBuf frames = this.frames; + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + this.allocator.compositeBuffer(), frame, hasFollows, this.maxInboundPayloadSize); + this.frames = frames; + + long previousState = markReassembling(STATE, this); + if (isTerminated(previousState)) { + this.frames = null; + frames.release(); + return; + } + } else { + try { + frames = + ReassemblyUtils.addFollowingFrame( + frames, frame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException e) { + if (isTerminated(this.state)) { + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + return; + } + + if (this.outboundSubscription == null) { + this.firstPayload = payload; + Flux source = this.handler.requestChannel(this); + source.subscribe(this); + } else { + this.handlePayload(payload); + } + + if (isLastPayload) { + this.handleComplete(); + } + } + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState)) { + Operators.onErrorDropped( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), + this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, connection, allocator, false); + } catch (Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null && !isTerminated(previousState)) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + boolean wasThrowableAdded = + Exceptions.addThrowable( + INBOUND_ERROR, + this, + new CancellationException("Outbound has terminated with an error")); + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + synchronized (frames) { + frames.release(); + } + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (wasThrowableAdded + && isFirstFrameSent(previousState) + && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java deleted file mode 100644 index 05f8d6b3c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java +++ /dev/null @@ -1,188 +0,0 @@ -package io.rsocket.core; - -import io.rsocket.Payload; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Fuseable; -import reactor.core.publisher.Operators; -import reactor.core.publisher.SignalType; -import reactor.util.context.Context; - -/** - * This is a support class for handling of request input, intended for use with {@link - * Operators#lift}. It ensures serial execution of cancellation vs first request signals and also - * provides hooks for separate handling of first vs subsequent {@link Subscription#request} - * invocations. - */ -abstract class RequestOperator - implements CoreSubscriber, Fuseable.QueueSubscription { - - final CoreSubscriber actual; - - Subscription s; - Fuseable.QueueSubscription qs; - - int streamId; - boolean firstRequest = true; - - volatile int wip; - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(RequestOperator.class, "wip"); - - RequestOperator(CoreSubscriber actual) { - this.actual = actual; - } - - /** - * Optional hook executed exactly once on the first {@link Subscription#request) invocation - * and right after the {@link Subscription#request} was propagated to the upstream subscription. - * - *

    Note: this hook may not be invoked if cancellation happened before this invocation - */ - void hookOnFirstRequest(long n) {} - - /** - * Optional hook executed after the {@link Subscription#request} was propagated to the upstream - * subscription and excludes the first {@link Subscription#request} invocation. - */ - void hookOnRemainingRequests(long n) {} - - /** Optional hook executed after this {@link Subscription} cancelling. */ - void hookOnCancel() {} - - /** - * Optional hook executed after {@link org.reactivestreams.Subscriber} termination events - * (onError, onComplete). - * - * @param signalType the type of termination event that triggered the hook ({@link - * SignalType#ON_ERROR} or {@link SignalType#ON_COMPLETE}) - */ - void hookOnTerminal(SignalType signalType) {} - - @Override - public Context currentContext() { - return actual.currentContext(); - } - - @Override - public void request(long n) { - this.s.request(n); - if (!firstRequest) { - try { - this.hookOnRemainingRequests(n); - } catch (Throwable throwable) { - onError(throwable); - } - return; - } - this.firstRequest = false; - - if (WIP.getAndIncrement(this) != 0) { - return; - } - int missed = 1; - - boolean firstLoop = true; - for (; ; ) { - if (firstLoop) { - firstLoop = false; - try { - this.hookOnFirstRequest(n); - } catch (Throwable throwable) { - onError(throwable); - return; - } - } else { - try { - this.hookOnCancel(); - } catch (Throwable throwable) { - onError(throwable); - } - return; - } - - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { - return; - } - } - } - - @Override - public void cancel() { - this.s.cancel(); - - if (WIP.getAndIncrement(this) != 0) { - return; - } - - hookOnCancel(); - } - - @Override - @SuppressWarnings("unchecked") - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - if (s instanceof Fuseable.QueueSubscription) { - this.qs = (Fuseable.QueueSubscription) s; - } - this.actual.onSubscribe(this); - } - } - - @Override - public void onNext(Payload t) { - this.actual.onNext(t); - } - - @Override - public void onError(Throwable t) { - this.actual.onError(t); - try { - this.hookOnTerminal(SignalType.ON_ERROR); - } catch (Throwable throwable) { - Operators.onErrorDropped(throwable, currentContext()); - } - } - - @Override - public void onComplete() { - this.actual.onComplete(); - try { - this.hookOnTerminal(SignalType.ON_COMPLETE); - } catch (Throwable throwable) { - Operators.onErrorDropped(throwable, currentContext()); - } - } - - @Override - public int requestFusion(int requestedMode) { - if (this.qs != null) { - return this.qs.requestFusion(requestedMode); - } else { - return Fuseable.NONE; - } - } - - @Override - public Payload poll() { - return this.qs.poll(); - } - - @Override - public int size() { - return this.qs.size(); - } - - @Override - public boolean isEmpty() { - return this.qs.isEmpty(); - } - - @Override - public void clear() { - this.qs.clear(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java new file mode 100644 index 000000000..a13b105b5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -0,0 +1,400 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestResponseRequesterMono extends Mono + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); + + int streamId; + CoreSubscriber actual; + CompositeByteBuf frames; + boolean done; + + RequestResponseRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestResponseMono allows only a single " + "Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + + if (isTerminated(previousState) || hasRequested(previousState)) { + return; + } + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload); + return true; + } + + void sendFirstPayload(Payload payload) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + final long previousState = markTerminated(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + + payload.release(); + + if (!isTerminated(previousState)) { + this.actual.onError(ut); + } + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + + try { + sendReleasingPayload( + streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, connection, allocator, true); + } catch (Throwable e) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + + this.actual.onError(e); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + sm.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } + } + + @Override + public final void cancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } else if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload value) { + if (this.done) { + value.release(); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + value.release(); + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + + final CoreSubscriber a = this.actual; + a.onNext(value); + a.onComplete(); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + + this.actual.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + this.actual.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, cause); + } + + this.actual.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.actual, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.PREFETCH) return 0; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestResponseMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java new file mode 100644 index 000000000..3d9d020ff --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -0,0 +1,358 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestResponseResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestResponseResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final RSocket handler; + + @Nullable final RequestInterceptor requestInterceptor; + + boolean done; + CompositeByteBuf frames; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestResponseResponderSubscriber.class, Subscription.class, "s"); + + public RequestResponseResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + public RequestResponseResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + S.lazySet(this, subscription); + subscription.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(@Nullable Payload p) { + if (this.done) { + if (p != null) { + p.release(); + } + return; + } + + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription() + || !S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + if (p != null) { + p.release(); + } + return; + } + + this.done = true; + + final int streamId = this.streamId; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + this.requesterResponderSupport.remove(streamId, this); + + if (p == null) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); + connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + return; + } + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + currentSubscription.cancel(); + + p.release(); + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + currentSubscription.cancel(); + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, connection, allocator, false); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + } catch (Throwable t) { + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription() + || !S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + } + + @Override + public void onComplete() { + onNext(null); + } + + @Override + public void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + if (frames == null) { + return; + } + + try { + ReassemblyUtils.addFollowingFrame(frames, frame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException t) { + S.lazySet(this, Operators.cancelledSubscription()); + + this.requesterResponderSupport.remove(this.streamId, this); + + this.frames = null; + frames.release(); + + logger.debug("Reassembly has failed", t); + + // sends error frame from the responder side to tell that something went wrong + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + return; + } + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReferenceCountUtil.safeRelease(frames); + + logger.debug("Reassembly has failed", t); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + return; + } + + final Mono source = this.handler.requestResponse(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java new file mode 100644 index 000000000..6182ca506 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -0,0 +1,449 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestStreamRequesterFlux extends Flux + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); + + int streamId; + CoreSubscriber inboundSubscriber; + CompositeByteBuf frames; + boolean done; + long requested; + long produced; + + RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestStreamFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload, n); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload, extractRequestN(previousState)); + return true; + } + + void sendFirstPayload(Payload payload, long initialRequestN) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + final long previousState = markTerminated(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_STREAM, payload.metadata()); + } + + payload.release(); + + if (!isTerminated(previousState)) { + this.inboundSubscriber.onError(ut); + } + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_STREAM, + initialRequestN, + this.mtu, + payload, + connection, + allocator, + false); + } catch (Throwable t) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + this.inboundSubscriber.onError(t); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + return; + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + connection.sendFrame(streamId, requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + connection.sendFrame(streamId, requestNFrame); + } + } + + @Override + public final void cancel() { + final long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } else if (!isReadyToSendFirstFrame(previousState)) { + // no need to send anything, since the first request has not happened + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload p) { + if (this.done) { + p.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final IllegalStateException cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(p); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return extractRequestN(state); + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestStreamFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java new file mode 100644 index 000000000..48903ae38 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -0,0 +1,395 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestStreamResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestStreamResponderSubscriber.class); + + final int streamId; + final long firstRequest; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; + + final RSocket handler; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestStreamResponderSubscriber.class, Subscription.class, "s"); + + CompositeByteBuf frames; + boolean done; + + public RequestStreamResponderSubscriber( + int streamId, + long firstRequest, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + public RequestStreamResponderSubscriber( + int streamId, long firstRequest, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + final long firstRequest = this.firstRequest; + S.lazySet(this, subscription); + subscription.request(firstRequest); + } + } + + @Override + public void onNext(Payload p) { + if (this.done) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final DuplexConnection sender = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + if (!this.tryTerminateOnError()) { + return; + } + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + if (!this.tryTerminateOnError()) { + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); + } catch (Throwable t) { + if (!this.tryTerminateOnError()) { + return; + } + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + } + + boolean tryTerminateOnError() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return false; + } + + this.done = true; + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return false; + } + + currentSubscription.cancel(); + + return true; + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + logger.debug("Dropped error", t); + return; + } + + final CompositeByteBuf frames = this.frames; + if (frames != null && frames.refCnt() > 0) { + frames.release(); + } + + final int streamId = this.streamId; + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + return; + } + + final int streamId = this.streamId; + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } + } + + @Override + public void handleRequestN(long n) { + this.s.request(n); + } + + @Override + public final void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + if (frames == null) { + return; + } + + try { + ReassemblyUtils.addFollowingFrame( + frames, followingFrame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException e) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + + this.frames = null; + frames.release(); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + + logger.debug("Reassembly has failed", e); + return; + } + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + S.lazySet(this, Operators.cancelledSubscription()); + this.done = true; + + final int streamId = this.streamId; + + ReferenceCountUtil.safeRelease(frames); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + Flux source = this.handler.requestStream(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java new file mode 100644 index 000000000..1f7b09af8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java @@ -0,0 +1,43 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import java.util.concurrent.CancellationException; +import reactor.util.annotation.Nullable; + +interface RequesterFrameHandler extends FrameHandler { + + void handlePayload(Payload payload); + + @Override + default void handleCancel() { + handleError( + new CancellationException( + "Cancellation was received but should not be possible for current request type")); + } + + @Override + default void handleRequestN(long n) { + // no ops + } + + @Nullable + CompositeByteBuf getFrames(); + + void setFrames(@Nullable CompositeByteBuf reassembledFrames); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java new file mode 100644 index 000000000..50da83b8f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java @@ -0,0 +1,135 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Availability; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.MissingLeaseException; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Queue; + +final class RequesterLeaseTracker implements Availability { + + final String tag; + final int maximumAllowedAwaitingPermitHandlersNumber; + final Queue awaitingPermitHandlersQueue; + + Lease currentLease = null; + int availableRequests; + + boolean isDisposed; + Throwable t; + + RequesterLeaseTracker(String tag, int maximumAllowedAwaitingPermitHandlersNumber) { + this.tag = tag; + this.maximumAllowedAwaitingPermitHandlersNumber = maximumAllowedAwaitingPermitHandlersNumber; + this.awaitingPermitHandlersQueue = new ArrayDeque<>(); + } + + synchronized void issue(LeasePermitHandler leasePermitHandler) { + if (this.isDisposed) { + leasePermitHandler.handlePermitError(this.t); + return; + } + + final int availableRequests = this.availableRequests; + final Lease l = this.currentLease; + final boolean leaseReceived = l != null; + final boolean isExpired = leaseReceived && isExpired(l); + + if (leaseReceived && availableRequests > 0 && !isExpired) { + if (leasePermitHandler.handlePermit()) { + this.availableRequests = availableRequests - 1; + } + } else { + final Queue queue = this.awaitingPermitHandlersQueue; + if (this.maximumAllowedAwaitingPermitHandlersNumber > queue.size()) { + queue.offer(leasePermitHandler); + } else { + final String tag = this.tag; + final String message; + if (!leaseReceived) { + message = String.format("[%s] Lease was not received yet", tag); + } else if (isExpired) { + message = String.format("[%s] Missing leases. Lease is expired", tag); + } else { + message = + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", + tag, availableRequests); + } + + final Throwable t = new MissingLeaseException(message); + leasePermitHandler.handlePermitError(t); + } + } + } + + void handleLeaseFrame(ByteBuf leaseFrame) { + final int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); + final int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); + final ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); + + synchronized (this) { + final Lease lease = + Lease.create(Duration.ofMillis(timeToLiveMillis), numberOfRequests, metadata); + final Queue queue = this.awaitingPermitHandlersQueue; + + int availableRequests = lease.numberOfRequests(); + + this.currentLease = lease; + if (queue.size() > 0) { + do { + final LeasePermitHandler handler = queue.poll(); + if (handler.handlePermit()) { + availableRequests--; + } + } while (availableRequests > 0 && queue.size() > 0); + } + + this.availableRequests = availableRequests; + } + } + + public synchronized void dispose(Throwable t) { + this.isDisposed = true; + this.t = t; + + final Queue queue = this.awaitingPermitHandlersQueue; + final int size = queue.size(); + + for (int i = 0; i < size; i++) { + final LeasePermitHandler leasePermitHandler = queue.poll(); + + //noinspection ConstantConditions + leasePermitHandler.handlePermitError(t); + } + } + + @Override + public synchronized double availability() { + final Lease lease = this.currentLease; + return lease != null ? this.availableRequests / (double) lease.numberOfRequests() : 0.0d; + } + + static boolean isExpired(Lease currentLease) { + return System.currentTimeMillis() >= currentLease.expirationTime(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java new file mode 100644 index 000000000..bea7dc1aa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -0,0 +1,161 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.function.Function; +import reactor.util.annotation.Nullable; + +class RequesterResponderSupport { + + private final int mtu; + private final int maxFrameLength; + private final int maxInboundPayloadSize; + private final PayloadDecoder payloadDecoder; + private final ByteBufAllocator allocator; + private final DuplexConnection connection; + @Nullable private final RequestInterceptor requestInterceptor; + + @Nullable final StreamIdSupplier streamIdSupplier; + final IntObjectMap activeStreams; + + public RequesterResponderSupport( + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + PayloadDecoder payloadDecoder, + DuplexConnection connection, + @Nullable StreamIdSupplier streamIdSupplier, + Function requestInterceptorFunction) { + + this.activeStreams = new IntObjectHashMap<>(); + this.mtu = mtu; + this.maxFrameLength = maxFrameLength; + this.maxInboundPayloadSize = maxInboundPayloadSize; + this.payloadDecoder = payloadDecoder; + this.allocator = connection.alloc(); + this.streamIdSupplier = streamIdSupplier; + this.connection = connection; + this.requestInterceptor = requestInterceptorFunction.apply((RSocket) this); + } + + public int getMtu() { + return mtu; + } + + public int getMaxFrameLength() { + return maxFrameLength; + } + + public int getMaxInboundPayloadSize() { + return maxInboundPayloadSize; + } + + public PayloadDecoder getPayloadDecoder() { + return payloadDecoder; + } + + public ByteBufAllocator getAllocator() { + return allocator; + } + + public DuplexConnection getDuplexConnection() { + return connection; + } + + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return null; + } + + @Nullable + public RequestInterceptor getRequestInterceptor() { + return requestInterceptor; + } + + /** + * Issues next {@code streamId} + * + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int getNextStreamId() { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + synchronized (this) { + return streamIdSupplier.nextStreamId(this.activeStreams); + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + /** + * Adds frameHandler and returns issued {@code streamId} back + * + * @param frameHandler to store + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int addAndGetNextStreamId(FrameHandler frameHandler) { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + final IntObjectMap activeStreams = this.activeStreams; + synchronized (this) { + final int streamId = streamIdSupplier.nextStreamId(activeStreams); + + activeStreams.put(streamId, frameHandler); + + return streamId; + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + public synchronized boolean add(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.putIfAbsent(key, value) without `streamId` boxing + final FrameHandler previousHandler = activeStreams.get(streamId); + if (previousHandler == null) { + activeStreams.put(streamId, frameHandler); + return true; + } + return false; + } + + /** + * Resolves {@link FrameHandler} by {@code streamId} + * + * @param streamId used to resolve {@link FrameHandler} + * @return {@link FrameHandler} or {@code null} + */ + @Nullable + public synchronized FrameHandler get(int streamId) { + return this.activeStreams.get(streamId); + } + + /** + * Removes {@link FrameHandler} if it is present and equals to the given one + * + * @param streamId to lookup for {@link FrameHandler} + * @param frameHandler instance to check with the found one + * @return {@code true} if there is {@link FrameHandler} for the given {@code streamId} and the + * instance equals to the passed one + */ + public synchronized boolean remove(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.remove(key, value) without `streamId` boxing + final FrameHandler curValue = activeStreams.get(streamId); + if (!Objects.equals(curValue, frameHandler)) { + return false; + } + activeStreams.remove(streamId); + return true; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java index c431b3f3f..50bef5b70 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import java.time.Duration; @@ -15,6 +30,8 @@ import reactor.util.annotation.Nullable; import reactor.util.context.Context; +// A copy of this class exists in io.rsocket.loadbalance + class ResolvingOperator implements Disposable { static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); @@ -153,19 +170,19 @@ public T block(@Nullable Duration timeout) { delay = System.nanoTime() + timeout.toNanos(); } for (; ; ) { - BiConsumer[] inners = this.subscribers; + subscribers = this.subscribers; - if (inners == READY) { + if (subscribers == READY) { final T value = this.value; if (value != null) { return value; } else { // value == null means racing between invalidate and this block // thus, we have to update the state again and see what happened - inners = this.subscribers; + subscribers = this.subscribers; } } - if (inners == TERMINATED) { + if (subscribers == TERMINATED) { RuntimeException re = Exceptions.propagate(this.t); re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); throw re; @@ -174,6 +191,12 @@ public T block(@Nullable Duration timeout) { throw new IllegalStateException("Timeout on Mono blocking read"); } + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + Thread.sleep(1); } } catch (InterruptedException ie) { @@ -186,6 +209,7 @@ public T block(@Nullable Duration timeout) { @SuppressWarnings("unchecked") final void terminate(Throwable t) { if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); return; } @@ -307,6 +331,30 @@ protected void doOnDispose() { // no ops } + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + final int add(BiConsumer ps) { for (; ; ) { BiConsumer[] a = this.subscribers; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java similarity index 57% rename from rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java rename to rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java index d9dec9f54..27cc8db9a 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java @@ -13,17 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.rsocket.core; -package io.rsocket.resume; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.util.function.BiFunction; -import org.reactivestreams.Publisher; -import reactor.util.retry.Retry; +interface ResponderFrameHandler extends FrameHandler { -/** - * @deprecated as of 1.0 RC7 in favor of using {@link io.rsocket.core.Resume#retry(Retry)} via - * {@link io.rsocket.core.RSocketConnector} or {@link io.rsocket.core.RSocketServer}. - */ -@Deprecated -@FunctionalInterface -public interface ResumeStrategy extends BiFunction> {} + Logger logger = LoggerFactory.getLogger(ResponderFrameHandler.class); + + @Override + default void handleComplete() {} + + @Override + default void handleError(Throwable t) { + logger.debug("Dropped error", t); + handleCancel(); + } + + @Override + default void handleRequestN(long n) { + // no ops + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java new file mode 100644 index 000000000..fc7442f4a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Availability; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseSender; +import io.rsocket.lease.MissingLeaseException; +import reactor.core.Disposable; +import reactor.core.publisher.BaseSubscriber; +import reactor.util.annotation.Nullable; + +final class ResponderLeaseTracker extends BaseSubscriber + implements Disposable, Availability { + + final String tag; + final ByteBufAllocator allocator; + final DuplexConnection connection; + + @Nullable volatile MutableLease currentLease; + + ResponderLeaseTracker(String tag, DuplexConnection connection, LeaseSender leaseSender) { + this.tag = tag; + this.connection = connection; + this.allocator = connection.alloc(); + + leaseSender.send().subscribe(this); + } + + @Nullable + Throwable use() { + final MutableLease lease = this.currentLease; + final String tag = this.tag; + + if (lease == null) { + return new MissingLeaseException(String.format("[%s] Lease was not issued yet", tag)); + } + + if (isExpired(lease)) { + return new MissingLeaseException(String.format("[%s] Missing leases. Lease is expired", tag)); + } + + final int allowedRequests = lease.allowedRequests; + final int remainingRequests = lease.remainingRequests; + if (remainingRequests <= 0) { + return new MissingLeaseException( + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", tag, allowedRequests)); + } + + lease.remainingRequests = remainingRequests - 1; + + return null; + } + + @Override + protected void hookOnNext(Lease lease) { + final int allowedRequests = lease.numberOfRequests(); + final int ttl = lease.timeToLiveInMillis(); + final long expireAt = lease.expirationTime(); + + this.currentLease = new MutableLease(allowedRequests, expireAt); + this.connection.sendFrame( + 0, LeaseFrameCodec.encode(this.allocator, ttl, allowedRequests, lease.metadata())); + } + + @Override + public double availability() { + final MutableLease lease = this.currentLease; + + if (lease == null || isExpired(lease)) { + return 0; + } + + return lease.remainingRequests / (double) lease.allowedRequests; + } + + static boolean isExpired(MutableLease currentLease) { + return System.currentTimeMillis() >= currentLease.expireAt; + } + + static final class MutableLease { + final int allowedRequests; + final long expireAt; + + int remainingRequests; + + MutableLease(int allowedRequests, long expireAt) { + this.allowedRequests = allowedRequests; + this.expireAt = expireAt; + + this.remainingRequests = allowedRequests; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java index 48133af98..fa0eedbfa 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/Resume.java +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -160,7 +160,7 @@ boolean isCleanupStoreOnKeepAlive() { Function getStoreFactory(String tag) { return storeFactory != null ? storeFactory - : token -> new InMemoryResumableFramesStore(tag, 100_000); + : token -> new InMemoryResumableFramesStore(tag, token, 100_000); } Duration getStreamTimeout() { diff --git a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java new file mode 100644 index 000000000..568dada2e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java @@ -0,0 +1,335 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.isFragmentable; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import java.util.function.Consumer; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +final class SendUtils { + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + try { + ReferenceCounted referenceCounted = (ReferenceCounted) data; + referenceCounted.release(); + } catch (Throwable e) { + // ignored + } + } + }; + + static final Context DISCARD_CONTEXT = Operators.enableOnDiscard(null, DROPPED_ELEMENTS_CONSUMER); + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + int mtu, + Payload payload, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean requester) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, false); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, streamId, hasMetadata, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + connection.sendFrame(streamId, first); + + boolean complete = frameType == FrameType.NEXT_COMPLETE; + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, true, e); + throw e; + } + connection.sendFrame(streamId, following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_FNF: + requestFrame = + RequestFireAndForgetFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case REQUEST_RESPONSE: + requestFrame = + RequestResponseFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + requestFrame = + PayloadFrameCodec.encode( + allocator, + streamId, + false, + frameType == FrameType.NEXT_COMPLETE, + frameType != FrameType.PAYLOAD, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + connection.sendFrame(streamId, requestFrame); + } + } + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + long initialRequestN, + int mtu, + Payload payload, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean complete) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, true); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, + mtu, + initialRequestN, + frameType, + streamId, + hasMetadata, + slicedMetadata, + slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + connection.sendFrame(streamId, first); + + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + connection.sendFrame(streamId, following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_STREAM: + requestFrame = + RequestStreamFrameCodec.encode( + allocator, + streamId, + false, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + case REQUEST_CHANNEL: + requestFrame = + RequestChannelFrameCodec.encode( + allocator, + streamId, + false, + complete, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + connection.sendFrame(streamId, requestFrame); + } + } + + static void sendTerminalFrame( + int streamId, + FrameType frameType, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean requester, + boolean onFollowingFrame, + Throwable t) { + + if (onFollowingFrame) { + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode fragmented " + + frameType + + " frame. Cause: " + + t.getMessage())); + connection.sendFrame(streamId, errorFrame); + } + } else { + switch (frameType) { + case NEXT_COMPLETE: + case NEXT: + case PAYLOAD: + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode " + frameType + " frame. Cause: " + t.getMessage())); + connection.sendFrame(streamId, errorFrame); + } + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java index 337d17c64..5aae22e89 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -20,66 +20,73 @@ import io.netty.buffer.ByteBuf; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; import io.rsocket.exceptions.RejectedResumeException; import io.rsocket.exceptions.UnsupportedSetupException; -import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.SetupFrameCodec; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.resume.*; +import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.function.BiFunction; import java.util.function.Function; import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; abstract class ServerSetup { + final Duration timeout; + + protected ServerSetup(Duration timeout) { + this.timeout = timeout; + } + + Mono> init(DuplexConnection connection) { + return Mono.>create( + sink -> sink.onRequest(__ -> new SetupHandlingDuplexConnection(connection, sink))) + .timeout(this.timeout) + .or(connection.onClose().then(Mono.error(ClosedChannelException::new))); + } + abstract Mono acceptRSocketSetup( ByteBuf frame, - ClientServerInputMultiplexer multiplexer, - BiFunction> then); + DuplexConnection clientServerConnection, + BiFunction> then); - abstract Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); + abstract Mono acceptRSocketResume(ByteBuf frame, DuplexConnection connection); void dispose() {} - Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - DuplexConnection duplexConnection = multiplexer.asSetupConnection(); - return duplexConnection - .sendOne(ErrorFrameCodec.encode(duplexConnection.alloc(), 0, exception)) - .onErrorResume(err -> Mono.empty()); + void sendError(DuplexConnection duplexConnection, RSocketErrorException exception) { + duplexConnection.sendErrorAndClose(exception); + duplexConnection.receive().subscribe(); } static class DefaultServerSetup extends ServerSetup { + DefaultServerSetup(Duration timeout) { + super(timeout); + } + @Override public Mono acceptRSocketSetup( ByteBuf frame, - ClientServerInputMultiplexer multiplexer, - BiFunction> then) { + DuplexConnection duplexConnection, + BiFunction> then) { if (SetupFrameCodec.resumeEnabled(frame)) { - return sendError(multiplexer, new UnsupportedSetupException("resume not supported")) - .doFinally( - signalType -> { - frame.release(); - multiplexer.dispose(); - }); + sendError(duplexConnection, new UnsupportedSetupException("resume not supported")); + return duplexConnection.onClose(); } else { - return then.apply(new DefaultKeepAliveHandler(multiplexer), multiplexer); + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); } } @Override - public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { - - return sendError(multiplexer, new RejectedResumeException("resume not supported")) - .doFinally( - signalType -> { - frame.release(); - multiplexer.dispose(); - }); + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { + sendError(duplexConnection, new RejectedResumeException("resume not supported")); + return duplexConnection.onClose(); } } @@ -91,11 +98,13 @@ static class ResumableServerSetup extends ServerSetup { private final boolean cleanupStoreOnKeepAlive; ResumableServerSetup( + Duration timeout, SessionManager sessionManager, Duration resumeSessionDuration, Duration resumeStreamTimeout, Function resumeStoreFactory, boolean cleanupStoreOnKeepAlive) { + super(timeout); this.sessionManager = sessionManager; this.resumeSessionDuration = resumeSessionDuration; this.resumeStreamTimeout = resumeStreamTimeout; @@ -106,47 +115,45 @@ static class ResumableServerSetup extends ServerSetup { @Override public Mono acceptRSocketSetup( ByteBuf frame, - ClientServerInputMultiplexer multiplexer, - BiFunction> then) { + DuplexConnection duplexConnection, + BiFunction> then) { if (SetupFrameCodec.resumeEnabled(frame)) { ByteBuf resumeToken = SetupFrameCodec.resumeToken(frame); - ResumableDuplexConnection connection = - sessionManager - .save( - new ServerRSocketSession( - multiplexer.asClientServerConnection(), - resumeSessionDuration, - resumeStreamTimeout, - resumeStoreFactory, - resumeToken, - cleanupStoreOnKeepAlive)) - .resumableConnection(); + final ResumableFramesStore resumableFramesStore = resumeStoreFactory.apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "server", resumeToken, duplexConnection, resumableFramesStore); + final ServerRSocketSession serverRSocketSession = + new ServerRSocketSession( + resumeToken, + resumableDuplexConnection, + duplexConnection, + resumableFramesStore, + resumeSessionDuration, + cleanupStoreOnKeepAlive); + + sessionManager.save(serverRSocketSession, resumeToken); + return then.apply( - new ResumableKeepAliveHandler(connection), - new ClientServerInputMultiplexer(connection)); + new ResumableKeepAliveHandler( + resumableDuplexConnection, serverRSocketSession, serverRSocketSession), + resumableDuplexConnection); } else { - return then.apply(new DefaultKeepAliveHandler(multiplexer), multiplexer); + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); } } @Override - public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { ServerRSocketSession session = sessionManager.get(ResumeFrameCodec.token(frame)); if (session != null) { - return session - .continueWith(multiplexer.asClientServerConnection()) - .resumeWith(frame) - .onClose() - .then(); + session.resumeWith(frame, duplexConnection); + return duplexConnection.onClose(); } else { - return sendError(multiplexer, new RejectedResumeException("unknown resume token")) - .doFinally( - s -> { - frame.release(); - multiplexer.dispose(); - }); + sendError(duplexConnection, new RejectedResumeException("unknown resume token")); + return duplexConnection.onClose(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java new file mode 100644 index 000000000..3beedf97f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java @@ -0,0 +1,176 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +class SetupHandlingDuplexConnection extends Flux + implements DuplexConnection, CoreSubscriber, Subscription { + + final DuplexConnection source; + final MonoSink> sink; + + Subscription s; + boolean firstFrameReceived = false; + + CoreSubscriber actual; + + boolean done; + Throwable t; + + SetupHandlingDuplexConnection( + DuplexConnection source, MonoSink> sink) { + this.source = source; + this.sink = sink; + + source.receive().subscribe(this); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (done) { + final Throwable t = this.t; + if (t == null) { + Operators.complete(actual); + } else { + Operators.error(actual, t); + } + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public void request(long n) { + if (n != Long.MAX_VALUE) { + actual.onError(new IllegalArgumentException("Only unbounded request is allowed")); + return; + } + + s.request(Long.MAX_VALUE); + } + + @Override + public void cancel() { + source.dispose(); + s.cancel(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(1); + } + } + + @Override + public void onNext(ByteBuf frame) { + if (!firstFrameReceived) { + firstFrameReceived = true; + sink.success(Tuples.of(frame, this)); + return; + } + + actual.onNext(frame); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.done = true; + this.t = t; + + if (!firstFrameReceived) { + sink.error(t); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onError(t); + } + } + + @Override + public void onComplete() { + if (done) { + return; + } + + this.done = true; + + if (!firstFrameReceived) { + sink.error(new ClosedChannelException()); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onComplete(); + } + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public String toString() { + return "SetupHandlingDuplexConnection{" + "source=" + source + ", done=" + done + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java new file mode 100644 index 000000000..3035696b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java @@ -0,0 +1,255 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.isReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.isSubscribedOrTerminated; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.lazyTerminate; +import static io.rsocket.core.StateUtils.markReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.markSubscribed; +import static io.rsocket.core.StateUtils.markTerminated; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class SlowFireAndForgetRequesterMono extends Mono + implements LeasePermitHandler, Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(SlowFireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + CoreSubscriber actual; + + SlowFireAndForgetRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + long previousState = markSubscribed(STATE, this, !leaseEnabled); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstFrame(p); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstFrame(this.payload); + return true; + } + + void sendFirstFrame(Payload p) { + final CoreSubscriber actual = this.actual; + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePermitError(Throwable cause) { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + this.actual.onError(cause); + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java new file mode 100644 index 000000000..2b6a0e09a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java @@ -0,0 +1,493 @@ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +final class StateUtils { + + /** Volatile Long Field bit mask that allows extract flags stored in the field */ + static final long FLAGS_MASK = + 0b111111111111111111111111111111111_0000000000000000000000000000000L; + /** Volatile Long Field bit mask that allows extract int RequestN stored in the field */ + static final long REQUEST_MASK = + 0b000000000000000000000000000000000_1111111111111111111111111111111L; + /** Bit Flag that indicates Requester Producer has been subscribed once */ + static final long SUBSCRIBED_FLAG = + 0b000000000000000000000000000000001_0000000000000000000000000000000L; + /** Bit Flag that indicates that the first payload in RequestChannel scenario is received */ + static final long FIRST_PAYLOAD_RECEIVED_FLAG = + 0b000000000000000000000000000000010_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that the logical stream is ready to send the first initial frame + * (applicable for requester only) + */ + static final long READY_TO_SEND_FIRST_FRAME_FLAG = + 0b000000000000000000000000000000100_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that sent first initial frame was sent (in case of requester) or + * consumed (if responder) + */ + static final long FIRST_FRAME_SENT_FLAG = + 0b000000000000000000000000000001000_0000000000000000000000000000000L; + /** Bit Flag that indicates that there is a frame being reassembled */ + static final long REASSEMBLING_FLAG = + 0b000000000000000000000000000010000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the inbound is terminated + */ + static final long INBOUND_TERMINATED_FLAG = + 0b000000000000000000000000000100000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the outbound is terminated + */ + static final long OUTBOUND_TERMINATED_FLAG = + 0b000000000000000000000000001000000_0000000000000000000000000000000L; + /** Initial state for any request operator */ + static final long UNSUBSCRIBED_STATE = + 0b000000000000000000000000000000000_0000000000000000000000000000000L; + /** State that indicates request operator was terminated */ + static final long TERMINATED_STATE = + 0b100000000000000000000000000000000_0000000000000000000000000000000L; + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

    Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed(AtomicLongFieldUpdater updater, T instance) { + return markSubscribed(updater, instance, false); + } + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

    Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param markPrepared indicates whether the given instance should be marked as prepared + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed( + AtomicLongFieldUpdater updater, T instance, boolean markPrepared) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG) { + return state; + } + + if (updater.compareAndSet( + instance, + state, + state | SUBSCRIBED_FLAG | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0))) { + return state; + } + } + } + + /** + * Indicates that the given stream has already been subscribed once + * + * @param state to check whether stream is subscribed + * @return true if the {@link #SUBSCRIBED_FLAG} flag is set + */ + static boolean isSubscribed(long state) { + return (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_FRAME_SENT_FLAG} flag which indicates + * that the first frame has already set and logical stream has already been established. + * + *

    Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been established once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstFrameSent(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_FRAME_SENT_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the first frame which established logical stream has already been sent + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_FRAME_SENT_FLAG} flag is set + */ + static boolean isFirstFrameSent(long state) { + return (state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

    Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReadyToSendFirstFrame(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | READY_TO_SEND_FIRST_FRAME_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is prepared for initial frame sending + * @return true if the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag is set + */ + static boolean isReadyToSendFirstFrame(long state) { + return (state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

    Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstPayloadReceived(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_PAYLOAD_RECEIVED_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag is set + */ + static boolean isFirstPayloadReceived(long state) { + return (state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #REASSEMBLING_FLAG} flag which indicates that + * there is a payload reassembling in progress. + * + *

    Note, the flag will not be added if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembling(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state | REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Removes (if possible) from the given state the {@link #REASSEMBLING_FLAG} flag which indicates + * that a payload reassembly process is completed. + * + *

    Note, the flag will not be removed if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembled(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state & ~REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Indicates that a payload reassembly process is completed. + * + * @param state to check whether there is reassembly in progress + * @return true if the {@link #REASSEMBLING_FLAG} flag is set + */ + static boolean isReassembling(long state) { + return (state & REASSEMBLING_FLAG) == REASSEMBLING_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #INBOUND_TERMINATED_FLAG} flag which indicates + * that an inbound channel of a bidirectional stream is terminated. + * + *

    Note, this action will have no effect if the stream has already been terminated or if + * the {@link #INBOUND_TERMINATED_FLAG} flag has already been set.
    + * Note, if the outbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markInboundTerminated(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + return state; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | INBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the inbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #INBOUND_TERMINATED_FLAG} set + * @return true if the {@link #INBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isInboundTerminated(long state) { + return (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #OUTBOUND_TERMINATED_FLAG} flag which + * indicates that an outbound channel of a bidirectional stream is terminated. + * + *

    Note, this action will have no effect if the stream has already been terminated or if + * the {@link #OUTBOUND_TERMINATED_FLAG} flag has already been set.
    + * Note, if the {@code checkEstablishment} parameter is {@code true} and the logical stream + * is not established, then the result state will be {@link #TERMINATED_STATE}
    + * Note, if the inbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param checkEstablishment indicates whether {@link #FIRST_FRAME_SENT_FLAG} should be checked to + * make final decision + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markOutboundTerminated( + AtomicLongFieldUpdater updater, T instance, boolean checkEstablishment) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + return state; + } + + if ((checkEstablishment && !isFirstFrameSent(state)) + || (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | OUTBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isOutboundTerminated(long state) { + return (state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG; + } + + /** + * Makes current state a {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markTerminated(AtomicLongFieldUpdater updater, T instance) { + return updater.getAndSet(instance, TERMINATED_STATE); + } + + /** + * Makes current state a {@link #TERMINATED_STATE} using {@link + * AtomicLongFieldUpdater#lazySet(Object, long)} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + */ + static void lazyTerminate(AtomicLongFieldUpdater updater, T instance) { + updater.lazySet(instance, TERMINATED_STATE); + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isTerminated(long state) { + return state == TERMINATED_STATE; + } + + /** + * Shortcut for {@link #isSubscribed} {@code ||} {@link #isTerminated} methods + * + * @param state to check flags on + * @return true if state is terminated or has flag subscribed + */ + static boolean isSubscribedOrTerminated(long state) { + return state == TERMINATED_STATE || (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + static long addRequestN(AtomicLongFieldUpdater updater, T instance, long toAdd) { + return addRequestN(updater, instance, toAdd, false); + } + + static long addRequestN( + AtomicLongFieldUpdater updater, T instance, long toAdd, boolean markPrepared) { + long currentState, flags, requestN, nextRequestN; + for (; ; ) { + currentState = updater.get(instance); + + if (currentState == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + requestN = currentState & REQUEST_MASK; + if (requestN == REQUEST_MASK) { + return currentState; + } + + flags = (currentState & FLAGS_MASK) | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0); + nextRequestN = addRequestN(requestN, toAdd); + + if (updater.compareAndSet(instance, currentState, nextRequestN | flags)) { + return currentState; + } + } + } + + static long addRequestN(long a, long b) { + long res = a + b; + if (res < 0 || res > REQUEST_MASK) { + return REQUEST_MASK; + } + return res; + } + + static boolean hasRequested(long state) { + return (state & REQUEST_MASK) > 0; + } + + static long extractRequestN(long state) { + long requestN = state & REQUEST_MASK; + + if (requestN == REQUEST_MASK) { + return REQUEST_MASK; + } + + return requestN; + } + + static boolean isMaxAllowedRequestN(long n) { + return n >= REQUEST_MASK; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java index cd0d46754..40cb15dd6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -25,7 +26,7 @@ * @see Error * Codes */ -public final class ApplicationErrorException extends RSocketException { +public final class ApplicationErrorException extends RSocketErrorException { private static final long serialVersionUID = 7873267740343446585L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java index d51ba0fb7..144ef94c6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -26,7 +27,7 @@ * @see Error * Codes */ -public final class CanceledException extends RSocketException { +public final class CanceledException extends RSocketErrorException { private static final long serialVersionUID = 5074789326089722770L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java index 80324aa90..1e0167bdd 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -26,7 +27,7 @@ * @see Error * Codes */ -public final class ConnectionCloseException extends RSocketException { +public final class ConnectionCloseException extends RSocketErrorException { private static final long serialVersionUID = -2214953527482377471L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java index b44714f7e..5cf7cff66 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -26,7 +27,7 @@ * @see Error * Codes */ -public final class ConnectionErrorException extends RSocketException implements Retryable { +public final class ConnectionErrorException extends RSocketErrorException implements Retryable { private static final long serialVersionUID = 512325887785119744L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java index 079b561f9..a72c0ba3b 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -16,10 +16,11 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; -public class CustomRSocketException extends RSocketException { +public class CustomRSocketException extends RSocketErrorException { private static final long serialVersionUID = 7873267740343446585L; /** diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java index a1b77b8dd..c556423b9 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -25,7 +26,7 @@ * @see Error * Codes */ -public final class InvalidException extends RSocketException { +public final class InvalidException extends RSocketErrorException { private static final long serialVersionUID = 8279420324864928243L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java deleted file mode 100644 index 2b137282f..000000000 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.exceptions; - -import io.rsocket.RSocketErrorException; -import io.rsocket.frame.ErrorFrameCodec; -import reactor.util.annotation.Nullable; - -/** - * The root of the RSocket exception hierarchy. - * - * @deprecated please use {@link RSocketErrorException} instead - */ -@Deprecated -public abstract class RSocketException extends RSocketErrorException { - - private static final long serialVersionUID = 2912815394105575423L; - - /** - * Constructs a new exception with the specified message and error code 0x201 (Application error). - * - * @param message the message - */ - public RSocketException(String message) { - this(message, null); - } - - /** - * Constructs a new exception with the specified message and cause and error code 0x201 - * (Application error). - * - * @param message the message - * @param cause the cause of this exception - */ - public RSocketException(String message, @Nullable Throwable cause) { - super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); - } - - /** - * Constructs a new exception with the specified error code, message and cause. - * - * @param errorCode the RSocket protocol error code - * @param message the message - * @param cause the cause of this exception - */ - public RSocketException(int errorCode, String message, @Nullable Throwable cause) { - super(errorCode, message, cause); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java index baed84e1b..8bc946e3d 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -27,7 +28,7 @@ * @see Error * Codes */ -public class RejectedException extends RSocketException implements Retryable { +public class RejectedException extends RSocketErrorException implements Retryable { private static final long serialVersionUID = 3926231092835143715L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java index 8a99fcffb..44cc55710 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java @@ -16,6 +16,7 @@ package io.rsocket.exceptions; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; @@ -25,7 +26,7 @@ * @see Error * Codes */ -public final class RejectedResumeException extends RSocketException { +public final class RejectedResumeException extends RSocketErrorException { private static final long serialVersionUID = -873684362478544811L; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java index ed979c9e6..76dc39a59 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java @@ -16,37 +16,14 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.RSocketErrorException; import reactor.util.annotation.Nullable; /** The root of the setup exception hierarchy. */ -public abstract class SetupException extends RSocketException { +public abstract class SetupException extends RSocketErrorException { private static final long serialVersionUID = -2928269501877732756L; - /** - * Constructs a new exception with the specified message. - * - * @param message the message - * @deprecated please use {@link #SetupException(int, String, Throwable)} - */ - @Deprecated - public SetupException(String message) { - this(message, null); - } - - /** - * Constructs a new exception with the specified message and cause. - * - * @param message the message - * @param cause the cause of this exception - * @deprecated please use {@link #SetupException(int, String, Throwable)} - */ - @Deprecated - public SetupException(String message, @Nullable Throwable cause) { - this(ErrorFrameCodec.INVALID_SETUP, message, cause); - } - /** * Constructs a new exception with the specified error code, message and cause. * diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java deleted file mode 100644 index 6eebd676c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static io.rsocket.fragmentation.FrameFragmenter.fragmentFrame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameType; -import java.util.Objects; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * A {@link DuplexConnection} implementation that fragments and reassembles {@link ByteBuf}s. - * - * @see Fragmentation - * and Reassembly - */ -public final class FragmentationDuplexConnection extends ReassemblyDuplexConnection - implements DuplexConnection { - - public static final int MIN_MTU_SIZE = 64; - - private static final Logger logger = LoggerFactory.getLogger(FragmentationDuplexConnection.class); - - final DuplexConnection delegate; - final int mtu; - final String type; - - /** - * Class constructor. - * - * @param delegate the underlying connection - * @param mtu the fragment size, greater than {@link #MIN_MTU_SIZE} - * @param maxInboundPayloadSize the maximum payload size, which can be reassembled from multiple - * fragments - * @param type a label to use for logging purposes - */ - public FragmentationDuplexConnection( - DuplexConnection delegate, int mtu, int maxInboundPayloadSize, String type) { - super(delegate, maxInboundPayloadSize); - - Objects.requireNonNull(delegate, "delegate must not be null"); - this.delegate = delegate; - this.mtu = assertMtu(mtu); - this.type = type; - } - - private boolean shouldFragment(FrameType frameType, int readableBytes) { - return frameType.isFragmentable() && readableBytes > mtu; - } - - public static int assertMtu(int mtu) { - if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { - String msg = - String.format( - "The smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); - throw new IllegalArgumentException(msg); - } else { - return mtu; - } - } - - @Override - public Mono send(Publisher frames) { - return Flux.from(frames).concatMap(this::sendOne).then(); - } - - @Override - public Mono sendOne(ByteBuf frame) { - FrameType frameType = FrameHeaderCodec.frameType(frame); - int readableBytes = frame.readableBytes(); - if (!shouldFragment(frameType, readableBytes)) { - return delegate.sendOne(frame); - } - Flux fragments = Flux.from(fragmentFrame(alloc(), mtu, frame, frameType)); - if (logger.isDebugEnabled()) { - fragments = - fragments.doOnNext( - byteBuf -> { - logger.debug( - "{} - stream id {} - frame type {} - \n {}", - type, - FrameHeaderCodec.streamId(byteBuf), - FrameHeaderCodec.frameType(byteBuf), - ByteBufUtil.prettyHexDump(byteBuf)); - }); - } - return delegate.send(fragments); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java deleted file mode 100644 index fcb6198a3..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameType; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.frame.RequestChannelFrameCodec; -import io.rsocket.frame.RequestFireAndForgetFrameCodec; -import io.rsocket.frame.RequestResponseFrameCodec; -import io.rsocket.frame.RequestStreamFrameCodec; -import java.util.function.Consumer; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.SynchronousSink; - -/** - * The implementation of the RSocket fragmentation behavior. - * - * @see Fragmentation - * and Reassembly - */ -final class FrameFragmenter { - static Publisher fragmentFrame( - ByteBufAllocator allocator, int mtu, final ByteBuf frame, FrameType frameType) { - ByteBuf metadata = getMetadata(frame, frameType); - ByteBuf data = getData(frame, frameType); - int streamId = FrameHeaderCodec.streamId(frame); - return Flux.generate( - new Consumer>() { - boolean first = true; - - @Override - public void accept(SynchronousSink sink) { - ByteBuf byteBuf; - if (first) { - first = false; - byteBuf = - encodeFirstFragment( - allocator, mtu, frame, frameType, streamId, metadata, data); - } else { - byteBuf = encodeFollowsFragment(allocator, mtu, streamId, metadata, data); - } - - sink.next(byteBuf); - if (!metadata.isReadable() && !data.isReadable()) { - sink.complete(); - } - } - }) - .doFinally(signalType -> ReferenceCountUtil.safeRelease(frame)); - } - - static ByteBuf encodeFirstFragment( - ByteBufAllocator allocator, - int mtu, - ByteBuf frame, - FrameType frameType, - int streamId, - ByteBuf metadata, - ByteBuf data) { - // subtract the header bytes - int remaining = mtu - FrameHeaderCodec.size(); - - // substract the initial request n - switch (frameType) { - case REQUEST_STREAM: - case REQUEST_CHANNEL: - remaining -= Integer.BYTES; - break; - default: - } - - ByteBuf metadataFragment = null; - if (metadata.isReadable()) { - // subtract the metadata frame length - remaining -= 3; - int r = Math.min(remaining, metadata.readableBytes()); - remaining -= r; - metadataFragment = metadata.readRetainedSlice(r); - } - - ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; - if (remaining > 0 && data.isReadable()) { - int r = Math.min(remaining, data.readableBytes()); - dataFragment = data.readRetainedSlice(r); - } - - switch (frameType) { - case REQUEST_FNF: - return RequestFireAndForgetFrameCodec.encode( - allocator, streamId, true, metadataFragment, dataFragment); - case REQUEST_STREAM: - return RequestStreamFrameCodec.encode( - allocator, - streamId, - true, - RequestStreamFrameCodec.initialRequestN(frame), - metadataFragment, - dataFragment); - case REQUEST_RESPONSE: - return RequestResponseFrameCodec.encode( - allocator, streamId, true, metadataFragment, dataFragment); - case REQUEST_CHANNEL: - return RequestChannelFrameCodec.encode( - allocator, - streamId, - true, - false, - RequestChannelFrameCodec.initialRequestN(frame), - metadataFragment, - dataFragment); - // Payload and synthetic types - case PAYLOAD: - return PayloadFrameCodec.encode( - allocator, streamId, true, false, false, metadataFragment, dataFragment); - case NEXT: - return PayloadFrameCodec.encode( - allocator, streamId, true, false, true, metadataFragment, dataFragment); - case NEXT_COMPLETE: - return PayloadFrameCodec.encode( - allocator, streamId, true, true, true, metadataFragment, dataFragment); - case COMPLETE: - return PayloadFrameCodec.encode( - allocator, streamId, true, true, false, metadataFragment, dataFragment); - default: - throw new IllegalStateException("unsupported fragment type: " + frameType); - } - } - - static ByteBuf encodeFollowsFragment( - ByteBufAllocator allocator, int mtu, int streamId, ByteBuf metadata, ByteBuf data) { - // subtract the header bytes - int remaining = mtu - FrameHeaderCodec.size(); - - ByteBuf metadataFragment = null; - if (metadata.isReadable()) { - // subtract the metadata frame length - remaining -= 3; - int r = Math.min(remaining, metadata.readableBytes()); - remaining -= r; - metadataFragment = metadata.readRetainedSlice(r); - } - - ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; - if (remaining > 0 && data.isReadable()) { - int r = Math.min(remaining, data.readableBytes()); - dataFragment = data.readRetainedSlice(r); - } - - boolean follows = data.isReadable() || metadata.isReadable(); - return PayloadFrameCodec.encode( - allocator, streamId, follows, false, true, metadataFragment, dataFragment); - } - - static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { - boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); - if (hasMetadata) { - ByteBuf metadata; - switch (frameType) { - case REQUEST_FNF: - metadata = RequestFireAndForgetFrameCodec.metadata(frame); - break; - case REQUEST_STREAM: - metadata = RequestStreamFrameCodec.metadata(frame); - break; - case REQUEST_RESPONSE: - metadata = RequestResponseFrameCodec.metadata(frame); - break; - case REQUEST_CHANNEL: - metadata = RequestChannelFrameCodec.metadata(frame); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - metadata = PayloadFrameCodec.metadata(frame); - break; - default: - throw new IllegalStateException("unsupported fragment type"); - } - return metadata; - } else { - return Unpooled.EMPTY_BUFFER; - } - } - - static ByteBuf getData(ByteBuf frame, FrameType frameType) { - ByteBuf data; - switch (frameType) { - case REQUEST_FNF: - data = RequestFireAndForgetFrameCodec.data(frame); - break; - case REQUEST_STREAM: - data = RequestStreamFrameCodec.data(frame); - break; - case REQUEST_RESPONSE: - data = RequestResponseFrameCodec.data(frame); - break; - case REQUEST_CHANNEL: - data = RequestChannelFrameCodec.data(frame); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - data = PayloadFrameCodec.data(frame); - break; - default: - throw new IllegalStateException("unsupported fragment type"); - } - return data; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java deleted file mode 100644 index d1adbfdf7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ /dev/null @@ -1,342 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.collection.IntObjectHashMap; -import io.netty.util.collection.IntObjectMap; -import io.rsocket.frame.*; -import java.util.concurrent.atomic.AtomicBoolean; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.SynchronousSink; -import reactor.util.annotation.Nullable; - -/** - * The implementation of the RSocket reassembly behavior. - * - * @see Fragmentation - * and Reassembly - */ -final class FrameReassembler extends AtomicBoolean implements Disposable { - - private static final long serialVersionUID = -4394598098863449055L; - - private static final Logger logger = LoggerFactory.getLogger(FrameReassembler.class); - - final IntObjectMap headers; - final IntObjectMap metadata; - final IntObjectMap data; - - final ByteBufAllocator allocator; - final int maxInboundPayloadSize; - - public FrameReassembler(ByteBufAllocator allocator, int maxInboundPayloadSize) { - this.allocator = allocator; - this.maxInboundPayloadSize = maxInboundPayloadSize; - this.headers = new IntObjectHashMap<>(); - this.metadata = new IntObjectHashMap<>(); - this.data = new IntObjectHashMap<>(); - } - - @Override - public void dispose() { - if (compareAndSet(false, true)) { - synchronized (FrameReassembler.this) { - for (ByteBuf byteBuf : headers.values()) { - ReferenceCountUtil.safeRelease(byteBuf); - } - headers.clear(); - - for (ByteBuf byteBuf : metadata.values()) { - ReferenceCountUtil.safeRelease(byteBuf); - } - metadata.clear(); - - for (ByteBuf byteBuf : data.values()) { - ReferenceCountUtil.safeRelease(byteBuf); - } - data.clear(); - } - } - } - - @Override - public boolean isDisposed() { - return get(); - } - - @Nullable - synchronized ByteBuf getHeader(int streamId) { - return headers.get(streamId); - } - - synchronized CompositeByteBuf getMetadata(int streamId) { - CompositeByteBuf byteBuf = metadata.get(streamId); - - if (byteBuf == null) { - byteBuf = allocator.compositeBuffer(); - metadata.put(streamId, byteBuf); - } - - return byteBuf; - } - - synchronized int getMetadataSize(int streamId) { - CompositeByteBuf byteBuf = metadata.get(streamId); - - if (byteBuf == null) { - return 0; - } - - return byteBuf.readableBytes(); - } - - synchronized CompositeByteBuf getData(int streamId) { - CompositeByteBuf byteBuf = data.get(streamId); - - if (byteBuf == null) { - byteBuf = allocator.compositeBuffer(); - data.put(streamId, byteBuf); - } - - return byteBuf; - } - - synchronized int getDataSize(int streamId) { - CompositeByteBuf byteBuf = data.get(streamId); - - if (byteBuf == null) { - return 0; - } - - return byteBuf.readableBytes(); - } - - @Nullable - synchronized ByteBuf removeHeader(int streamId) { - return headers.remove(streamId); - } - - @Nullable - synchronized CompositeByteBuf removeMetadata(int streamId) { - return metadata.remove(streamId); - } - - @Nullable - synchronized CompositeByteBuf removeData(int streamId) { - return data.remove(streamId); - } - - synchronized void putHeader(int streamId, ByteBuf header) { - headers.put(streamId, header); - } - - void cancelAssemble(int streamId) { - ByteBuf header = removeHeader(streamId); - CompositeByteBuf metadata = removeMetadata(streamId); - CompositeByteBuf data = removeData(streamId); - - if (header != null) { - ReferenceCountUtil.safeRelease(header); - } - - if (metadata != null) { - ReferenceCountUtil.safeRelease(metadata); - } - - if (data != null) { - ReferenceCountUtil.safeRelease(data); - } - } - - void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int streamId) { - ByteBuf header = removeHeader(streamId); - if (header != null) { - - int maxReassemblySize = this.maxInboundPayloadSize; - if (maxReassemblySize != Integer.MAX_VALUE) { - int currentPayloadSize = getMetadataSize(streamId) + getDataSize(streamId); - if (currentPayloadSize + frame.readableBytes() - FrameHeaderCodec.size() - > maxReassemblySize) { - frame.release(); - throw new IllegalStateException("Reassembled payload went out of allowed size"); - } - } - - if (FrameHeaderCodec.hasMetadata(header)) { - ByteBuf assembledFrame = assembleFrameWithMetadata(frame, streamId, header); - sink.next(assembledFrame); - } else { - ByteBuf data = assembleData(frame, streamId); - ByteBuf assembledFrame = FragmentationCodec.encode(allocator, header, data); - sink.next(assembledFrame); - } - frame.release(); - } else { - sink.next(frame); - } - } - - void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { - - int maxReassemblySize = this.maxInboundPayloadSize; - if (maxReassemblySize != Integer.MAX_VALUE) { - int currentPayloadSize = getMetadataSize(streamId) + getDataSize(streamId); - if (currentPayloadSize + frame.readableBytes() - FrameHeaderCodec.size() - > maxReassemblySize) { - frame.release(); - throw new IllegalStateException("Reassembled payload went out of allowed size"); - } - } - - ByteBuf header = getHeader(streamId); - if (header == null) { - header = frame.copy(frame.readerIndex(), FrameHeaderCodec.size()); - - if (frameType == FrameType.REQUEST_CHANNEL || frameType == FrameType.REQUEST_STREAM) { - long i = RequestChannelFrameCodec.initialRequestN(frame); - header.writeInt(i > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) i); - } - putHeader(streamId, header); - } - - if (FrameHeaderCodec.hasMetadata(frame)) { - CompositeByteBuf metadata = getMetadata(streamId); - switch (frameType) { - case REQUEST_FNF: - metadata.addComponents(true, RequestFireAndForgetFrameCodec.metadata(frame).retain()); - break; - case REQUEST_STREAM: - metadata.addComponents(true, RequestStreamFrameCodec.metadata(frame).retain()); - break; - case REQUEST_RESPONSE: - metadata.addComponents(true, RequestResponseFrameCodec.metadata(frame).retain()); - break; - case REQUEST_CHANNEL: - metadata.addComponents(true, RequestChannelFrameCodec.metadata(frame).retain()); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - metadata.addComponents(true, PayloadFrameCodec.metadata(frame).retain()); - break; - default: - throw new IllegalStateException("unsupported fragment type"); - } - } - - ByteBuf data; - switch (frameType) { - case REQUEST_FNF: - data = RequestFireAndForgetFrameCodec.data(frame).retain(); - break; - case REQUEST_STREAM: - data = RequestStreamFrameCodec.data(frame).retain(); - break; - case REQUEST_RESPONSE: - data = RequestResponseFrameCodec.data(frame).retain(); - break; - case REQUEST_CHANNEL: - data = RequestChannelFrameCodec.data(frame).retain(); - break; - // Payload and synthetic types - case PAYLOAD: - case NEXT: - case NEXT_COMPLETE: - case COMPLETE: - data = PayloadFrameCodec.data(frame).retain(); - break; - default: - frame.release(); - throw new IllegalStateException("unsupported fragment type"); - } - - getData(streamId).addComponents(true, data); - frame.release(); - } - - void reassembleFrame(ByteBuf frame, SynchronousSink sink) { - try { - FrameType frameType = FrameHeaderCodec.frameType(frame); - int streamId = FrameHeaderCodec.streamId(frame); - switch (frameType) { - case CANCEL: - case ERROR: - cancelAssemble(streamId); - } - - if (!frameType.isFragmentable()) { - sink.next(frame); - return; - } - - boolean hasFollows = FrameHeaderCodec.hasFollows(frame); - - if (hasFollows) { - handleFollowsFlag(frame, streamId, frameType); - } else { - handleNoFollowsFlag(frame, sink, streamId); - } - - } catch (Throwable t) { - logger.error("error reassemble frame", t); - sink.error(t); - } - } - - private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf header) { - ByteBuf metadata; - CompositeByteBuf cm = removeMetadata(streamId); - - ByteBuf decodedMetadata = PayloadFrameCodec.metadata(frame); - if (decodedMetadata != null) { - if (cm != null) { - metadata = cm.addComponents(true, decodedMetadata.retain()); - } else { - metadata = PayloadFrameCodec.metadata(frame).retain(); - } - } else { - metadata = cm; - } - - ByteBuf data = assembleData(frame, streamId); - - return FragmentationCodec.encode(allocator, header, metadata, data); - } - - private ByteBuf assembleData(ByteBuf frame, int streamId) { - ByteBuf data; - CompositeByteBuf cd = removeData(streamId); - if (cd != null) { - cd.addComponents(true, PayloadFrameCodec.data(frame).retain()); - data = cd; - } else { - data = Unpooled.EMPTY_BUFFER; - } - - return data; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java deleted file mode 100644 index 03f97c75d..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameLengthCodec; -import java.util.Objects; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * A {@link DuplexConnection} implementation that reassembles {@link ByteBuf}s. - * - * @see Fragmentation - * and Reassembly - */ -public class ReassemblyDuplexConnection implements DuplexConnection { - private final DuplexConnection delegate; - private final FrameReassembler frameReassembler; - - /** Constructor with the underlying delegate to receive frames from. */ - public ReassemblyDuplexConnection(DuplexConnection delegate, int maxInboundPayloadSize) { - Objects.requireNonNull(delegate, "delegate must not be null"); - this.delegate = delegate; - this.frameReassembler = new FrameReassembler(delegate.alloc(), maxInboundPayloadSize); - - delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); - } - - public static int assertInboundPayloadSize(int inboundPayloadSize) { - if (inboundPayloadSize < FragmentationDuplexConnection.MIN_MTU_SIZE) { - String msg = - String.format( - "The min allowed inboundPayloadSize size is %d bytes, provided: %d", - FrameLengthCodec.FRAME_LENGTH_MASK, inboundPayloadSize); - throw new IllegalArgumentException(msg); - } else { - return inboundPayloadSize; - } - } - - @Override - public Mono send(Publisher frames) { - return delegate.send(frames); - } - - @Override - public Mono sendOne(ByteBuf frame) { - return delegate.sendOne(frame); - } - - @Override - public Flux receive() { - return delegate.receive().handle(frameReassembler::reassembleFrame); - } - - @Override - public ByteBufAllocator alloc() { - return delegate.alloc(); - } - - @Override - public Mono onClose() { - return delegate.onClose(); - } - - @Override - public void dispose() { - delegate.dispose(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java index 28f39459d..fc146c935 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -60,6 +60,10 @@ public static boolean hasFollows(ByteBuf byteBuf) { return (flags(byteBuf) & FLAGS_F) == FLAGS_F; } + public static boolean hasComplete(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_C) == FLAGS_C; + } + public static int streamId(ByteBuf byteBuf) { byteBuf.markReaderIndex(); int streamId = byteBuf.readInt(); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java index 66d18c8a7..d581731a3 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.frame; import io.netty.buffer.ByteBuf; @@ -99,8 +114,9 @@ private static ByteBuf getData(ByteBuf frame, FrameType frameType) { case REQUEST_CHANNEL: data = RequestChannelFrameCodec.data(frame); break; - // Payload and synthetic types + // Payload, KeepAlive and synthetic types case PAYLOAD: + case KEEPALIVE: case NEXT: case NEXT_COMPLETE: case COMPLETE: diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java index e6874c097..0d8063e0b 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -52,12 +52,12 @@ public Payload apply(ByteBuf byteBuf) { throw new IllegalArgumentException("unsupported frame type: " + type); } - ByteBuffer data = ByteBuffer.allocateDirect(d.readableBytes()); + ByteBuffer data = ByteBuffer.allocate(d.readableBytes()); data.put(d.nioBuffer()); data.flip(); if (m != null) { - ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); + ByteBuffer metadata = ByteBuffer.allocate(m.readableBytes()); metadata.put(m.nioBuffer()); metadata.flip(); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java index 9668e5e18..0296b0a07 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java @@ -1,30 +1,56 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.internal; +import io.netty.buffer.ByteBuf; import io.rsocket.DuplexConnection; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; public abstract class BaseDuplexConnection implements DuplexConnection { - private MonoProcessor onClose = MonoProcessor.create(); + protected final Sinks.Empty onClose = Sinks.empty(); + protected final UnboundedProcessor sender = new UnboundedProcessor(onClose::tryEmitEmpty); - public BaseDuplexConnection() { - onClose.doFinally(s -> doOnClose()).subscribe(); + public BaseDuplexConnection() {} + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + sender.tryEmitPrioritized(frame); + } else { + sender.tryEmitNormal(frame); + } } protected abstract void doOnClose(); @Override - public final Mono onClose() { - return onClose; + public Mono onClose() { + return onClose.asMono(); } @Override public final void dispose() { - onClose.onComplete(); + doOnClose(); } @Override + @SuppressWarnings("ConstantConditions") public final boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index 038120efc..8b1378917 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -1,228 +1 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Closeable; -import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameUtil; -import io.rsocket.plugins.DuplexConnectionInterceptor.Type; -import io.rsocket.plugins.InitializingInterceptorRegistry; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * {@link DuplexConnection#receive()} is a single stream on which the following type of frames - * arrive: - * - *

      - *
    • Frames for streams initiated by the initiator of the connection (client). - *
    • Frames for streams initiated by the acceptor of the connection (server). - *
    - * - *

    The only way to differentiate these two frames is determining whether the stream Id is odd or - * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the - * client. - */ -public class ClientServerInputMultiplexer implements Closeable { - private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); - private static final InitializingInterceptorRegistry emptyInterceptorRegistry = - new InitializingInterceptorRegistry(); - - private final DuplexConnection setupConnection; - private final DuplexConnection serverConnection; - private final DuplexConnection clientConnection; - private final DuplexConnection source; - private final DuplexConnection clientServerConnection; - - public ClientServerInputMultiplexer(DuplexConnection source) { - this(source, emptyInterceptorRegistry, false); - } - - public ClientServerInputMultiplexer( - DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { - this.source = source; - final MonoProcessor> setup = MonoProcessor.create(); - final MonoProcessor> server = MonoProcessor.create(); - final MonoProcessor> client = MonoProcessor.create(); - - source = registry.initConnection(Type.SOURCE, source); - setupConnection = - registry.initConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); - serverConnection = - registry.initConnection(Type.SERVER, new InternalDuplexConnection(source, server)); - clientConnection = - registry.initConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); - clientServerConnection = new InternalDuplexConnection(source, client, server); - - source - .receive() - .groupBy( - frame -> { - int streamId = FrameHeaderCodec.streamId(frame); - final Type type; - if (streamId == 0) { - switch (FrameHeaderCodec.frameType(frame)) { - case SETUP: - case RESUME: - case RESUME_OK: - type = Type.SETUP; - break; - case LEASE: - case KEEPALIVE: - case ERROR: - type = isClient ? Type.CLIENT : Type.SERVER; - break; - default: - type = isClient ? Type.SERVER : Type.CLIENT; - } - } else if ((streamId & 0b1) == 0) { - type = Type.SERVER; - } else { - type = Type.CLIENT; - } - return type; - }) - .subscribe( - group -> { - switch (group.key()) { - case SETUP: - setup.onNext(group); - break; - - case SERVER: - server.onNext(group); - break; - - case CLIENT: - client.onNext(group); - break; - } - }, - t -> {}); - } - - public DuplexConnection asClientServerConnection() { - return clientServerConnection; - } - - public DuplexConnection asServerConnection() { - return serverConnection; - } - - public DuplexConnection asClientConnection() { - return clientConnection; - } - - public DuplexConnection asSetupConnection() { - return setupConnection; - } - - @Override - public void dispose() { - source.dispose(); - } - - @Override - public boolean isDisposed() { - return source.isDisposed(); - } - - @Override - public Mono onClose() { - return source.onClose(); - } - - private static class InternalDuplexConnection implements DuplexConnection { - private final DuplexConnection source; - private final MonoProcessor>[] processors; - private final boolean debugEnabled; - - @SafeVarargs - public InternalDuplexConnection( - DuplexConnection source, MonoProcessor>... processors) { - this.source = source; - this.processors = processors; - this.debugEnabled = LOGGER.isDebugEnabled(); - } - - @Override - public Mono send(Publisher frame) { - if (debugEnabled) { - frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + FrameUtil.toString(f))); - } - - return source.send(frame); - } - - @Override - public Mono sendOne(ByteBuf frame) { - if (debugEnabled) { - LOGGER.debug("sending -> " + FrameUtil.toString(frame)); - } - - return source.sendOne(frame); - } - - @Override - public Flux receive() { - return Flux.fromArray(processors) - .flatMap( - p -> - p.flatMapMany( - f -> { - if (debugEnabled) { - return f.doOnNext( - frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); - } else { - return f; - } - })); - } - - @Override - public ByteBufAllocator alloc() { - return source.alloc(); - } - - @Override - public void dispose() { - source.dispose(); - } - - @Override - public boolean isDisposed() { - return source.isDisposed(); - } - - @Override - public Mono onClose() { - return source.onClose(); - } - - @Override - public double availability() { - return source.availability(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java b/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java deleted file mode 100644 index fd6bf0aed..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java +++ /dev/null @@ -1,748 +0,0 @@ -/* - * Copyright 2014 The Netty Project - * - * The Netty Project licenses this file to you under the Apache License, version 2.0 (the - * "License"); you may not use this file except in compliance with the License. You may obtain a - * copy of the License at: - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. - */ - -package io.rsocket.internal; - -import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; - -import io.netty.util.collection.IntObjectMap; -import java.util.AbstractCollection; -import java.util.AbstractSet; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.Map; -import java.util.NoSuchElementException; -import java.util.Set; - -/** - * A hash map implementation of {@link IntObjectMap} that uses open addressing for keys. To minimize - * the memory footprint, this class uses open addressing rather than chaining. Collisions are - * resolved using linear probing. Deletions implement compaction, so cost of remove can approach - * O(N) for full maps, which makes a small loadFactor recommended. - * - * @param The value type stored in the map. - */ -public class SynchronizedIntObjectHashMap implements IntObjectMap { - - /** Default initial capacity. Used if not specified in the constructor */ - public static final int DEFAULT_CAPACITY = 8; - - /** Default load factor. Used if not specified in the constructor */ - public static final float DEFAULT_LOAD_FACTOR = 0.5f; - - /** - * Placeholder for null values, so we can use the actual null to mean available. (Better than - * using a placeholder for available: less references for GC processing.) - */ - private static final Object NULL_VALUE = new Object(); - - /** The maximum number of elements allowed without allocating more space. */ - private int maxSize; - - /** The load factor for the map. Used to calculate {@link #maxSize}. */ - private final float loadFactor; - - private int[] keys; - private V[] values; - private int size; - private int mask; - - private final Set keySet = new KeySet(); - private final Set> entrySet = new EntrySet(); - private final Iterable> entries = PrimitiveIterator::new; - - public SynchronizedIntObjectHashMap() { - this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); - } - - public SynchronizedIntObjectHashMap(int initialCapacity) { - this(initialCapacity, DEFAULT_LOAD_FACTOR); - } - - public SynchronizedIntObjectHashMap(int initialCapacity, float loadFactor) { - if (loadFactor <= 0.0f || loadFactor > 1.0f) { - // Cannot exceed 1 because we can never store more than capacity elements; - // using a bigger loadFactor would trigger rehashing before the desired load is reached. - throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); - } - - this.loadFactor = loadFactor; - - // Adjust the initial capacity if necessary. - int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); - mask = capacity - 1; - - // Allocate the arrays. - keys = new int[capacity]; - @SuppressWarnings({"unchecked", "SuspiciousArrayCast"}) - V[] temp = (V[]) new Object[capacity]; - values = temp; - - // Initialize the maximum size value. - maxSize = calcMaxSize(capacity); - } - - private static T toExternal(T value) { - assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; - return value == NULL_VALUE ? null : value; - } - - @SuppressWarnings("unchecked") - private static T toInternal(T value) { - return value == null ? (T) NULL_VALUE : value; - } - - public synchronized V[] getValuesCopy() { - V[] values = this.values; - return Arrays.copyOf(values, values.length); - } - - @Override - public synchronized V get(int key) { - int index = indexOf(key); - return index == -1 ? null : toExternal(values[index]); - } - - @Override - public synchronized V put(int key, V value) { - int startIndex = hashIndex(key); - int index = startIndex; - - for (; ; ) { - if (values[index] == null) { - // Found empty slot, use it. - keys[index] = key; - values[index] = toInternal(value); - growSize(); - return null; - } - if (keys[index] == key) { - // Found existing entry with this key, just replace the value. - V previousValue = values[index]; - values[index] = toInternal(value); - return toExternal(previousValue); - } - - // Conflict, keep probing ... - if ((index = probeNext(index)) == startIndex) { - // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. - throw new IllegalStateException("Unable to insert"); - } - } - } - - @Override - public synchronized void putAll(Map sourceMap) { - if (sourceMap instanceof SynchronizedIntObjectHashMap) { - // Optimization - iterate through the arrays. - @SuppressWarnings("unchecked") - SynchronizedIntObjectHashMap source = (SynchronizedIntObjectHashMap) sourceMap; - for (int i = 0; i < source.values.length; ++i) { - V sourceValue = source.values[i]; - if (sourceValue != null) { - put(source.keys[i], sourceValue); - } - } - return; - } - - // Otherwise, just add each entry. - for (Entry entry : sourceMap.entrySet()) { - put(entry.getKey(), entry.getValue()); - } - } - - @Override - public synchronized V remove(int key) { - int index = indexOf(key); - if (index == -1) { - return null; - } - - V prev = values[index]; - removeAt(index); - return toExternal(prev); - } - - @Override - public synchronized int size() { - return size; - } - - @Override - public synchronized boolean isEmpty() { - return size == 0; - } - - @Override - public synchronized void clear() { - Arrays.fill(keys, 0); - Arrays.fill(values, null); - size = 0; - } - - @Override - public synchronized boolean containsKey(int key) { - return indexOf(key) >= 0; - } - - @Override - public synchronized boolean containsValue(Object value) { - @SuppressWarnings("unchecked") - V v1 = toInternal((V) value); - for (V v2 : values) { - // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). - if (v2 != null && v2.equals(v1)) { - return true; - } - } - return false; - } - - @Override - public synchronized Iterable> entries() { - return entries; - } - - @Override - public synchronized Collection values() { - return new AbstractCollection() { - @Override - public Iterator iterator() { - return new Iterator() { - final PrimitiveIterator iter = new PrimitiveIterator(); - - @Override - public boolean hasNext() { - return iter.hasNext(); - } - - @Override - public V next() { - return iter.next().value(); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - @Override - public int size() { - return size; - } - }; - } - - @Override - public synchronized int hashCode() { - // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys - // array, which may have different lengths for two maps of same size(), so the - // capacity cannot be used as input for hashing but the size can. - int hash = size; - for (int key : keys) { - // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. - // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, - // or the devastatingly bad memory locality of visiting value objects. - // Also, it's important to use a hash function that does not depend on the ordering - // of terms, only their values; since the map is an unordered collection and - // entries can end up in different positions in different maps that have the same - // elements, but with different history of puts/removes, due to conflicts. - hash ^= hashCode(key); - } - return hash; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (!(obj instanceof IntObjectMap)) { - return false; - } - @SuppressWarnings("rawtypes") - IntObjectMap other = (IntObjectMap) obj; - synchronized (this) { - if (size != other.size()) { - return false; - } - for (int i = 0; i < values.length; ++i) { - V value = values[i]; - if (value != null) { - int key = keys[i]; - Object otherValue = other.get(key); - if (value == NULL_VALUE) { - if (otherValue != null) { - return false; - } - } else if (!value.equals(otherValue)) { - return false; - } - } - } - } - return true; - } - - @Override - public synchronized boolean containsKey(Object key) { - return containsKey(objectToKey(key)); - } - - @Override - public synchronized V get(Object key) { - return get(objectToKey(key)); - } - - @Override - public synchronized V put(Integer key, V value) { - return put(objectToKey(key), value); - } - - @Override - public synchronized V remove(Object key) { - return remove(objectToKey(key)); - } - - @Override - public synchronized Set keySet() { - return keySet; - } - - @Override - public synchronized Set> entrySet() { - return entrySet; - } - - private int objectToKey(Object key) { - return ((Integer) key).intValue(); - } - - /** - * Locates the index for the given key. This method probes using double hashing. - * - * @param key the key for an entry in the map. - * @return the index where the key was found, or {@code -1} if no entry is found for that key. - */ - private int indexOf(int key) { - int startIndex = hashIndex(key); - int index = startIndex; - - for (; ; ) { - if (values[index] == null) { - // It's available, so no chance that this value exists anywhere in the map. - return -1; - } - if (key == keys[index]) { - return index; - } - - // Conflict, keep probing ... - if ((index = probeNext(index)) == startIndex) { - return -1; - } - } - } - - /** Returns the hashed index for the given key. */ - private int hashIndex(int key) { - // The array lengths are always a power of two, so we can use a bitmask to stay inside the array - // bounds. - return hashCode(key) & mask; - } - - /** Returns the hash code for the key. */ - private static int hashCode(int key) { - return key; - } - - /** Get the next sequential index after {@code index} and wraps if necessary. */ - private int probeNext(int index) { - // The array lengths are always a power of two, so we can use a bitmask to stay inside the array - // bounds. - return (index + 1) & mask; - } - - /** Grows the map size after an insertion. If necessary, performs a rehash of the map. */ - private void growSize() { - size++; - - if (size > maxSize) { - if (keys.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Max capacity reached at size=" + size); - } - - // Double the capacity. - rehash(keys.length << 1); - } - } - - /** - * Removes entry at the given index position. Also performs opportunistic, incremental rehashing - * if necessary to not break conflict chains. - * - * @param index the index position of the element to remove. - * @return {@code true} if the next item was moved back. {@code false} otherwise. - */ - private boolean removeAt(final int index) { - --size; - // Clearing the key is not strictly necessary (for GC like in a regular collection), - // but recommended for security. The memory location is still fresh in the cache anyway. - keys[index] = 0; - values[index] = null; - - // In the interval from index to the next available entry, the arrays may have entries - // that are displaced from their base position due to prior conflicts. Iterate these - // entries and move them back if possible, optimizing future lookups. - // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. - - int nextFree = index; - int i = probeNext(index); - for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { - int key = keys[i]; - int bucket = hashIndex(key); - if (i < bucket && (bucket <= nextFree || nextFree <= i) - || bucket <= nextFree && nextFree <= i) { - // Move the displaced entry "back" to the first available position. - keys[nextFree] = key; - values[nextFree] = value; - // Put the first entry after the displaced entry - keys[i] = 0; - values[i] = null; - nextFree = i; - } - } - return nextFree != index; - } - - /** Calculates the maximum size allowed before rehashing. */ - private int calcMaxSize(int capacity) { - // Clip the upper bound so that there will always be at least one available slot. - int upperBound = capacity - 1; - return Math.min(upperBound, (int) (capacity * loadFactor)); - } - - /** - * Rehashes the map for the given capacity. - * - * @param newCapacity the new capacity for the map. - */ - private void rehash(int newCapacity) { - int[] oldKeys = keys; - V[] oldVals = values; - - keys = new int[newCapacity]; - @SuppressWarnings({"unchecked", "SuspiciousArrayCast"}) - V[] temp = (V[]) new Object[newCapacity]; - values = temp; - - maxSize = calcMaxSize(newCapacity); - mask = newCapacity - 1; - - // Insert to the new arrays. - for (int i = 0; i < oldVals.length; ++i) { - V oldVal = oldVals[i]; - if (oldVal != null) { - // Inlined put(), but much simpler: we don't need to worry about - // duplicated keys, growing/rehashing, or failing to insert. - int oldKey = oldKeys[i]; - int index = hashIndex(oldKey); - - for (; ; ) { - if (values[index] == null) { - keys[index] = oldKey; - values[index] = oldVal; - break; - } - - // Conflict, keep probing. Can wrap around, but never reaches startIndex again. - index = probeNext(index); - } - } - } - } - - @Override - public synchronized String toString() { - if (isEmpty()) { - return "{}"; - } - StringBuilder sb = new StringBuilder(4 * size); - sb.append('{'); - boolean first = true; - for (int i = 0; i < values.length; ++i) { - V value = values[i]; - if (value != null) { - if (!first) { - sb.append(", "); - } - sb.append(keyToString(keys[i])) - .append('=') - .append(value == this ? "(this Map)" : toExternal(value)); - first = false; - } - } - return sb.append('}').toString(); - } - - /** - * Helper method called by {@link #toString()} in order to convert a single map key into a string. - * This is protected to allow subclasses to override the appearance of a given key. - */ - protected String keyToString(int key) { - return Integer.toString(key); - } - - /** Set implementation for iterating over the entries of the map. */ - private final class EntrySet extends AbstractSet> { - @Override - public Iterator> iterator() { - return new MapIterator(); - } - - @Override - public int size() { - return SynchronizedIntObjectHashMap.this.size(); - } - } - - /** Set implementation for iterating over the keys. */ - private final class KeySet extends AbstractSet { - @Override - public int size() { - return SynchronizedIntObjectHashMap.this.size(); - } - - @Override - public boolean contains(Object o) { - return SynchronizedIntObjectHashMap.this.containsKey(o); - } - - @Override - public boolean remove(Object o) { - return SynchronizedIntObjectHashMap.this.remove(o) != null; - } - - @Override - public boolean retainAll(Collection retainedKeys) { - synchronized (SynchronizedIntObjectHashMap.this) { - boolean changed = false; - for (Iterator> iter = entries().iterator(); iter.hasNext(); ) { - PrimitiveEntry entry = iter.next(); - if (!retainedKeys.contains(entry.key())) { - changed = true; - iter.remove(); - } - } - return changed; - } - } - - @Override - public void clear() { - SynchronizedIntObjectHashMap.this.clear(); - } - - @Override - public Iterator iterator() { - synchronized (SynchronizedIntObjectHashMap.this) { - final Iterator> iter = entrySet.iterator(); - return new Iterator() { - @Override - public boolean hasNext() { - synchronized (SynchronizedIntObjectHashMap.this) { - return iter.hasNext(); - } - } - - @Override - public Integer next() { - synchronized (SynchronizedIntObjectHashMap.this) { - return iter.next().getKey(); - } - } - - @Override - public void remove() { - synchronized (SynchronizedIntObjectHashMap.this) { - iter.remove(); - } - } - }; - } - } - } - - /** - * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link - * #next()}. - */ - private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { - private int prevIndex = -1; - private int nextIndex = -1; - private int entryIndex = -1; - - private void scanNext() { - while (++nextIndex != values.length && values[nextIndex] == null) {} - } - - @Override - public boolean hasNext() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (nextIndex == -1) { - scanNext(); - } - return nextIndex != values.length; - } - } - - @Override - public PrimitiveEntry next() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (!hasNext()) { - throw new NoSuchElementException(); - } - - prevIndex = nextIndex; - scanNext(); - - // Always return the same Entry object, just change its index each time. - entryIndex = prevIndex; - return this; - } - } - - @Override - public void remove() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (prevIndex == -1) { - throw new IllegalStateException("next must be called before each remove."); - } - if (removeAt(prevIndex)) { - // removeAt may move elements "back" in the array if they have been displaced because - // their - // spot in the - // array was occupied when they were inserted. If this occurs then the nextIndex is now - // invalid and - // should instead point to the prevIndex which now holds an element which was "moved - // back". - nextIndex = prevIndex; - } - prevIndex = -1; - } - } - - // Entry implementation. Since this implementation uses a single Entry, we coalesce that - // into the Iterator object (potentially making loop optimization much easier). - - @Override - public int key() { - synchronized (SynchronizedIntObjectHashMap.this) { - return keys[entryIndex]; - } - } - - @Override - public V value() { - synchronized (SynchronizedIntObjectHashMap.this) { - return toExternal(values[entryIndex]); - } - } - - @Override - public void setValue(V value) { - synchronized (SynchronizedIntObjectHashMap.this) { - values[entryIndex] = toInternal(value); - } - } - } - - /** Iterator used by the {@link Map} interface. */ - private final class MapIterator implements Iterator> { - private final PrimitiveIterator iter = new PrimitiveIterator(); - - @Override - public boolean hasNext() { - synchronized (SynchronizedIntObjectHashMap.this) { - return iter.hasNext(); - } - } - - @Override - public Entry next() { - synchronized (SynchronizedIntObjectHashMap.this) { - if (!hasNext()) { - throw new NoSuchElementException(); - } - - iter.next(); - - return new MapEntry(iter.entryIndex); - } - } - - @Override - public void remove() { - synchronized (SynchronizedIntObjectHashMap.this) { - iter.remove(); - } - } - } - - /** A single entry in the map. */ - final class MapEntry implements Entry { - private final int entryIndex; - - MapEntry(int entryIndex) { - this.entryIndex = entryIndex; - } - - @Override - public Integer getKey() { - synchronized (SynchronizedIntObjectHashMap.this) { - verifyExists(); - return keys[entryIndex]; - } - } - - @Override - public V getValue() { - synchronized (SynchronizedIntObjectHashMap.this) { - verifyExists(); - return toExternal(values[entryIndex]); - } - } - - @Override - public V setValue(V value) { - synchronized (SynchronizedIntObjectHashMap.this) { - verifyExists(); - V prevValue = toExternal(values[entryIndex]); - values[entryIndex] = toInternal(value); - return prevValue; - } - } - - private void verifyExists() { - if (values[entryIndex] == null) { - throw new IllegalStateException("The map entry has been removed"); - } - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index cb8b5d63d..c96a7aed2 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -16,19 +16,23 @@ package io.rsocket.internal; -import io.netty.util.ReferenceCounted; +import io.netty.buffer.ByteBuf; import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; import java.util.Objects; import java.util.Queue; +import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import org.reactivestreams.Subscriber; +import java.util.stream.Stream; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; +import reactor.core.Disposable; import reactor.core.Exceptions; import reactor.core.Fuseable; -import reactor.core.publisher.FluxProcessor; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.util.Logger; import reactor.util.annotation.Nullable; import reactor.util.concurrent.Queues; import reactor.util.context.Context; @@ -37,194 +41,399 @@ * A Processor implementation that takes a custom queue and allows only a single subscriber. * *

    The implementation keeps the order of signals. - * - * @param the input and output type */ -public final class UnboundedProcessor extends FluxProcessor - implements Fuseable.QueueSubscription, Fuseable { - - final Queue queue; - final Queue priorityQueue; - - volatile boolean done; +public final class UnboundedProcessor extends Flux + implements Scannable, + Disposable, + CoreSubscriber, + Fuseable.QueueSubscription, + Fuseable { + + final Queue queue; + final Queue priorityQueue; + final Runnable onFinalizedHook; + @Nullable final Logger logger; + + boolean cancelled; + boolean done; Throwable error; - // important to not loose the downstream too early and miss discard hook, while - // having relevant hasDownstreams() - boolean hasDownstream; - volatile CoreSubscriber actual; - - volatile boolean cancelled; - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "once"); - - volatile int wip; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "wip"); + CoreSubscriber actual; + + static final long FLAG_FINALIZED = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_DISPOSED = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_TERMINATED = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_CANCELLED = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_VALUE = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_REQUEST = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBER_READY = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBED_ONCE = + 0b0000_0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long MAX_WIP_VALUE = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "state"); volatile int discardGuard; - @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater DISCARD_GUARD = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); volatile long requested; - @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); + ByteBuf last; + boolean outputFused; public UnboundedProcessor() { + this(() -> {}); + } + + UnboundedProcessor(Logger logger) { + this(() -> {}, logger); + } + + public UnboundedProcessor(Runnable onFinalizedHook) { + this(onFinalizedHook, null); + } + + UnboundedProcessor(Runnable onFinalizedHook, @Nullable Logger logger) { + this.onFinalizedHook = onFinalizedHook; this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.logger = logger; } @Override - public int getBufferSize() { - return Integer.MAX_VALUE; + public Stream inners() { + return hasDownstreams() ? Stream.of(Scannable.from(this.actual)) : Stream.empty(); } @Override public Object scanUnsafe(Attr key) { - if (Attr.BUFFERED == key) return queue.size(); + if (Attr.ACTUAL == key) return isSubscriberReady(this.state) ? this.actual : null; + if (Attr.BUFFERED == key) return this.queue.size() + this.priorityQueue.size(); if (Attr.PREFETCH == key) return Integer.MAX_VALUE; - return super.scanUnsafe(key); + if (Attr.CANCELLED == key) { + final long state = this.state; + return isCancelled(state) || isDisposed(state); + } + + return null; } - void drainRegular(Subscriber a) { - int missed = 1; + public boolean tryEmitPrioritized(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } - final Queue q = queue; - final Queue pq = priorityQueue; + if (!this.priorityQueue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } - for (; ; ) { + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } - long r = requested; - long e = 0L; + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } - while (r != e) { - boolean d = done; + if (isWorkInProgress(previousState)) { + return true; + } - T t; - boolean empty; + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); + } + } + return true; + } - if (!pq.isEmpty()) { - t = pq.poll(); - empty = false; - } else { - t = q.poll(); - empty = t == null; - } + public boolean tryEmitNormal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } - if (checkTerminated(d, empty, a)) { - return; - } + if (!this.queue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } - if (empty) { - break; - } + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } - a.onNext(t); + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } - e++; + if (isWorkInProgress(previousState)) { + return true; } - if (r == e) { - if (checkTerminated(done, q.isEmpty() && pq.isEmpty(), a)) { - return; - } + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); } + } - if (e != 0 && r != Long.MAX_VALUE) { - REQUESTED.addAndGet(this, -e); + return true; + } + + public boolean tryEmitFinal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + this.last = t; + this.done = true; + + final long previousState = markValueAddedAndTerminated(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + this.actual.onComplete(); + return true; } - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { - break; + if (isWorkInProgress(previousState)) { + return true; } + + drainRegular((previousState | FLAG_TERMINATED | FLAG_HAS_VALUE) + 1); } + + return true; } - void drainFused(Subscriber a) { - int missed = 1; + @Deprecated + public void onNextPrioritized(ByteBuf t) { + tryEmitPrioritized(t); + } - for (; ; ) { + @Override + @Deprecated + public void onNext(ByteBuf t) { + tryEmitNormal(t); + } - if (cancelled) { - this.clear(); - hasDownstream = false; - return; - } + @Override + @Deprecated + public void onError(Throwable t) { + if (this.done || this.cancelled) { + Operators.onErrorDropped(t, currentContext()); + return; + } - boolean d = done; + this.error = t; + this.done = true; - a.onNext(null); + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + Operators.onErrorDropped(t, currentContext()); + return; + } - if (d) { - hasDownstream = false; + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onError(t); + return; + } - Throwable ex = error; - if (ex != null) { - a.onError(ex); - } else { - a.onComplete(); - } + if (isWorkInProgress(previousState)) { return; } - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { - break; + if (!hasValue(previousState)) { + // fast path no-values scenario + this.actual.onError(t); + return; } + + drainRegular((previousState | FLAG_TERMINATED) + 1); } } - public void drain() { - if (WIP.getAndIncrement(this) != 0) { - if (cancelled) { - this.clear(); - } + @Override + @Deprecated + public void onComplete() { + if (this.done || this.cancelled) { return; } - int missed = 1; + this.done = true; + + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + return; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onComplete(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (!hasValue(previousState)) { + this.actual.onComplete(); + return; + } + + drainRegular((previousState | FLAG_TERMINATED) + 1); + } + } + + void drainRegular(long expectedState) { + final CoreSubscriber a = this.actual; + final Queue q = this.queue; + final Queue pq = this.priorityQueue; for (; ; ) { - Subscriber a = actual; - if (a != null) { - if (outputFused) { - drainFused(a); - } else { - drainRegular(a); + long r = this.requested; + long e = 0L; + + boolean empty = false; + boolean done; + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + done = this.done; + + ByteBuf t = pq.poll(); + empty = t == null; + + if (empty) { + t = q.poll(); + empty = t == null; } + + if (checkTerminated(done, empty, true, a)) { + if (!empty) { + release(t); + } + return; + } + + if (empty) { + break; + } + + a.onNext(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + done = this.done; + empty = q.isEmpty() && pq.isEmpty(); + + if (checkTerminated(done, empty, false, a)) { + return; + } + } + + if (e != 0 && r != Long.MAX_VALUE) { + r = REQUESTED.addAndGet(this, -e); + } + + expectedState = markWorkDone(this, expectedState, r > 0, !empty); + if (isCancelled(expectedState)) { + clearAndFinalize(this); return; } - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { + if (isDisposed(expectedState)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); + return; + } + + if (!isWorkInProgress(expectedState)) { break; } } } - boolean checkTerminated(boolean d, boolean empty, Subscriber a) { - if (cancelled) { - this.clear(); - hasDownstream = false; + boolean checkTerminated( + boolean done, boolean empty, boolean hasDemand, CoreSubscriber a) { + final long state = this.state; + if (isCancelled(state)) { + clearAndFinalize(this); + return true; + } + + if (isDisposed(state)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); return true; } - if (d && empty) { - Throwable e = error; - hasDownstream = false; + + if (done && empty) { + if (!isTerminated(state)) { + // proactively return if volatile field is not yet set to needed state + return false; + } + final ByteBuf last = this.last; + if (last != null) { + if (!hasDemand) { + return false; + } + this.last = null; + a.onNext(last); + } + clearAndFinalize(this); + Throwable e = this.error; if (e != null) { a.onError(e); } else { @@ -238,7 +447,8 @@ boolean checkTerminated(boolean d, boolean empty, Subscriber a) { @Override public void onSubscribe(Subscription s) { - if (done || cancelled) { + final long state = this.state; + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { s.cancel(); } else { s.request(Long.MAX_VALUE); @@ -252,150 +462,216 @@ public int getPrefetch() { @Override public Context currentContext() { - CoreSubscriber actual = this.actual; - return actual != null ? actual.currentContext() : Context.empty(); + return isSubscriberReady(this.state) ? this.actual.currentContext() : Context.empty(); } - public void onNextPrioritized(T t) { - if (done || cancelled) { - Operators.onNextDropped(t, currentContext()); - release(t); + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + long previousState = markSubscribedOnce(this); + if (isSubscribedOnce(previousState)) { + Operators.error( + actual, new IllegalStateException("UnboundedProcessor allows only a single Subscriber")); return; } - if (!priorityQueue.offer(t)) { - Throwable ex = - Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); - onError(Operators.onOperatorError(null, ex, t, currentContext())); - release(t); + if (isDisposed(previousState)) { + Operators.error(actual, new CancellationException("Disposed")); return; } - drain(); - } - @Override - public void onNext(T t) { - if (done || cancelled) { - Operators.onNextDropped(t, currentContext()); - release(t); + actual.onSubscribe(this); + this.actual = actual; + + previousState = markSubscriberReady(this); + + if (isSubscriberReady(previousState)) { return; } - if (!queue.offer(t)) { - Throwable ex = - Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); - onError(Operators.onOperatorError(null, ex, t, currentContext())); - release(t); + if (this.outputFused) { + if (isCancelled(previousState)) { + return; + } + + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (hasValue(previousState)) { + actual.onNext(null); + } + + if (isTerminated(previousState)) { + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } return; } - drain(); - } - @Override - public void onError(Throwable t) { - if (done || cancelled) { - Operators.onErrorDropped(t, currentContext()); + if (isCancelled(previousState)) { + clearAndFinalize(this); return; } - error = t; - done = true; - - drain(); - } - - @Override - public void onComplete() { - if (done || cancelled) { + if (isDisposed(previousState)) { + clearAndFinalize(this); + actual.onError(new CancellationException("Disposed")); return; } - done = true; + if (!hasValue(previousState)) { + if (isTerminated(previousState)) { + clearAndFinalize(this); + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } + return; + } - drain(); + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_SUBSCRIBER_READY) + 1); + } } @Override - public void subscribe(CoreSubscriber actual) { - Objects.requireNonNull(actual, "subscribe"); - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + public void request(long n) { + if (Operators.validate(n)) { + if (this.outputFused) { + final long state = this.state; + if (isSubscriberReady(state)) { + this.actual.onNext(null); + } + return; + } - actual.onSubscribe(this); - this.actual = actual; - if (cancelled) { - this.hasDownstream = false; - } else { - drain(); + Operators.addCap(REQUESTED, this, n); + + final long previousState = markRequestAdded(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (isSubscriberReady(previousState) && hasValue(previousState)) { + drainRegular((previousState | FLAG_HAS_REQUEST) + 1); } - } else { - Operators.error( - actual, - new IllegalStateException("UnboundedProcessor " + "allows only a single Subscriber")); } } @Override - public void request(long n) { - if (Operators.validate(n)) { - Operators.addCap(REQUESTED, this, n); - drain(); + public void cancel() { + this.cancelled = true; + + final long previousState = markCancelled(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState) || !this.outputFused) { + clearAndFinalize(this); } } @Override - public void cancel() { - if (cancelled) { + @Deprecated + public void dispose() { + this.cancelled = true; + + final long previousState = markDisposed(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState)) { + clearAndFinalize(this); + return; + } + + if (!isSubscriberReady(previousState)) { + return; + } + + if (!this.outputFused) { + clearAndFinalize(this); + this.actual.onError(new CancellationException("Disposed")); return; } - cancelled = true; - if (WIP.getAndIncrement(this) == 0) { - this.clear(); - hasDownstream = false; + if (!isTerminated(previousState)) { + this.actual.onError(new CancellationException("Disposed")); } } @Override @Nullable - public T poll() { - Queue pq = this.priorityQueue; - if (!pq.isEmpty()) { - return pq.poll(); + public ByteBuf poll() { + ByteBuf t = this.priorityQueue.poll(); + if (t != null) { + return t; } - return queue.poll(); + + t = this.queue.poll(); + if (t != null) { + return t; + } + + t = this.last; + if (t != null) { + this.last = null; + return t; + } + + return null; } @Override public int size() { - return priorityQueue.size() + queue.size(); + return this.priorityQueue.size() + this.queue.size(); } @Override public boolean isEmpty() { - return priorityQueue.isEmpty() && queue.isEmpty(); + return this.priorityQueue.isEmpty() && this.queue.isEmpty(); } + /** + * Clears all elements from queues and set state to terminate. This method MUST be called only by + * the downstream subscriber which has enabled {@link Fuseable#ASYNC} fusion with the given {@link + * UnboundedProcessor} and is and indicator that the downstream is done with draining, it has + * observed any terminal signal (ON_COMPLETE or ON_ERROR or CANCEL) and will never be interacting + * with SingleConsumer queue anymore. + */ @Override public void clear() { + clearAndFinalize(this); + } + + void clearSafely() { if (DISCARD_GUARD.getAndIncrement(this) != 0) { return; } int missed = 1; - for (; ; ) { - while (!queue.isEmpty()) { - T t = queue.poll(); - if (t != null) { - release(t); - } - } - while (!priorityQueue.isEmpty()) { - T t = priorityQueue.poll(); - if (t != null) { - release(t); - } - } + clearUnsafely(); missed = DISCARD_GUARD.addAndGet(this, -missed); if (missed == 0) { @@ -404,56 +680,488 @@ public void clear() { } } + void clearUnsafely() { + final Queue queue = this.queue; + final Queue priorityQueue = this.priorityQueue; + + final ByteBuf last = this.last; + + if (last != null) { + release(last); + } + + ByteBuf byteBuf; + while ((byteBuf = queue.poll()) != null) { + release(byteBuf); + } + + while ((byteBuf = priorityQueue.poll()) != null) { + release(byteBuf); + } + } + @Override public int requestFusion(int requestedMode) { if ((requestedMode & Fuseable.ASYNC) != 0) { - outputFused = true; + this.outputFused = true; return Fuseable.ASYNC; } return Fuseable.NONE; } @Override - public void dispose() { - cancel(); + public boolean isDisposed() { + return isFinalized(this.state); } - @Override - public boolean isDisposed() { - return cancelled || done; + boolean hasDownstreams() { + final long state = this.state; + return !isTerminated(state) && isSubscriberReady(state); } - @Override - public boolean isTerminated() { - return done; + static void release(ByteBuf byteBuf) { + if (byteBuf.refCnt() > 0) { + try { + byteBuf.release(); + } catch (Throwable ex) { + // no ops + } + } } - @Override - @Nullable - public Throwable getError() { - return error; + /** + * Sets {@link #FLAG_SUBSCRIBED_ONCE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return {@code true} if {@link #FLAG_SUBSCRIBED_ONCE} was successfully set + */ + static long markSubscribedOnce(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isSubscribedOnce(state)) { + return state; + } + + final long nextState = state | FLAG_SUBSCRIBED_ONCE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mso", state, nextState); + return state; + } + } } - @Override - public long downstreamCount() { - return hasDownstreams() ? 1L : 0L; + /** + * Sets {@link #FLAG_SUBSCRIBER_READY} flag if flags {@link #FLAG_FINALIZED}, {@link + * #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return previous state + */ + static long markSubscriberReady(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) + || isCancelled(state) + || isDisposed(state) + || isSubscriberReady(state)) { + return state; + } + + long nextState = state; + if (!instance.outputFused) { + if ((!hasValue(state) && isTerminated(state)) || (hasRequest(state) && hasValue(state))) { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_SUBSCRIBER_READY; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " msr", state, nextState); + return state; + } + } } - @Override - public boolean hasDownstreams() { - return hasDownstream; - } - - void release(T t) { - if (t instanceof ReferenceCounted) { - ReferenceCounted refCounted = (ReferenceCounted) t; - if (refCounted.refCnt() > 0) { - try { - refCounted.release(); - } catch (Throwable ex) { - // no ops + /** + * Sets {@link #FLAG_HAS_REQUEST} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markRequestAdded(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state) || (isSubscriberReady(state) && hasValue(state))) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_REQUEST; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mra", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAdded(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state)) { + if (instance.outputFused) { + // fast path for fusion scenario + return state; + } + + if (hasRequest(state)) { + nextState = addWork(state); } } + + nextState = nextState | FLAG_HAS_VALUE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mva", state, nextState); + return state; + } } } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAddedAndTerminated(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_VALUE | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, "mva&t", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_TERMINATED} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markTerminatedOrFinalized(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + if (!hasValue(state)) { + // fast path for no values and no work in progress + nextState = FLAG_FINALIZED; + } else { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mt|f", state, nextState); + if (isFinalized(nextState)) { + instance.onFinalizedHook.run(); + } + return state; + } + } + } + + /** + * Sets {@link #FLAG_CANCELLED} flag if it was not set before and if flag {@link #FLAG_FINALIZED} + * is unset. Also, this method increments number of work in progress (WIP) + * + * @return previous state + */ + static long markCancelled(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_CANCELLED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mc", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_DISPOSED} flag if it was not set before and if flags {@link #FLAG_FINALIZED}, + * {@link #FLAG_CANCELLED} are unset. Also, this method increments number of work in progress + * (WIP) + * + * @return previous state + */ + static long markDisposed(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_DISPOSED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " md", state, nextState); + return state; + } + } + } + + static long addWork(long state) { + return (state & MAX_WIP_VALUE) == MAX_WIP_VALUE ? state : state + 1; + } + + /** + * Decrements the amount of work in progress by the given amount on the given state. Fails if flag + * is {@link #FLAG_FINALIZED} is set or if fusion disabled and flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set. + * + *

    Note, if fusion is enabled, the decrement should work if flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set, since, while the operator was not terminate by the downstream, + * we still have to propagate notifications that new elements are enqueued + * + * @return state after changing WIP or current state if update failed + */ + static long markWorkDone( + UnboundedProcessor instance, long expectedState, boolean hasRequest, boolean hasValue) { + for (; ; ) { + final long state = instance.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = + (state - (expectedState & MAX_WIP_VALUE)) + ^ (hasRequest ? 0 : FLAG_HAS_REQUEST) + ^ (hasValue ? 0 : FLAG_HAS_VALUE); + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mwd", state, nextState); + return nextState; + } + } + } + + /** + * Set flag {@link #FLAG_FINALIZED} and {@link #release(ByteBuf)} all the elements from {@link + * #queue} and {@link #priorityQueue}. + * + *

    This method may be called concurrently only if the given {@link UnboundedProcessor} has no + * output fusion ({@link #outputFused} {@code == true}). Otherwise this method MUST only by the + * downstream calling method {@link #clear()} + */ + static void clearAndFinalize(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + instance.clearSafely(); + return; + } + + if (!isSubscriberReady(state) || !instance.outputFused) { + instance.clearSafely(); + } else { + instance.clearUnsafely(); + } + + long nextState = (state & ~MAX_WIP_VALUE & ~FLAG_HAS_VALUE) | FLAG_FINALIZED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " c&f", state, nextState); + instance.onFinalizedHook.run(); + break; + } + } + } + + static boolean hasValue(long state) { + return (state & FLAG_HAS_VALUE) == FLAG_HAS_VALUE; + } + + static boolean hasRequest(long state) { + return (state & FLAG_HAS_REQUEST) == FLAG_HAS_REQUEST; + } + + static boolean isCancelled(long state) { + return (state & FLAG_CANCELLED) == FLAG_CANCELLED; + } + + static boolean isDisposed(long state) { + return (state & FLAG_DISPOSED) == FLAG_DISPOSED; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WIP_VALUE) != 0; + } + + static boolean isTerminated(long state) { + return (state & FLAG_TERMINATED) == FLAG_TERMINATED; + } + + static boolean isFinalized(long state) { + return (state & FLAG_FINALIZED) == FLAG_FINALIZED; + } + + static boolean isSubscriberReady(long state) { + return (state & FLAG_SUBSCRIBER_READY) == FLAG_SUBSCRIBER_READY; + } + + static boolean isSubscribedOnce(long state) { + return (state & FLAG_SUBSCRIBED_ONCE) == FLAG_SUBSCRIBED_ONCE; + } + + static void log( + UnboundedProcessor instance, String action, long initialState, long committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + long initialState, + long committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64))); + } + } + + static void log( + UnboundedProcessor instance, String action, int initialState, int committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + int initialState, + int committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32))); + } + } + + static String formatState(long state, int size) { + final String defaultFormat = Long.toBinaryString(state); + final StringBuilder formatted = new StringBuilder(); + final int toPrepend = size - defaultFormat.length(); + for (int i = 0; i < size; i++) { + if (i != 0 && i % 4 == 0) { + formatted.append("_"); + } + if (i < toPrepend) { + formatted.append("0"); + } else { + formatted.append(defaultFormat.charAt(i - toPrepend)); + } + } + + formatted.insert(0, "0b"); + return formatted.toString(); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java index 6939b0f7a..a99ef8a49 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java @@ -13,15 +13,30 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; -import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; import java.util.AbstractQueue; import java.util.Iterator; abstract class BaseLinkedQueuePad0 extends AbstractQueue implements MessagePassingQueue { - long p00, p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + // * drop 8b as object header acts as padding and is >= 8b * } // $gen:ordered-fields @@ -29,18 +44,20 @@ abstract class BaseLinkedQueueProducerNodeRef extends BaseLinkedQueuePad0 static final long P_NODE_OFFSET = fieldOffset(BaseLinkedQueueProducerNodeRef.class, "producerNode"); - private LinkedQueueNode producerNode; + private volatile LinkedQueueNode producerNode; final void spProducerNode(LinkedQueueNode newValue) { - producerNode = newValue; + UNSAFE.putObject(this, P_NODE_OFFSET, newValue); + } + + final void soProducerNode(LinkedQueueNode newValue) { + UNSAFE.putOrderedObject(this, P_NODE_OFFSET, newValue); } - @SuppressWarnings("unchecked") final LinkedQueueNode lvProducerNode() { - return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, P_NODE_OFFSET); + return producerNode; } - @SuppressWarnings("unchecked") final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) { return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); } @@ -51,8 +68,22 @@ final LinkedQueueNode lpProducerNode() { } abstract class BaseLinkedQueuePad1 extends BaseLinkedQueueProducerNodeRef { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -77,16 +108,27 @@ final LinkedQueueNode lpConsumerNode() { } abstract class BaseLinkedQueuePad2 extends BaseLinkedQueueConsumerNodeRef { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } /** * A base data structure for concurrent linked queues. For convenience also pulled in common single * consumer methods since at this time there's no plan to implement MC. - * - * @param - * @author nitsanw */ abstract class BaseLinkedQueue extends BaseLinkedQueuePad2 { @@ -158,8 +200,10 @@ public final int size() { * @see MessagePassingQueue#isEmpty() */ @Override - public final boolean isEmpty() { - return lvConsumerNode() == lvProducerNode(); + public boolean isEmpty() { + LinkedQueueNode consumerNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + return consumerNode == producerNode; } protected E getSingleConsumerNodeValue( diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java index 635779df3..cfad5ef71 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java @@ -13,25 +13,38 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.queues.CircularArrayOffsetCalculator.allocate; import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; -import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcElementOffset; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; -import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.calcElementOffset; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.lvElement; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.soElement; +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.allocateRefArray; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.lvRefElement; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.soRefElement; import io.rsocket.internal.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; -import io.rsocket.internal.jctools.util.PortableJvmInfo; -import io.rsocket.internal.jctools.util.Pow2; -import io.rsocket.internal.jctools.util.RangeUtil; import java.util.AbstractQueue; import java.util.Iterator; +import java.util.NoSuchElementException; abstract class BaseMpscLinkedArrayQueuePad1 extends AbstractQueue implements IndexedQueue { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -56,8 +69,22 @@ final boolean casProducerIndex(long expect, long newValue) { } abstract class BaseMpscLinkedArrayQueuePad2 extends BaseMpscLinkedArrayQueueProducerFields { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -84,8 +111,22 @@ final void soConsumerIndex(long newValue) { } abstract class BaseMpscLinkedArrayQueuePad3 extends BaseMpscLinkedArrayQueueConsumerFields { - long p0, p1, p2, p3, p4, p5, p6, p7; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -115,12 +156,9 @@ final void soProducerLimit(long newValue) { * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in * linked chunks of the initial size. The queue grows only when the current buffer is full and * elements are not copied on resize, instead a link to the new buffer is stored in the old buffer - * for the consumer to follow.
    - * - * @param + * for the consumer to follow. */ -public abstract class BaseMpscLinkedArrayQueue - extends BaseMpscLinkedArrayQueueColdProducerFields +abstract class BaseMpscLinkedArrayQueue extends BaseMpscLinkedArrayQueueColdProducerFields implements MessagePassingQueue, QueueProgressIndicators { // No post padding here, subclasses must add private static final Object JUMP = new Object(); @@ -141,7 +179,7 @@ public BaseMpscLinkedArrayQueue(final int initialCapacity) { // leave lower bit of mask clear long mask = (p2capacity - 1) << 1; // need extra element to point at next array - E[] buffer = allocate(p2capacity + 1); + E[] buffer = allocateRefArray(p2capacity + 1); producerBuffer = buffer; producerMask = mask; consumerBuffer = buffer; @@ -150,7 +188,7 @@ public BaseMpscLinkedArrayQueue(final int initialCapacity) { } @Override - public final int size() { + public int size() { // NOTE: because indices are on even numbers we cannot use the size util. /* @@ -181,7 +219,7 @@ public final int size() { } @Override - public final boolean isEmpty() { + public boolean isEmpty() { // Order matters! // Loading consumer before producer allows for producer increments after consumer index is read. // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there @@ -240,8 +278,8 @@ public boolean offer(final E e) { } } // INDEX visible before ELEMENT - final long offset = modifiedCalcElementOffset(pIndex, mask); - soElement(buffer, offset, e); // release element e + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); // release element e return true; } @@ -257,8 +295,8 @@ public E poll() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == null) { if (index != lvProducerIndex()) { // poll() == null iff queue is empty, null element is not strong enough indicator, so we @@ -266,7 +304,7 @@ public E poll() { // check the producer index. If the queue is indeed not empty we spin until element is // visible. do { - e = lvElement(buffer, offset); + e = lvRefElement(buffer, offset); } while (e == null); } else { return null; @@ -278,7 +316,7 @@ public E poll() { return newBufferPoll(nextBuffer, index); } - soElement(buffer, offset, null); // release element null + soRefElement(buffer, offset, null); // release element null soConsumerIndex(index + 2); // release cIndex return (E) e; } @@ -295,14 +333,14 @@ public E peek() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == null && index != lvProducerIndex()) { // peek() == null iff queue is empty, null element is not strong enough indicator, so we must // check the producer index. If the queue is indeed not empty we spin until element is // visible. do { - e = lvElement(buffer, offset); + e = lvRefElement(buffer, offset); } while (e == null); } if (e == JUMP) { @@ -346,31 +384,31 @@ else if (casProducerIndex(pIndex, pIndex + 1)) { @SuppressWarnings("unchecked") private E[] nextBuffer(final E[] buffer, final long mask) { final long offset = nextArrayOffset(mask); - final E[] nextBuffer = (E[]) lvElement(buffer, offset); + final E[] nextBuffer = (E[]) lvRefElement(buffer, offset); consumerBuffer = nextBuffer; consumerMask = (length(nextBuffer) - 2) << 1; - soElement(buffer, offset, BUFFER_CONSUMED); + soRefElement(buffer, offset, BUFFER_CONSUMED); return nextBuffer; } - private long nextArrayOffset(long mask) { - return modifiedCalcElementOffset(mask + 2, Long.MAX_VALUE); + private static long nextArrayOffset(long mask) { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); } private E newBufferPoll(E[] nextBuffer, long index) { - final long offset = modifiedCalcElementOffset(index, consumerMask); - final E n = lvElement(nextBuffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); if (n == null) { throw new IllegalStateException("new buffer must have at least one element"); } - soElement(nextBuffer, offset, null); // StoreStore + soRefElement(nextBuffer, offset, null); soConsumerIndex(index + 2); return n; } private E newBufferPeek(E[] nextBuffer, long index) { - final long offset = modifiedCalcElementOffset(index, consumerMask); - final E n = lvElement(nextBuffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); if (null == n) { throw new IllegalStateException("new buffer must have at least one element"); } @@ -402,8 +440,8 @@ public E relaxedPoll() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == null) { return null; } @@ -411,7 +449,7 @@ public E relaxedPoll() { final E[] nextBuffer = nextBuffer(buffer, mask); return newBufferPoll(nextBuffer, index); } - soElement(buffer, offset, null); + soRefElement(buffer, offset, null); soConsumerIndex(index + 2); return (E) e; } @@ -423,8 +461,8 @@ public E relaxedPeek() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == JUMP) { return newBufferPeek(nextBuffer(buffer, mask), index); } @@ -447,7 +485,11 @@ public int fill(Supplier s) { } @Override - public int fill(Supplier s, int batchSize) { + public int fill(Supplier s, int limit) { + if (null == s) throw new IllegalArgumentException("supplier is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) return 0; + long mask; E[] buffer; long pIndex; @@ -471,9 +513,10 @@ public int fill(Supplier s, int batchSize) { // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' - long batchIndex = Math.min(producerLimit, pIndex + 2 * batchSize); + long batchIndex = + Math.min(producerLimit, pIndex + 2l * limit); // -> producerLimit >= batchIndex - if (pIndex >= producerLimit || producerLimit < batchIndex) { + if (pIndex >= producerLimit) { int result = offerSlowPath(mask, pIndex, producerLimit); switch (result) { case CONTINUE_TO_P_INDEX_CAS: @@ -496,23 +539,15 @@ public int fill(Supplier s, int batchSize) { } for (int i = 0; i < claimedSlots; i++) { - final long offset = modifiedCalcElementOffset(pIndex + 2 * i, mask); - soElement(buffer, offset, s.get()); + final long offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); } return claimedSlots; } @Override - public void fill(Supplier s, WaitStrategy w, ExitCondition exit) { - - while (exit.keepRunning()) { - if (fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { - int idleCounter = 0; - while (exit.keepRunning() && fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { - idleCounter = w.idle(idleCounter); - } - } - } + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); } @Override @@ -521,30 +556,13 @@ public int drain(Consumer c) { } @Override - public int drain(final Consumer c, final int limit) { - // Impl note: there are potentially some small gains to be had by manually inlining - // relaxedPoll() and hoisting - // reused fields out to reduce redundant reads. - int i = 0; - E m; - for (; i < limit && (m = relaxedPoll()) != null; i++) { - c.accept(m); - } - return i; + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); } @Override - public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { - int idleCounter = 0; - while (exit.keepRunning()) { - E e = relaxedPoll(); - if (e == null) { - idleCounter = w.idle(idleCounter); - continue; - } - idleCounter = 0; - c.accept(e); - } + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); } /** @@ -559,21 +577,28 @@ public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { */ @Override public Iterator iterator() { - return new WeakIterator(); + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); } - private final class WeakIterator implements Iterator { - + private static class WeakIterator implements Iterator { + private final long pIndex; private long nextIndex; private E nextElement; private E[] currentBuffer; - private int currentBufferLength; + private int mask; - WeakIterator() { - setBuffer(consumerBuffer); + WeakIterator(E[] currentBuffer, long cIndex, long pIndex) { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); nextElement = getNext(); } + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + @Override public boolean hasNext() { return nextElement != null; @@ -581,37 +606,54 @@ public boolean hasNext() { @Override public E next() { - E e = nextElement; + final E e = nextElement; + if (e == null) { + throw new NoSuchElementException(); + } nextElement = getNext(); return e; } private void setBuffer(E[] buffer) { this.currentBuffer = buffer; - this.currentBufferLength = length(buffer); - this.nextIndex = 0; + this.mask = length(buffer) - 2; } private E getNext() { - while (true) { - while (nextIndex < currentBufferLength - 1) { - long offset = calcElementOffset(nextIndex++); - E e = lvElement(currentBuffer, offset); - if (e != null && e != JUMP) { - return e; - } + while (nextIndex < pIndex) { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; } - long offset = calcElementOffset(currentBufferLength - 1); - Object nextArray = lvElement(currentBuffer, offset); - if (nextArray == BUFFER_CONSUMED) { - // Consumer may have passed us, just jump to the current consumer buffer - setBuffer(consumerBuffer); - } else if (nextArray != null) { - setBuffer((E[]) nextArray); - } else { + + // not null && not JUMP -> found next element + if (e != JUMP) { + return e; + } + + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, calcRefElementOffset(nextBufferIndex)); + + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early return null; } + + setBuffer((E[]) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same + // index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } else { + return e; + } } + return null; } } @@ -620,7 +662,7 @@ private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s int newBufferLength = getNextBufferSize(oldBuffer); final E[] newBuffer; try { - newBuffer = allocate(newBufferLength); + newBuffer = allocateRefArray(newBufferLength); } catch (OutOfMemoryError oom) { assert lvProducerIndex() == pIndex + 1; soProducerIndex(pIndex); @@ -631,11 +673,11 @@ private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s final int newMask = (newBufferLength - 2) << 1; producerMask = newMask; - final long offsetInOld = modifiedCalcElementOffset(pIndex, oldMask); - final long offsetInNew = modifiedCalcElementOffset(pIndex, newMask); + final long offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); - soElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array - soElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked // ASSERT code final long cIndex = lvConsumerIndex(); @@ -652,7 +694,7 @@ private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s // INDEX visible before ELEMENT, consistent with consumer expectation // make resize visible to consumer - soElement(oldBuffer, offsetInOld, JUMP); + soRefElement(oldBuffer, offsetInOld, JUMP); } /** @return next buffer size(inclusive of next array pointer) */ diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java deleted file mode 100644 index d746fccbb..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal.jctools.queues; - -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; - -import io.rsocket.internal.jctools.util.InternalAPI; - -@InternalAPI -public final class CircularArrayOffsetCalculator { - @SuppressWarnings("unchecked") - public static E[] allocate(int capacity) { - return (E[]) new Object[capacity]; - } - - /** - * @param index desirable element index - * @param mask (length - 1) - * @return the offset in bytes within the array for a given index. - */ - public static long calcElementOffset(long index, long mask) { - return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java index 1b7d43166..40116bbe1 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java @@ -13,10 +13,7 @@ */ package io.rsocket.internal.jctools.queues; -import io.rsocket.internal.jctools.util.InternalAPI; - -@InternalAPI -public final class IndexedQueueSizeUtil { +final class IndexedQueueSizeUtil { public static int size(IndexedQueue iq) { /* * It is possible for a thread to be interrupted or reschedule between the read of the producer and @@ -54,7 +51,6 @@ public static boolean isEmpty(IndexedQueue iq) { return (iq.lvConsumerIndex() == iq.lvProducerIndex()); } - @InternalAPI public interface IndexedQueue { long lvConsumerIndex(); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java index 5e7831128..37651f351 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java @@ -13,13 +13,11 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; /** This is used for method substitution in the LinkedArray classes code generation. */ final class LinkedArrayQueueUtil { - private LinkedArrayQueueUtil() {} - static int length(Object[] buf) { return buf.length; } @@ -29,7 +27,7 @@ static int length(Object[] buf) { * is compensated for by reducing the element shift. The computation is constant folded, so * there's no cost. */ - static long modifiedCalcElementOffset(long index, long mask) { + static long modifiedCalcCircularRefElementOffset(long index, long mask) { return REF_ARRAY_BASE + ((index & mask) << (REF_ELEMENT_SHIFT - 1)); } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java index 6ea69e330..72e78bb92 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java @@ -13,8 +13,8 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; -import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; final class LinkedQueueNode { private static final long NEXT_OFFSET = fieldOffset(LinkedQueueNode.class, "next"); @@ -53,6 +53,10 @@ public void soNext(LinkedQueueNode n) { UNSAFE.putOrderedObject(this, NEXT_OFFSET, n); } + public void spNext(LinkedQueueNode n) { + UNSAFE.putObject(this, NEXT_OFFSET, n); + } + public LinkedQueueNode lvNext() { return next; } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java index e0c3d0ee1..7a0fa901f 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java @@ -13,55 +13,327 @@ */ package io.rsocket.internal.jctools.queues; +import java.util.Queue; + +/** + * Message passing queues are intended for concurrent method passing. A subset of {@link Queue} + * methods are provided with the same semantics, while further functionality which accomodates the + * concurrent usecase is also on offer. + * + *

    Message passing queues provide happens before semantics to messages passed through, namely + * that writes made by the producer before offering the message are visible to the consuming thread + * after the message has been polled out of the queue. + * + * @param the event/message type + */ public interface MessagePassingQueue { int UNBOUNDED_CAPACITY = -1; interface Supplier { + /** + * This method will return the next value to be written to the queue. As such the queue + * implementations are commited to insert the value once the call is made. + * + *

    Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the supplier + * method. In particular size and any offer methods may take the view that the full batch has + * already happened. + * + *

    WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + *

    WARNING: this method is assumed to never return {@code null}. Breaking this + * assumption can lead to a broken queue. + * + * @return new element, NEVER {@code null} + */ T get(); } interface Consumer { + /** + * This method will process an element already removed from the queue. This method is expected + * to never throw an exception. + * + *

    Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the accept method. + * In particular size and any poll/peek methods may take the view that the full batch has + * already happened. + * + *

    WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + * @param e not {@code null} + */ void accept(T e); } interface WaitStrategy { + /** + * This method can implement static or dynamic backoff. Dynamic backoff will rely on the counter + * for estimating how long the caller has been idling. The expected usage is: + * + *

    + * + *

    +     * 
    +     * int ic = 0;
    +     * while(true) {
    +     *   if(!isGodotArrived()) {
    +     *     ic = w.idle(ic);
    +     *     continue;
    +     *   }
    +     *   ic = 0;
    +     *   // party with Godot until he goes again
    +     * }
    +     * 
    +     * 
    + * + * @param idleCounter idle calls counter, managed by the idle method until reset + * @return new counter value to be used on subsequent idle cycle + */ int idle(int idleCounter); } interface ExitCondition { + /** + * This method should be implemented such that the flag read or determination cannot be hoisted + * out of a loop which notmally means a volatile load, but with JDK9 VarHandles may mean + * getOpaque. + * + * @return true as long as we should keep running + */ boolean keepRunning(); } + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#offer(Object)} interface. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false iff full + */ boolean offer(T e); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#poll()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ T poll(); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#peek()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ T peek(); + /** + * This method's accuracy is subject to concurrent modifications happening as the size is + * estimated and as such is a best effort rather than absolute value. For some implementations + * this method may be O(n) rather than O(1). + * + * @return number of messages in the queue, between 0 and {@link Integer#MAX_VALUE} but less or + * equals to capacity (if bounded). + */ int size(); + /** + * Removes all items from the queue. Called from the consumer thread subject to the restrictions + * appropriate to the implementation and according to the {@link Queue#clear()} interface. + */ void clear(); + /** + * This method's accuracy is subject to concurrent modifications happening as the observation is + * carried out. + * + * @return true if empty, false otherwise + */ boolean isEmpty(); + /** + * @return the capacity of this queue or {@link MessagePassingQueue#UNBOUNDED_CAPACITY} if not + * bounded + */ int capacity(); + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation. As + * opposed to {@link Queue#offer(Object)} this method may return false without the queue being + * full. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false if unable to offer + */ boolean relaxedOffer(T e); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#poll()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to poll + */ T relaxedPoll(); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#peek()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to peek + */ T relaxedPeek(); - int drain(Consumer c); - - int fill(Supplier s); - + /** + * Remove up to limit elements from the queue and hand to consume. This should be + * semantically similar to: + * + *

    + * + *

    {@code
    +   * M m;
    +   * int i = 0;
    +   * for(;i < limit && (m = relaxedPoll()) != null; i++){
    +   *   c.accept(m);
    +   * }
    +   * return i;
    +   * }
    + * + *

    There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

    WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + * @throws IllegalArgumentException if limit is negative + */ int drain(Consumer c, int limit); + /** + * Stuff the queue with up to limit elements from the supplier. Semantically similar to: + * + *

    + * + *

    {@code
    +   * for(int i=0; i < limit && relaxedOffer(s.get()); i++);
    +   * }
    + * + *

    There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

    WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + * @throws IllegalArgumentException if limit is negative + */ int fill(Supplier s, int limit); + /** + * Remove all available item from the queue and hand to consume. This should be semantically + * similar to: + * + *

    +   * M m;
    +   * while((m = relaxedPoll()) != null){
    +   * c.accept(m);
    +   * }
    +   * 
    + * + * There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

    WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + */ + int drain(Consumer c); + + /** + * Stuff the queue with elements from the supplier. Semantically similar to: + * + *

    +   * while(relaxedOffer(s.get());
    +   * 
    + * + * There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

    Unbounded queues will fill up the queue with a fixed amount rather than fill up to oblivion. + * + *

    WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + */ + int fill(Supplier s); + + /** + * Remove elements from the queue and hand to consume forever. Semantically similar to: + * + *

    + * + *

    +   *  int idleCounter = 0;
    +   *  while (exit.keepRunning()) {
    +   *      E e = relaxedPoll();
    +   *      if(e==null){
    +   *          idleCounter = wait.idle(idleCounter);
    +   *          continue;
    +   *      }
    +   *      idleCounter = 0;
    +   *      c.accept(e);
    +   *  }
    +   * 
    + * + *

    Called from a consumer thread subject to the restrictions appropriate to the implementation. + * + *

    WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @throws IllegalArgumentException c OR wait OR exit are {@code null} + */ void drain(Consumer c, WaitStrategy wait, ExitCondition exit); + /** + * Stuff the queue with elements from the supplier forever. Semantically similar to: + * + *

    + * + *

    +   * 
    +   *  int idleCounter = 0;
    +   *  while (exit.keepRunning()) {
    +   *      E e = s.get();
    +   *      while (!relaxedOffer(e)) {
    +   *          idleCounter = wait.idle(idleCounter);
    +   *          continue;
    +   *      }
    +   *      idleCounter = 0;
    +   *  }
    +   * 
    +   * 
    + * + *

    Called from a producer thread subject to the restrictions appropriate to the implementation. + * The main difference being that implementors MUST assure room in the queue is available BEFORE + * calling {@link Supplier#get}. + * + *

    WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @throws IllegalArgumentException s OR wait OR exit are {@code null} + */ void fill(Supplier s, WaitStrategy wait, ExitCondition exit); } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java new file mode 100644 index 000000000..cb03364d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.jctools.queues; + +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Consumer; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.ExitCondition; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Supplier; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.WaitStrategy; + +final class MessagePassingQueueUtil { + public static int drain(MessagePassingQueue queue, Consumer c, int limit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) return 0; + E e; + int i = 0; + for (; i < limit && (e = queue.relaxedPoll()) != null; i++) { + c.accept(e); + } + return i; + } + + public static int drain(MessagePassingQueue queue, Consumer c) { + if (null == c) throw new IllegalArgumentException("c is null"); + E e; + int i = 0; + while ((e = queue.relaxedPoll()) != null) { + i++; + c.accept(e); + } + return i; + } + + public static void drain( + MessagePassingQueue queue, Consumer c, WaitStrategy wait, ExitCondition exit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (null == wait) throw new IllegalArgumentException("wait is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + final E e = queue.relaxedPoll(); + if (e == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + c.accept(e); + } + } + + public static void fill( + MessagePassingQueue q, Supplier s, WaitStrategy wait, ExitCondition exit) { + if (null == wait) throw new IllegalArgumentException("waiter is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + if (q.fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + } + } + + public static int fillBounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, q.capacity()); + } + + public static int fillInBatchesToLimit( + MessagePassingQueue q, Supplier s, int batch, int limit) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + do { + final int filled = q.fill(s, batch); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= limit); + return (int) result; + } + + public static int fillUnbounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, 4096); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java index 59eab33a1..179070be4 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java @@ -14,20 +14,31 @@ package io.rsocket.internal.jctools.queues; import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; - -import io.rsocket.internal.jctools.util.PortableJvmInfo; +import static io.rsocket.internal.jctools.queues.MessagePassingQueueUtil.fillUnbounded; /** * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked * chunks of the initial size. The queue grows only when the current chunk is full and elements are * not copied on resize, instead a link to the new chunk is stored in the old chunk for the consumer - * to follow.
    - * - * @param + * to follow. */ public class MpscUnboundedArrayQueue extends BaseMpscLinkedArrayQueue { - long p0, p1, p2, p3, p4, p5, p6, p7; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b public MpscUnboundedArrayQueue(int chunkSize) { super(chunkSize); @@ -50,17 +61,7 @@ public int drain(Consumer c) { @Override public int fill(Supplier s) { - long result = - 0; // result is a long because we want to have a safepoint check at regular intervals - final int capacity = 4096; - do { - final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); - if (filled == 0) { - return (int) result; - } - result += filled; - } while (result <= capacity); - return (int) result; + return fillUnbounded(this, s); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java similarity index 90% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java index 2d567d60d..f037857e8 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java @@ -11,11 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; /** JVM Information that is standard and available on all JVMs (i.e. does not use unsafe) */ -@InternalAPI -public interface PortableJvmInfo { +interface PortableJvmInfo { int CACHE_LINE_SIZE = Integer.getInteger("jctools.cacheLineSize", 64); int CPUs = Runtime.getRuntime().availableProcessors(); int RECOMENDED_OFFER_BATCH = CPUs * 4; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java similarity index 96% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java index d8c66d89e..282a22f02 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java @@ -11,11 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; /** Power of 2 utility functions. */ -@InternalAPI -public final class Pow2 { +final class Pow2 { public static final int MAX_POW2 = 1 << 30; /** diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java similarity index 94% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java index 77a0582ca..3adcb2f3c 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java @@ -11,10 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; -@InternalAPI -public final class RangeUtil { +final class RangeUtil { public static long checkPositive(long n, String name) { if (n <= 0) { throw new IllegalArgumentException(name + ": " + n + " (expected: > 0)"); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java old mode 100755 new mode 100644 similarity index 71% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java index 793e64505..c99aeb689 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -33,41 +33,56 @@ * * @author nitsanw */ -@InternalAPI -public class UnsafeAccess { - public static final boolean SUPPORTS_GET_AND_SET; +class UnsafeAccess { + public static final boolean SUPPORTS_GET_AND_SET_REF; + public static final boolean SUPPORTS_GET_AND_ADD_LONG; public static final Unsafe UNSAFE; static { + UNSAFE = getUnsafe(); + SUPPORTS_GET_AND_SET_REF = hasGetAndSetSupport(); + SUPPORTS_GET_AND_ADD_LONG = hasGetAndAddLongSupport(); + } + + private static Unsafe getUnsafe() { Unsafe instance; try { final Field field = Unsafe.class.getDeclaredField("theUnsafe"); field.setAccessible(true); instance = (Unsafe) field.get(null); } catch (Exception ignored) { - // Some platforms, notably Android, might not have a sun.misc.Unsafe - // implementation with a private `theUnsafe` static instance. In this - // case we can try and call the default constructor, which proves - // sufficient for Android usage. + // Some platforms, notably Android, might not have a sun.misc.Unsafe implementation with a + // private + // `theUnsafe` static instance. In this case we can try to call the default constructor, which + // is sufficient + // for Android usage. try { Constructor c = Unsafe.class.getDeclaredConstructor(); c.setAccessible(true); instance = c.newInstance(); } catch (Exception e) { - SUPPORTS_GET_AND_SET = false; throw new RuntimeException(e); } } + return instance; + } - boolean getAndSetSupport = false; + private static boolean hasGetAndSetSupport() { try { Unsafe.class.getMethod("getAndSetObject", Object.class, Long.TYPE, Object.class); - getAndSetSupport = true; + return true; } catch (Exception ignored) { } + return false; + } - UNSAFE = instance; - SUPPORTS_GET_AND_SET = getAndSetSupport; + private static boolean hasGetAndAddLongSupport() { + try { + Unsafe.class.getMethod("getAndAddLong", Object.class, Long.TYPE, Long.TYPE); + return true; + } catch (Exception ignored) { + } + return false; } public static long fieldOffset(Class clz, String fieldName) throws RuntimeException { diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java similarity index 57% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java index d8309c5c5..c734a9914 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java @@ -11,32 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; -/** - * A concurrent access enabling class used by circular array based queues this class exposes an - * offset computation method along with differently memory fenced load/store methods into the - * underlying array. The class is pre-padded and the array is padded on either side to help with - * False sharing prvention. It is expected theat subclasses handle post padding. - * - *

    Offset calculation is separate from access to enable the reuse of a give compute offset. - * - *

    Load/Store methods using a buffer parameter are provided to allow the prevention of - * final field reload after a LoadLoad barrier. - * - *

    - * - * @author nitsanw - */ -@InternalAPI -public final class UnsafeRefArrayAccess { +final class UnsafeRefArrayAccess { public static final long REF_ARRAY_BASE; public static final int REF_ELEMENT_SHIFT; static { - final int scale = UnsafeAccess.UNSAFE.arrayIndexScale(Object[].class); + final int scale = UNSAFE.arrayIndexScale(Object[].class); if (4 == scale) { REF_ELEMENT_SHIFT = 2; } else if (8 == scale) { @@ -44,28 +28,28 @@ public final class UnsafeRefArrayAccess { } else { throw new IllegalStateException("Unknown pointer size: " + scale); } - REF_ARRAY_BASE = UnsafeAccess.UNSAFE.arrayBaseOffset(Object[].class); + REF_ARRAY_BASE = UNSAFE.arrayBaseOffset(Object[].class); } /** * A plain store (no ordering/fences) of an element to a given offset * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} * @param e an orderly kitty */ - public static void spElement(E[] buffer, long offset, E e) { + public static void spRefElement(E[] buffer, long offset, E e) { UNSAFE.putObject(buffer, offset, e); } /** - * An ordered store(store + StoreStore barrier) of an element to a given offset + * An ordered store of an element to a given offset * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset} + * @param offset computed via {@link UnsafeRefArrayAccess#calcCircularRefElementOffset} * @param e an orderly kitty */ - public static void soElement(E[] buffer, long offset, E e) { + public static void soRefElement(E[] buffer, long offset, E e) { UNSAFE.putOrderedObject(buffer, offset, e); } @@ -73,31 +57,48 @@ public static void soElement(E[] buffer, long offset, E e) { * A plain load (no ordering/fences) of an element from a given offset. * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} * @return the element at the offset */ @SuppressWarnings("unchecked") - public static E lpElement(E[] buffer, long offset) { + public static E lpRefElement(E[] buffer, long offset) { return (E) UNSAFE.getObject(buffer, offset); } /** - * A volatile load (load + LoadLoad barrier) of an element from a given offset. + * A volatile load of an element from a given offset. * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} * @return the element at the offset */ @SuppressWarnings("unchecked") - public static E lvElement(E[] buffer, long offset) { + public static E lvRefElement(E[] buffer, long offset) { return (E) UNSAFE.getObjectVolatile(buffer, offset); } /** * @param index desirable element index - * @return the offset in bytes within the array for a given index. + * @return the offset in bytes within the array for a given index */ - public static long calcElementOffset(long index) { + public static long calcRefElementOffset(long index) { return REF_ARRAY_BASE + (index << REF_ELEMENT_SHIFT); } + + /** + * Note: circular arrays are assumed a power of 2 in length and the `mask` is (length - 1). + * + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the circular array for a given index + */ + public static long calcCircularRefElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); + } + + /** This makes for an easier time generating the atomic queues, and removes some warnings. */ + @SuppressWarnings("unchecked") + public static E[] allocateRefArray(int capacity) { + return (E[]) new Object[capacity]; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java deleted file mode 100644 index f233e9597..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal.jctools.util; - -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * This annotation marks classes and methods which may be public for any reason (to support better - * testing or reduce code duplication) but are not intended as public API and may change between - * releases without the change being considered a breaking API change (a major release). - */ -@Target({ElementType.TYPE, ElementType.METHOD, ElementType.FIELD, ElementType.CONSTRUCTOR}) -@Retention(RetentionPolicy.SOURCE) -public @interface InternalAPI {} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java index 2535c342b..4fd7a772d 100644 --- a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java @@ -1,9 +1,10 @@ package io.rsocket.keepalive; import io.netty.buffer.ByteBuf; -import io.rsocket.Closeable; import io.rsocket.keepalive.KeepAliveSupport.KeepAlive; +import io.rsocket.resume.RSocketSession; import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; import java.util.function.Consumer; public interface KeepAliveHandler { @@ -14,18 +15,11 @@ KeepAliveFramesAcceptor start( Consumer onTimeout); class DefaultKeepAliveHandler implements KeepAliveHandler { - private final Closeable duplexConnection; - - public DefaultKeepAliveHandler(Closeable duplexConnection) { - this.duplexConnection = duplexConnection; - } - @Override public KeepAliveFramesAcceptor start( KeepAliveSupport keepAliveSupport, Consumer onSendKeepAliveFrame, Consumer onTimeout) { - duplexConnection.onClose().doFinally(s -> keepAliveSupport.stop()).subscribe(); return keepAliveSupport .onSendKeepAliveFrame(onSendKeepAliveFrame) .onTimeout(onTimeout) @@ -34,10 +28,18 @@ public KeepAliveFramesAcceptor start( } class ResumableKeepAliveHandler implements KeepAliveHandler { + private final ResumableDuplexConnection resumableDuplexConnection; + private final RSocketSession rSocketSession; + private final ResumeStateHolder resumeStateHolder; - public ResumableKeepAliveHandler(ResumableDuplexConnection resumableDuplexConnection) { + public ResumableKeepAliveHandler( + ResumableDuplexConnection resumableDuplexConnection, + RSocketSession rSocketSession, + ResumeStateHolder resumeStateHolder) { this.resumableDuplexConnection = resumableDuplexConnection; + this.rSocketSession = rSocketSession; + this.resumeStateHolder = resumeStateHolder; } @Override @@ -45,10 +47,11 @@ public KeepAliveFramesAcceptor start( KeepAliveSupport keepAliveSupport, Consumer onSendKeepAliveFrame, Consumer onTimeout) { - resumableDuplexConnection.onResume(keepAliveSupport::start); - resumableDuplexConnection.onDisconnect(keepAliveSupport::stop); + + rSocketSession.setKeepAliveSupport(keepAliveSupport); + return keepAliveSupport - .resumeState(resumableDuplexConnection) + .resumeState(resumeStateHolder) .onSendKeepAliveFrame(onSendKeepAliveFrame) .onTimeout(keepAlive -> resumableDuplexConnection.disconnect()) .start(); diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java index a67226ada..4fd18d041 100644 --- a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -23,7 +23,7 @@ import io.rsocket.resume.ResumeStateHolder; import java.time.Duration; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.function.Consumer; import reactor.core.Disposable; import reactor.core.publisher.Flux; @@ -38,11 +38,19 @@ public abstract class KeepAliveSupport implements KeepAliveFramesAcceptor { final Duration keepAliveTimeout; final long keepAliveTimeoutMillis; - final AtomicBoolean started = new AtomicBoolean(); + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(KeepAliveSupport.class, "state"); + + static final int STOPPED_STATE = 0; + static final int STARTING_STATE = 1; + static final int STARTED_STATE = 2; + static final int DISPOSED_STATE = -1; volatile Consumer onTimeout; volatile Consumer onFrameSent; - volatile Disposable ticksDisposable; + + Disposable ticksDisposable; volatile ResumeStateHolder resumeStateHolder; volatile long lastReceivedMillis; @@ -57,25 +65,30 @@ private KeepAliveSupport( } public KeepAliveSupport start() { - this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); - if (started.compareAndSet(false, true)) { - ticksDisposable = + if (this.state == STOPPED_STATE && STATE.compareAndSet(this, STOPPED_STATE, STARTING_STATE)) { + this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); + + final Disposable disposable = Flux.interval(keepAliveInterval, scheduler).subscribe(v -> onIntervalTick()); + this.ticksDisposable = disposable; + + if (this.state != STARTING_STATE + || !STATE.compareAndSet(this, STARTING_STATE, STARTED_STATE)) { + disposable.dispose(); + } } return this; } public void stop() { - if (started.compareAndSet(true, false)) { - ticksDisposable.dispose(); - } + terminate(STOPPED_STATE); } @Override public void receive(ByteBuf keepAliveFrame) { this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); if (resumeStateHolder != null) { - long remoteLastReceivedPos = remoteLastReceivedPosition(keepAliveFrame); + final long remoteLastReceivedPos = KeepAliveFrameCodec.lastPosition(keepAliveFrame); resumeStateHolder.onImpliedPosition(remoteLastReceivedPos); } if (KeepAliveFrameCodec.respondFlag(keepAliveFrame)) { @@ -104,6 +117,16 @@ public KeepAliveSupport onTimeout(Consumer onTimeout) { return this; } + @Override + public void dispose() { + terminate(DISPOSED_STATE); + } + + @Override + public boolean isDisposed() { + return ticksDisposable.isDisposed(); + } + abstract void onIntervalTick(); void send(ByteBuf frame) { @@ -122,40 +145,24 @@ void tryTimeout() { } } - long localLastReceivedPosition() { - return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; - } + void terminate(int terminationState) { + for (; ; ) { + final int state = this.state; - long remoteLastReceivedPosition(ByteBuf keepAliveFrame) { - return KeepAliveFrameCodec.lastPosition(keepAliveFrame); - } - - @Override - public void dispose() { - stop(); - } - - @Override - public boolean isDisposed() { - return ticksDisposable.isDisposed(); - } - - /** - * @deprecated since it should not be used anymore and will be completely removed in 1.1. - * Keepalive is symmetric on both side and implemented as a part of RSocketRequester - */ - @Deprecated - public static final class ServerKeepAliveSupport extends KeepAliveSupport { + if (state == STOPPED_STATE || state == DISPOSED_STATE) { + return; + } - public ServerKeepAliveSupport( - ByteBufAllocator allocator, int keepAlivePeriod, int keepAliveTimeout) { - super(allocator, keepAlivePeriod, keepAliveTimeout); + final Disposable disposable = this.ticksDisposable; + if (STATE.compareAndSet(this, state, terminationState)) { + disposable.dispose(); + return; + } } + } - @Override - void onIntervalTick() { - tryTimeout(); - } + long localLastReceivedPosition() { + return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; } public static final class ClientKeepAliveSupport extends KeepAliveSupport { diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java index 673b4a480..9e76d176d 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java @@ -18,51 +18,62 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.rsocket.Availability; +import java.time.Duration; import reactor.util.annotation.Nullable; /** A contract for RSocket lease, which is sent by a request acceptor and is time bound. */ -public interface Lease extends Availability { +public final class Lease { - static Lease create(int timeToLiveMillis, int numberOfRequests, @Nullable ByteBuf metadata) { - return LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); + public static Lease create( + Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + return new Lease(timeToLive, numberOfRequests, metadata); } - static Lease create(int timeToLiveMillis, int numberOfRequests) { - return create(timeToLiveMillis, numberOfRequests, Unpooled.EMPTY_BUFFER); + public static Lease create(Duration timeToLive, int numberOfRequests) { + return create(timeToLive, numberOfRequests, Unpooled.EMPTY_BUFFER); } - /** - * Number of requests allowed by this lease. - * - * @return The number of requests allowed by this lease. - */ - int getAllowedRequests(); + public static Lease unbounded() { + return unbounded(null); + } - /** - * Initial number of requests allowed by this lease. - * - * @return initial number of requests allowed by this lease. - */ - default int getStartingAllowedRequests() { - throw new UnsupportedOperationException("Not implemented"); + public static Lease unbounded(@Nullable ByteBuf metadata) { + return create(Duration.ofMillis(Integer.MAX_VALUE), Integer.MAX_VALUE, metadata); + } + + public static Lease empty() { + return create(Duration.ZERO, 0); + } + + final int timeToLiveMillis; + final int numberOfRequests; + final ByteBuf metadata; + final long expirationTime; + + Lease(Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + this.numberOfRequests = numberOfRequests; + this.timeToLiveMillis = (int) Math.min(timeToLive.toMillis(), Integer.MAX_VALUE); + this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + this.expirationTime = + timeToLive.isZero() ? 0 : System.currentTimeMillis() + timeToLive.toMillis(); } /** - * Number of milliseconds that this lease is valid from the time it is received. + * Number of requests allowed by this lease. * - * @return Number of milliseconds that this lease is valid from the time it is received. + * @return The number of requests allowed by this lease. */ - int getTimeToLiveMillis(); + public int numberOfRequests() { + return numberOfRequests; + } /** - * Number of milliseconds that this lease is still valid from now. + * Time to live for the given lease * - * @param now millis since epoch - * @return Number of milliseconds that this lease is still valid from now, or 0 if expired. + * @return relative duration in milliseconds */ - default int getRemainingTimeToLiveMillis(long now) { - return isEmpty() ? 0 : (int) Math.max(0, expiry() - now); + public int timeToLiveInMillis() { + return this.timeToLiveMillis; } /** @@ -70,41 +81,29 @@ default int getRemainingTimeToLiveMillis(long now) { * * @return Absolute time since epoch at which this lease will expire. */ - long expiry(); + public long expirationTime() { + return expirationTime; + } /** * Metadata for the lease. * * @return Metadata for the lease. */ - ByteBuf getMetadata(); - - /** - * Checks if the lease is expired now. - * - * @return {@code true} if the lease has expired. - */ - default boolean isExpired() { - return isExpired(System.currentTimeMillis()); - } - - /** - * Checks if the lease is expired for the passed {@code now}. - * - * @param now current time in millis. - * @return {@code true} if the lease has expired. - */ - default boolean isExpired(long now) { - return now > expiry(); - } - - /** Checks if the lease has not expired and there are allowed requests available */ - default boolean isValid() { - return !isExpired() && getAllowedRequests() > 0; + @Nullable + public ByteBuf metadata() { + return metadata; } - /** Checks if the lease is empty(default value if no lease was received yet) */ - default boolean isEmpty() { - return getAllowedRequests() == 0 && getTimeToLiveMillis() == 0; + @Override + public String toString() { + return "Lease{" + + "timeToLiveMillis=" + + timeToLiveMillis + + ", numberOfRequests=" + + numberOfRequests + + ", expirationTime=" + + expirationTime + + '}'; } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java deleted file mode 100644 index 7abb8aab9..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import java.util.concurrent.atomic.AtomicInteger; -import reactor.util.annotation.Nullable; - -public class LeaseImpl implements Lease { - private final int timeToLiveMillis; - private final AtomicInteger allowedRequests; - private final int startingAllowedRequests; - private final ByteBuf metadata; - private final long expiry; - - static LeaseImpl create(int timeToLiveMillis, int numberOfRequests, @Nullable ByteBuf metadata) { - assertLease(timeToLiveMillis, numberOfRequests); - return new LeaseImpl(timeToLiveMillis, numberOfRequests, metadata); - } - - static LeaseImpl empty() { - return new LeaseImpl(0, 0, null); - } - - private LeaseImpl(int timeToLiveMillis, int allowedRequests, @Nullable ByteBuf metadata) { - this.allowedRequests = new AtomicInteger(allowedRequests); - this.startingAllowedRequests = allowedRequests; - this.timeToLiveMillis = timeToLiveMillis; - this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; - this.expiry = timeToLiveMillis == 0 ? 0 : now() + timeToLiveMillis; - } - - public int getTimeToLiveMillis() { - return timeToLiveMillis; - } - - @Override - public int getAllowedRequests() { - return Math.max(0, allowedRequests.get()); - } - - @Override - public int getStartingAllowedRequests() { - return startingAllowedRequests; - } - - @Override - public ByteBuf getMetadata() { - return metadata; - } - - @Override - public long expiry() { - return expiry; - } - - @Override - public boolean isValid() { - return !isEmpty() && getAllowedRequests() > 0 && !isExpired(); - } - - /** - * try use 1 allowed request of Lease - * - * @return true if used successfully, false if Lease is expired or no allowed requests available - */ - public boolean use() { - if (isExpired()) { - return false; - } - int remaining = - allowedRequests.accumulateAndGet(1, (cur, update) -> Math.max(-1, cur - update)); - return remaining >= 0; - } - - @Override - public double availability() { - return isValid() ? getAllowedRequests() / (double) getStartingAllowedRequests() : 0.0; - } - - @Override - public String toString() { - long now = now(); - return "LeaseImpl{" - + "timeToLiveMillis=" - + timeToLiveMillis - + ", allowedRequests=" - + getAllowedRequests() - + ", startingAllowedRequests=" - + startingAllowedRequests - + ", expired=" - + isExpired(now) - + ", remainingTimeToLiveMillis=" - + getRemainingTimeToLiveMillis(now) - + '}'; - } - - private static long now() { - return System.currentTimeMillis(); - } - - private static void assertLease(int timeToLiveMillis, int numberOfRequests) { - if (numberOfRequests <= 0) { - throw new IllegalArgumentException("Number of requests must be positive"); - } - if (timeToLiveMillis <= 0) { - throw new IllegalArgumentException("Time-to-live must be positive"); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java new file mode 100644 index 000000000..48bd38494 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java @@ -0,0 +1,8 @@ +package io.rsocket.lease; + +import reactor.core.publisher.Flux; + +public interface LeaseSender { + + Flux send(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Leases.java b/rsocket-core/src/main/java/io/rsocket/lease/Leases.java deleted file mode 100644 index 4c90e38ce..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/Leases.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import java.util.Objects; -import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Function; -import reactor.core.publisher.Flux; - -public class Leases { - private static final Function> noopLeaseSender = leaseStats -> Flux.never(); - private static final Consumer> noopLeaseReceiver = leases -> {}; - - private Function> leaseSender = noopLeaseSender; - private Consumer> leaseReceiver = noopLeaseReceiver; - private Optional stats = Optional.empty(); - - public static Leases create() { - return new Leases<>(); - } - - public Leases sender(Function, Flux> leaseSender) { - this.leaseSender = leaseSender; - return this; - } - - public Leases receiver(Consumer> leaseReceiver) { - this.leaseReceiver = leaseReceiver; - return this; - } - - public Leases stats(T stats) { - this.stats = Optional.of(Objects.requireNonNull(stats)); - return this; - } - - @SuppressWarnings("unchecked") - public Function, Flux> sender() { - return (Function, Flux>) leaseSender; - } - - public Consumer> receiver() { - return leaseReceiver; - } - - @SuppressWarnings("unchecked") - public Optional stats() { - return (Optional) stats; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java index 3b6cec62c..84af91b1b 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -16,35 +16,16 @@ package io.rsocket.lease; import io.rsocket.exceptions.RejectedException; -import java.util.Objects; -import reactor.util.annotation.Nullable; public class MissingLeaseException extends RejectedException { private static final long serialVersionUID = -6169748673403858959L; - public MissingLeaseException(Lease lease, String tag) { - super(leaseMessage(Objects.requireNonNull(lease), Objects.requireNonNull(tag))); - } - - public MissingLeaseException(String tag) { - super(leaseMessage(null, Objects.requireNonNull(tag))); + public MissingLeaseException(String message) { + super(message); } @Override public synchronized Throwable fillInStackTrace() { return this; } - - static String leaseMessage(@Nullable Lease lease, String tag) { - if (lease == null) { - return String.format("[%s] Missing leases", tag); - } - if (lease.isEmpty()) { - return String.format("[%s] Lease was not received yet", tag); - } - boolean expired = lease.isExpired(); - int allowedRequests = lease.getAllowedRequests(); - return String.format( - "[%s] Missing leases. Expired: %b, allowedRequests: %d", tag, expired, allowedRequests); - } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java deleted file mode 100644 index fd569a2c8..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.netty.buffer.ByteBuf; -import io.rsocket.Availability; -import io.rsocket.frame.LeaseFrameCodec; -import java.util.function.Consumer; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.ReplayProcessor; - -public interface RequesterLeaseHandler extends Availability, Disposable { - - boolean useLease(); - - Exception leaseError(); - - void receive(ByteBuf leaseFrame); - - void dispose(); - - final class Impl implements RequesterLeaseHandler { - private final String tag; - private final ReplayProcessor receivedLease; - private volatile LeaseImpl currentLease = LeaseImpl.empty(); - - public Impl(String tag, Consumer> leaseReceiver) { - this.tag = tag; - receivedLease = ReplayProcessor.create(1); - leaseReceiver.accept(receivedLease); - } - - @Override - public boolean useLease() { - return currentLease.use(); - } - - @Override - public Exception leaseError() { - LeaseImpl l = this.currentLease; - String t = this.tag; - if (!l.isValid()) { - return new MissingLeaseException(l, t); - } else { - return new MissingLeaseException(t); - } - } - - @Override - public void receive(ByteBuf leaseFrame) { - int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); - int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); - ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); - LeaseImpl lease = LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); - currentLease = lease; - receivedLease.onNext(lease); - } - - @Override - public void dispose() { - receivedLease.onComplete(); - } - - @Override - public boolean isDisposed() { - return receivedLease.isTerminated(); - } - - @Override - public double availability() { - return currentLease.availability(); - } - } - - RequesterLeaseHandler None = - new RequesterLeaseHandler() { - @Override - public boolean useLease() { - return true; - } - - @Override - public Exception leaseError() { - throw new AssertionError("Error not possible with NOOP leases handler"); - } - - @Override - public void receive(ByteBuf leaseFrame) {} - - @Override - public void dispose() {} - - @Override - public double availability() { - return 1.0; - } - }; -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java deleted file mode 100644 index df8787cb7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Availability; -import io.rsocket.frame.LeaseFrameCodec; -import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Function; -import reactor.core.Disposable; -import reactor.core.Disposables; -import reactor.core.publisher.Flux; -import reactor.util.annotation.Nullable; - -public interface ResponderLeaseHandler extends Availability { - - boolean useLease(); - - Exception leaseError(); - - Disposable send(Consumer leaseFrameSender); - - final class Impl implements ResponderLeaseHandler { - private volatile LeaseImpl currentLease = LeaseImpl.empty(); - private final String tag; - private final ByteBufAllocator allocator; - private final Function, Flux> leaseSender; - private final Optional leaseStatsOption; - private final T leaseStats; - - public Impl( - String tag, - ByteBufAllocator allocator, - Function, Flux> leaseSender, - Optional leaseStatsOption) { - this.tag = tag; - this.allocator = allocator; - this.leaseSender = leaseSender; - this.leaseStatsOption = leaseStatsOption; - this.leaseStats = leaseStatsOption.orElse(null); - } - - @Override - public boolean useLease() { - boolean success = currentLease.use(); - onUseEvent(success, leaseStats); - return success; - } - - @Override - public Exception leaseError() { - LeaseImpl l = currentLease; - String t = tag; - if (!l.isValid()) { - return new MissingLeaseException(l, t); - } else { - return new MissingLeaseException(t); - } - } - - @Override - public Disposable send(Consumer leaseFrameSender) { - return leaseSender - .apply(leaseStatsOption) - .doOnTerminate(this::onTerminateEvent) - .subscribe( - lease -> { - currentLease = create(lease); - leaseFrameSender.accept(createLeaseFrame(lease)); - }); - } - - @Override - public double availability() { - return currentLease.availability(); - } - - private ByteBuf createLeaseFrame(Lease lease) { - return LeaseFrameCodec.encode( - allocator, lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); - } - - private void onTerminateEvent() { - T ls = leaseStats; - if (ls != null) { - ls.onEvent(LeaseStats.EventType.TERMINATE); - } - } - - private void onUseEvent(boolean success, @Nullable T ls) { - if (ls != null) { - LeaseStats.EventType eventType = - success ? LeaseStats.EventType.ACCEPT : LeaseStats.EventType.REJECT; - ls.onEvent(eventType); - } - } - - private static LeaseImpl create(Lease lease) { - if (lease instanceof LeaseImpl) { - return (LeaseImpl) lease; - } else { - return LeaseImpl.create( - lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); - } - } - } - - ResponderLeaseHandler None = - new ResponderLeaseHandler() { - @Override - public boolean useLease() { - return true; - } - - @Override - public Exception leaseError() { - throw new AssertionError("Error not possible with NOOP leases handler"); - } - - @Override - public Disposable send(Consumer leaseFrameSender) { - return Disposables.disposed(); - } - - @Override - public double availability() { - return 1.0; - } - }; -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java new file mode 100644 index 000000000..3e6f68321 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java @@ -0,0 +1,5 @@ +package io.rsocket.lease; + +import io.rsocket.plugins.RequestInterceptor; + +public interface TrackingLeaseSender extends LeaseSender, RequestInterceptor {} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java new file mode 100644 index 000000000..fdbbeb25d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java @@ -0,0 +1,235 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Implementation of {@link WeightedStats} that manages tracking state and exposes the required + * stats. + * + *

    A sub-class or a different class (delegation) needs to call {@link #startStream()}, {@link + * #stopStream()}, {@link #startRequest()}, and {@link #stopRequest(long)} to drive state tracking. + * + * @since 1.1 + * @see WeightedStatsRequestInterceptor + */ +public class BaseWeightedStats implements WeightedStats { + + private static final double DEFAULT_LOWER_QUANTILE = 0.5; + private static final double DEFAULT_HIGHER_QUANTILE = 0.8; + private static final int INACTIVITY_FACTOR = 500; + private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = + Clock.unit().convert(1L, TimeUnit.SECONDS); + + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; + + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final Ewma availabilityPercentage; + private final Median median; + private final Ewma interArrivalTime; + + private final long tau; + private final long inactivityFactor; + + private long errorStamp; // last we got an error + private long stamp; // last timestamp we sent a request + private long stamp0; // last timestamp we sent a request or receive a response + private long duration; // instantaneous cumulative duration + + private volatile int pendingRequests; // instantaneous rate + private static final AtomicIntegerFieldUpdater PENDING_REQUESTS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingRequests"); + private volatile int pendingStreams; // number of active streams + private static final AtomicIntegerFieldUpdater PENDING_STREAMS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingStreams"); + + protected BaseWeightedStats() { + this( + new FrugalQuantile(DEFAULT_LOWER_QUANTILE), + new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), + INACTIVITY_FACTOR); + } + + private BaseWeightedStats( + Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { + this.lowerQuantile = lowerQuantile; + this.higherQuantile = higherQuantile; + this.inactivityFactor = inactivityFactor; + + long now = Clock.now(); + this.stamp = now; + this.errorStamp = now; + this.stamp0 = now; + this.duration = 0L; + this.pendingRequests = 0; + this.median = new Median(); + this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); + this.availabilityPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); + this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); + } + + @Override + public double lowerQuantileLatency() { + return lowerQuantile.estimation(); + } + + @Override + public double higherQuantileLatency() { + return higherQuantile.estimation(); + } + + @Override + public int pending() { + return pendingRequests + pendingStreams; + } + + @Override + public double weightedAvailability() { + if (Clock.now() - stamp > tau) { + updateAvailability(1.0); + } + return availabilityPercentage.value(); + } + + @Override + public double predictedLatency() { + final long now = Clock.now(); + final long elapsed; + + synchronized (this) { + elapsed = Math.max(now - stamp, 1L); + } + + final double latency; + final double prediction = median.estimation(); + + final int pending = this.pending(); + if (prediction == 0.0) { + if (pending == 0) { + latency = 0.0; // first request + } else { + // subsequent requests while we don't have any history + latency = STARTUP_PENALTY + pending; + } + } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { + // if we did't see any data for a while, we decay the prediction by inserting + // artificial 0.0 into the median + median.insert(0.0); + latency = median.estimation(); + } else { + final double predicted = prediction * pending; + final double instant = instantaneous(now, pending); + + if (predicted < instant) { // NB: (0.0 < 0.0) == false + latency = instant / pending; // NB: pending never equal 0 here + } else { + // we are under the predictions + latency = prediction; + } + } + + return latency; + } + + long instantaneous(long now, int pending) { + return duration + (now - stamp0) * pending; + } + + void startStream() { + PENDING_STREAMS.incrementAndGet(this); + } + + void stopStream() { + PENDING_STREAMS.decrementAndGet(this); + } + + synchronized long startRequest() { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + interArrivalTime.insert(now - stamp); + duration += Math.max(0, now - stamp0) * pendingRequests; + PENDING_REQUESTS.lazySet(this, pendingRequests + 1); + stamp = now; + stamp0 = now; + + return now; + } + + synchronized long stopRequest(long timestamp) { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + duration += Math.max(0, now - stamp0) * pendingRequests - (now - timestamp); + PENDING_REQUESTS.lazySet(this, pendingRequests - 1); + stamp0 = now; + + return now; + } + + synchronized void record(double roundTripTime) { + median.insert(roundTripTime); + lowerQuantile.insert(roundTripTime); + higherQuantile.insert(roundTripTime); + } + + void updateAvailability(double value) { + availabilityPercentage.insert(value); + if (value == 0.0d) { + synchronized (this) { + errorStamp = Clock.now(); + } + } + } + + @Override + public String toString() { + return "Stats{" + + "lowerQuantile=" + + lowerQuantile.estimation() + + ", higherQuantile=" + + higherQuantile.estimation() + + ", inactivityFactor=" + + inactivityFactor + + ", tau=" + + tau + + ", errorPercentage=" + + availabilityPercentage.value() + + ", pending=" + + pendingRequests + + ", errorStamp=" + + errorStamp + + ", stamp=" + + stamp + + ", stamp0=" + + stamp0 + + ", duration=" + + duration + + ", median=" + + median.estimation() + + ", interArrivalTime=" + + interArrivalTime.value() + + ", pendingStreams=" + + pendingStreams + + ", availability=" + + availabilityPercentage.value() + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java new file mode 100644 index 000000000..528f4f896 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.InterceptorRegistry; + +/** + * A {@link LoadbalanceStrategy} with an interest in configuring the {@link RSocketConnector} for + * connecting to load-balance targets in order to hook into request lifecycle and track usage + * statistics. + * + *

    Currently this callback interface is supported for strategies configured in {@link + * LoadbalanceRSocketClient}. + * + * @since 1.1 + */ +public interface ClientLoadbalanceStrategy extends LoadbalanceStrategy { + + /** + * Initialize the connector, for example using the {@link InterceptorRegistry}, to intercept + * requests. + * + * @param connector the connector to configure + */ + void initialize(RSocketConnector connector); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java new file mode 100644 index 000000000..0f87f6510 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java @@ -0,0 +1,71 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +/** + * Compute the exponential weighted moving average of a series of values. The time at which you + * insert the value into `Ewma` is used to compute a weight (recent points are weighted higher). The + * parameter for defining the convergence speed (like most decay process) is the half-life. + * + *

    e.g. with a half-life of 10 unit, if you insert 100 at t=0 and 200 at t=10 the ewma will be + * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) + */ +class Ewma { + + final long tau; + + volatile long stamp; + static final AtomicLongFieldUpdater STAMP = + AtomicLongFieldUpdater.newUpdater(Ewma.class, "stamp"); + volatile double ewma; + + public Ewma(long halfLife, TimeUnit unit, double initialValue) { + this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); + + this.ewma = initialValue; + + STAMP.lazySet(this, 0L); + } + + public synchronized void insert(double x) { + final long now = Clock.now(); + final double elapsed = Math.max(0, now - stamp); + + STAMP.lazySet(this, now); + + double w = Math.exp(-elapsed / tau); + ewma = w * ewma + (1.0 - w) * x; + } + + public synchronized void reset(double value) { + stamp = 0L; + ewma = value; + } + + public double value() { + return ewma; + } + + @Override + public String toString() { + return "Ewma(value=" + ewma + ", age=" + (Clock.now() - stamp) + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java new file mode 100644 index 000000000..6c2b9c3ea --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java @@ -0,0 +1,228 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +abstract class FluxDeferredResolution extends Flux + implements CoreSubscriber, Subscription, BiConsumer, Scannable { + + final ResolvingOperator parent; + final INPUT fluxOrPayload; + final FrameType requestType; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(FluxDeferredResolution.class, "requested"); + + static final long STATE_UNSUBSCRIBED = -1; + static final long STATE_SUBSCRIBER_SET = 0; + static final long STATE_SUBSCRIBED = -2; + static final long STATE_TERMINATED = Long.MIN_VALUE; + + Subscription s; + CoreSubscriber actual; + boolean done; + + FluxDeferredResolution(ResolvingOperator parent, INPUT fluxOrPayload, FrameType requestType) { + this.parent = parent; + this.fluxOrPayload = fluxOrPayload; + this.requestType = requestType; + + REQUESTED.lazySet(this, STATE_UNSUBSCRIBED); + } + + @Override + public final void subscribe(CoreSubscriber actual) { + if (this.requested == STATE_UNSUBSCRIBED + && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { + + actual.onSubscribe(this); + + if (this.requested == STATE_TERMINATED) { + return; + } + + this.actual = actual; + this.parent.observe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public final Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_TERMINATED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_TERMINATED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_TERMINATED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(Payload payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public final void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_TERMINATED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public final void cancel() { + long state = REQUESTED.getAndSet(this, STATE_TERMINATED); + if (state == STATE_TERMINATED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + if (requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + } + } + + boolean isTerminated() { + return this.requested == STATE_TERMINATED; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java new file mode 100644 index 000000000..cdbdc19b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import java.util.SplittableRandom; + +/** + * Reference: Ma, Qiang, S. Muthukrishnan, and Mark Sandler. "Frugal Streaming for Estimating + * Quantiles." Space-Efficient Data Structures, Streams, and Algorithms. Springer Berlin Heidelberg, + * 2013. 77-96. + * + *

    More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ + */ +class FrugalQuantile implements Quantile { + final double increment; + final SplittableRandom rnd; + + int step; + int sign; + double quantile; + + volatile double estimate; + + public FrugalQuantile(double quantile, double increment) { + this.increment = increment; + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + this.rnd = new SplittableRandom(System.nanoTime()); + } + + public FrugalQuantile(double quantile) { + this(quantile, 1.0); + } + + public synchronized void reset(double quantile) { + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + } + + public double estimation() { + return estimate; + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + } else { + final double v = rnd.nextDouble(); + final double estimate = this.estimate; + + if (x > estimate && v > (1 - quantile)) { + higher(x); + } else if (x < estimate && v > quantile) { + lower(x); + } + } + } + + private void higher(double x) { + double estimate = this.estimate; + + step += sign * increment; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + + this.estimate = estimate; + } + + private void lower(double x) { + double estimate = this.estimate; + + step -= sign * increment; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + + this.estimate = estimate; + } + + @Override + public String toString() { + return "FrugalQuantile(q=" + quantile + ", v=" + estimate + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java new file mode 100644 index 000000000..eebf82fe9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java @@ -0,0 +1,1005 @@ +/* + * Copyright 2014-2020 Real Logic Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.IntToLongFunction; +import reactor.util.annotation.Nullable; + +/** A open addressing with linear probing hash map specialised for primitive key and value pairs. */ +class Int2LongHashMap implements Map, Serializable { + static final float DEFAULT_LOAD_FACTOR = 0.55f; + static final int MIN_CAPACITY = 8; + private static final long serialVersionUID = -690554872053575793L; + + private final float loadFactor; + private final long missingValue; + private int resizeThreshold; + private int size = 0; + private final boolean shouldAvoidAllocation; + + private long[] entries; + private KeySet keySet; + private ValueCollection values; + private EntrySet entrySet; + + /** @param missingValue for the map that represents null. */ + public Int2LongHashMap(final long missingValue) { + this(MIN_CAPACITY, DEFAULT_LOAD_FACTOR, missingValue); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + */ + public Int2LongHashMap( + final int initialCapacity, final float loadFactor, final long missingValue) { + this(initialCapacity, loadFactor, missingValue, true); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + * @param shouldAvoidAllocation should allocation be avoided by caching iterators and map entries. + */ + public Int2LongHashMap( + final int initialCapacity, + final float loadFactor, + final long missingValue, + final boolean shouldAvoidAllocation) { + validateLoadFactor(loadFactor); + + this.loadFactor = loadFactor; + this.missingValue = missingValue; + this.shouldAvoidAllocation = shouldAvoidAllocation; + + capacity(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, initialCapacity))); + } + + /** + * The value to be used as a null marker in the map. + * + * @return value to be used as a null marker in the map. + */ + public long missingValue() { + return missingValue; + } + + /** + * Get the load factor applied for resize operations. + * + * @return the load factor applied for resize operations. + */ + public float loadFactor() { + return loadFactor; + } + + /** + * Get the total capacity for the map to which the load factor will be a fraction of. + * + * @return the total capacity for the map. + */ + public int capacity() { + return entries.length >> 1; + } + + /** + * Get the actual threshold which when reached the map will resize. This is a function of the + * current capacity and load factor. + * + * @return the threshold when the map will resize. + */ + public int resizeThreshold() { + return resizeThreshold; + } + + /** {@inheritDoc} */ + public int size() { + return size; + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return size == 0; + } + + /** + * Get a value using provided key avoiding boxing. + * + * @param key lookup key. + * @return value associated with the key or {@link #missingValue()} if key is not found in the + * map. + */ + public long get(final int key) { + final int mask = entries.length - 1; + int index = evenHash(key, mask); + + long value = missingValue; + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + value = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + return value; + } + + /** + * Put a key value pair in the map. + * + * @param key lookup key + * @param value new value, must not be {@link #missingValue()} + * @return previous value associated with the key, or {@link #missingValue()} if none found + * @throws IllegalArgumentException if value is {@link #missingValue()} + */ + public long put(final int key, final long value) { + if (value == missingValue) { + throw new IllegalArgumentException("cannot accept missingValue"); + } + + final int mask = entries.length - 1; + int index = evenHash(key, mask); + long oldValue = missingValue; + + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + oldValue = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + if (oldValue == missingValue) { + ++size; + entries[index] = key; + } + + entries[index + 1] = value; + + increaseCapacity(); + + return oldValue; + } + + private void increaseCapacity() { + if (size > resizeThreshold) { + // entries.length = 2 * capacity + final int newCapacity = entries.length; + rehash(newCapacity); + } + } + + private void rehash(final int newCapacity) { + final long[] oldEntries = entries; + final int length = entries.length; + + capacity(newCapacity); + + final long[] newEntries = entries; + final int mask = entries.length - 1; + + for (int keyIndex = 0; keyIndex < length; keyIndex += 2) { + final long value = oldEntries[keyIndex + 1]; + if (value != missingValue) { + final int key = (int) oldEntries[keyIndex]; + int index = evenHash(key, mask); + + while (newEntries[index + 1] != missingValue) { + index = next(index, mask); + } + + newEntries[index] = key; + newEntries[index + 1] = value; + } + } + } + + /** + * Int primitive specialised containsKey. + * + * @param key the key to check. + * @return true if the map contains key as a key, false otherwise. + */ + public boolean containsKey(final int key) { + return get(key) != missingValue; + } + + /** + * Does the map contain the value. + * + * @param value to be tested against contained values. + * @return true if contained otherwise value. + */ + public boolean containsValue(final long value) { + boolean found = false; + if (value != missingValue) { + final int length = entries.length; + int remaining = size; + + for (int valueIndex = 1; remaining > 0 && valueIndex < length; valueIndex += 2) { + if (missingValue != entries[valueIndex]) { + if (value == entries[valueIndex]) { + found = true; + break; + } + --remaining; + } + } + } + + return found; + } + + /** {@inheritDoc} */ + public void clear() { + if (size > 0) { + Arrays.fill(entries, missingValue); + size = 0; + } + } + + /** + * Compact the backing arrays by rehashing with a capacity just larger than current size and + * giving consideration to the load factor. + */ + public void compact() { + final int idealCapacity = (int) Math.round(size() * (1.0d / loadFactor)); + rehash(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, idealCapacity))); + } + + /** + * Primitive specialised version of {@link #computeIfAbsent(Object, Function)} + * + * @param key to search on. + * @param mappingFunction to provide a value if the get returns null. + * @return the value if found otherwise the missing value. + */ + public long computeIfAbsent(final int key, final IntToLongFunction mappingFunction) { + long value = get(key); + if (value == missingValue) { + value = mappingFunction.applyAsLong(key); + if (value != missingValue) { + put(key, value); + } + } + + return value; + } + + // ---------------- Boxed Versions Below ---------------- + + /** {@inheritDoc} */ + @Nullable + public Long get(final Object key) { + return valOrNull(get((int) key)); + } + + /** {@inheritDoc} */ + public Long put(final Integer key, final Long value) { + return valOrNull(put((int) key, (long) value)); + } + + /** {@inheritDoc} */ + public boolean containsKey(final Object key) { + return containsKey((int) key); + } + + /** {@inheritDoc} */ + public boolean containsValue(final Object value) { + return containsValue((long) value); + } + + /** {@inheritDoc} */ + public void putAll(final Map map) { + for (final Map.Entry entry : map.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + /** {@inheritDoc} */ + public KeySet keySet() { + if (null == keySet) { + keySet = new KeySet(); + } + + return keySet; + } + + /** {@inheritDoc} */ + public ValueCollection values() { + if (null == values) { + values = new ValueCollection(); + } + + return values; + } + + /** {@inheritDoc} */ + public EntrySet entrySet() { + if (null == entrySet) { + entrySet = new EntrySet(); + } + + return entrySet; + } + + /** {@inheritDoc} */ + @Nullable + public Long remove(final Object key) { + return valOrNull(remove((int) key)); + } + + /** + * Remove value from the map using given key avoiding boxing. + * + * @param key whose mapping is to be removed from the map. + * @return removed value or {@link #missingValue()} if key was not found in the map. + */ + public long remove(final int key) { + final int mask = entries.length - 1; + int keyIndex = evenHash(key, mask); + + long oldValue = missingValue; + while (entries[keyIndex + 1] != missingValue) { + if (entries[keyIndex] == key) { + oldValue = entries[keyIndex + 1]; + entries[keyIndex + 1] = missingValue; + size--; + + compactChain(keyIndex); + + break; + } + + keyIndex = next(keyIndex, mask); + } + + return oldValue; + } + + @SuppressWarnings("FinalParameters") + private void compactChain(int deleteKeyIndex) { + final int mask = entries.length - 1; + int keyIndex = deleteKeyIndex; + + while (true) { + keyIndex = next(keyIndex, mask); + if (entries[keyIndex + 1] == missingValue) { + break; + } + + final int hash = evenHash((int) entries[keyIndex], mask); + + if ((keyIndex < hash && (hash <= deleteKeyIndex || deleteKeyIndex <= keyIndex)) + || (hash <= deleteKeyIndex && deleteKeyIndex <= keyIndex)) { + entries[deleteKeyIndex] = entries[keyIndex]; + entries[deleteKeyIndex + 1] = entries[keyIndex + 1]; + + entries[keyIndex + 1] = missingValue; + deleteKeyIndex = keyIndex; + } + } + } + + /** + * Get the minimum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the minimum value stored in the map. + */ + public long minValue() { + final long missingValue = this.missingValue; + long min = size == 0 ? missingValue : Long.MAX_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + min = Math.min(min, value); + } + } + + return min; + } + + /** + * Get the maximum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the maximum value stored in the map. + */ + public long maxValue() { + final long missingValue = this.missingValue; + long max = size == 0 ? missingValue : Long.MIN_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + max = Math.max(max, value); + } + } + + return max; + } + + /** {@inheritDoc} */ + public String toString() { + if (isEmpty()) { + return "{}"; + } + + final EntryIterator entryIterator = new EntryIterator(); + entryIterator.reset(); + + final StringBuilder sb = new StringBuilder().append('{'); + while (true) { + entryIterator.next(); + sb.append(entryIterator.getIntKey()).append('=').append(entryIterator.getLongValue()); + if (!entryIterator.hasNext()) { + return sb.append('}').toString(); + } + sb.append(',').append(' '); + } + } + + /** + * Primitive specialised version of {@link #replace(Object, Object)} + * + * @param key key with which the specified value is associated + * @param value value to be associated with the specified key + * @return the previous value associated with the specified key, or {@link #missingValue()} if + * there was no mapping for the key. + */ + public long replace(final int key, final long value) { + long currentValue = get(key); + if (currentValue != missingValue) { + currentValue = put(key, value); + } + + return currentValue; + } + + /** + * Primitive specialised version of {@link #replace(Object, Object, Object)} + * + * @param key key with which the specified value is associated + * @param oldValue value expected to be associated with the specified key + * @param newValue value to be associated with the specified key + * @return {@code true} if the value was replaced + */ + public boolean replace(final int key, final long oldValue, final long newValue) { + final long curValue = get(key); + if (curValue != oldValue || curValue == missingValue) { + return false; + } + + put(key, newValue); + + return true; + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Map)) { + return false; + } + + final Map that = (Map) o; + + return size == that.size() && entrySet().equals(that.entrySet()); + } + + public int hashCode() { + return entrySet().hashCode(); + } + + private static int next(final int index, final int mask) { + return (index + 2) & mask; + } + + private void capacity(final int newCapacity) { + final int entriesLength = newCapacity * 2; + if (entriesLength < 0) { + throw new IllegalStateException("max capacity reached at size=" + size); + } + + /*@DoNotSub*/ resizeThreshold = (int) (newCapacity * loadFactor); + entries = new long[entriesLength]; + Arrays.fill(entries, missingValue); + } + + @Nullable + private Long valOrNull(final long value) { + return value == missingValue ? null : value; + } + + // ---------------- Utility Classes ---------------- + + /** Base iterator implementation. */ + abstract class AbstractIterator implements Serializable { + private static final long serialVersionUID = 5262459454112462433L; + /** Is current position valid. */ + protected boolean isPositionValid = false; + + private int remaining; + private int positionCounter; + private int stopCounter; + + final void reset() { + isPositionValid = false; + remaining = Int2LongHashMap.this.size; + final long missingValue = Int2LongHashMap.this.missingValue; + final long[] entries = Int2LongHashMap.this.entries; + final int capacity = entries.length; + + int keyIndex = capacity; + if (entries[capacity - 1] != missingValue) { + for (int i = 1; i < capacity; i += 2) { + if (entries[i] == missingValue) { + keyIndex = i - 1; + break; + } + } + } + + stopCounter = keyIndex; + positionCounter = keyIndex + capacity; + } + + /** + * Returns position of the key of the current entry. + * + * @return key position. + */ + protected final int keyPosition() { + return positionCounter & entries.length - 1; + } + + /** + * Number of remaining elements. + * + * @return number of remaining elements. + */ + public int remaining() { + return remaining; + } + + /** + * Check if there are more elements remaining. + * + * @return {@code true} if {@code remaining > 0}. + */ + public boolean hasNext() { + return remaining > 0; + } + + /** + * Advance to the next entry. + * + * @throws NoSuchElementException if no more entries available. + */ + protected final void findNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final long[] entries = Int2LongHashMap.this.entries; + final long missingValue = Int2LongHashMap.this.missingValue; + final int mask = entries.length - 1; + + for (int keyIndex = positionCounter - 2; keyIndex >= stopCounter; keyIndex -= 2) { + final int index = keyIndex & mask; + if (entries[index + 1] != missingValue) { + isPositionValid = true; + positionCounter = keyIndex; + --remaining; + return; + } + } + + isPositionValid = false; + throw new IllegalStateException(); + } + + /** {@inheritDoc} */ + public void remove() { + if (isPositionValid) { + final int position = keyPosition(); + entries[position + 1] = missingValue; + --size; + + compactChain(position); + + isPositionValid = false; + } else { + throw new IllegalStateException(); + } + } + } + + /** Iterator over keys which supports access to unboxed keys via {@link #nextValue()}. */ + public final class KeyIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = 9151493609653852972L; + + public Integer next() { + return nextValue(); + } + + /** + * Return next key. + * + * @return next key. + */ + public int nextValue() { + findNext(); + return (int) entries[keyPosition()]; + } + } + + /** Iterator over values which supports access to unboxed values. */ + public final class ValueIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = -5670291734793552927L; + + public Long next() { + return nextValue(); + } + + /** + * Return next value. + * + * @return next value. + */ + public long nextValue() { + findNext(); + return entries[keyPosition() + 1]; + } + } + + /** Iterator over entries which supports access to unboxed keys and values. */ + public final class EntryIterator extends AbstractIterator + implements Iterator>, Entry, Serializable { + private static final long serialVersionUID = 1744408438593481051L; + + public Integer getKey() { + return getIntKey(); + } + + /** + * Returns the key of the current entry. + * + * @return the key. + */ + public int getIntKey() { + return (int) entries[keyPosition()]; + } + + public Long getValue() { + return getLongValue(); + } + + /** + * Returns the value of the current entry. + * + * @return the value. + */ + public long getLongValue() { + return entries[keyPosition() + 1]; + } + + public Long setValue(final Long value) { + return setValue(value.longValue()); + } + + /** + * Sets the value of the current entry. + * + * @param value to be set. + * @return previous value of the entry. + */ + public long setValue(final long value) { + if (!isPositionValid) { + throw new IllegalStateException(); + } + + if (missingValue == value) { + throw new IllegalArgumentException(); + } + + final int keyPosition = keyPosition(); + final long prevValue = entries[keyPosition + 1]; + entries[keyPosition + 1] = value; + return prevValue; + } + + public Entry next() { + findNext(); + + if (shouldAvoidAllocation) { + return this; + } + + return allocateDuplicateEntry(); + } + + private Entry allocateDuplicateEntry() { + return new MapEntry(getIntKey(), getLongValue()); + } + + /** {@inheritDoc} */ + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Entry)) { + return false; + } + + final Entry that = (Entry) o; + + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + /** An {@link java.util.Map.Entry} implementation. */ + public final class MapEntry implements Entry { + private final int k; + private final long v; + + /** + * Constructs entry with given key and value. + * + * @param k key. + * @param v value. + */ + public MapEntry(final int k, final long v) { + this.k = k; + this.v = v; + } + + public Integer getKey() { + return k; + } + + public Long getValue() { + return v; + } + + public Long setValue(final Long value) { + return Int2LongHashMap.this.put(k, value.longValue()); + } + + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + public boolean equals(final Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + + final Entry e = (Entry) o; + + return (e.getKey() != null && e.getValue() != null) + && (e.getKey().equals(k) && e.getValue().equals(v)); + } + + public String toString() { + return k + "=" + v; + } + } + } + + /** Set of keys which supports optional cached iterators to avoid allocation. */ + public final class KeySet extends AbstractSet implements Serializable { + private static final long serialVersionUID = -7645453993079742625L; + private final KeyIterator keyIterator = shouldAvoidAllocation ? new KeyIterator() : null; + + /** {@inheritDoc} */ + public KeyIterator iterator() { + KeyIterator keyIterator = this.keyIterator; + if (null == keyIterator) { + keyIterator = new KeyIterator(); + } + + keyIterator.reset(); + + return keyIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((int) o); + } + + /** + * Checks if key is contained in the map without boxing. + * + * @param key to check. + * @return {@code true} if key is contained in this map. + */ + public boolean contains(final int key) { + return containsKey(key); + } + } + + /** Collection of values which supports optionally cached iterators to avoid allocation. */ + public final class ValueCollection extends AbstractCollection implements Serializable { + private static final long serialVersionUID = -8925598924781601919L; + private final ValueIterator valueIterator = shouldAvoidAllocation ? new ValueIterator() : null; + + /** {@inheritDoc} */ + public ValueIterator iterator() { + ValueIterator valueIterator = this.valueIterator; + if (null == valueIterator) { + valueIterator = new ValueIterator(); + } + + valueIterator.reset(); + + return valueIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((long) o); + } + + /** + * Checks if the value is contained in the map. + * + * @param value to be checked. + * @return {@code true} if value is contained in this map. + */ + public boolean contains(final long value) { + return containsValue(value); + } + } + + /** Set of entries which supports optionally cached iterators to avoid allocation. */ + public final class EntrySet extends AbstractSet> + implements Serializable { + private static final long serialVersionUID = 63641283589916174L; + private final EntryIterator entryIterator = shouldAvoidAllocation ? new EntryIterator() : null; + + /** {@inheritDoc} */ + public EntryIterator iterator() { + EntryIterator entryIterator = this.entryIterator; + if (null == entryIterator) { + entryIterator = new EntryIterator(); + } + + entryIterator.reset(); + + return entryIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + if (!(o instanceof Entry)) { + return false; + } + final Entry entry = (Entry) o; + final Long value = get(entry.getKey()); + + return value != null && value.equals(entry.getValue()); + } + + /** {@inheritDoc} */ + public Object[] toArray() { + return toArray(new Object[size()]); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + public T[] toArray(final T[] a) { + final T[] array = + a.length >= size + ? a + : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); + final EntryIterator it = iterator(); + + for (int i = 0; i < array.length; i++) { + if (it.hasNext()) { + it.next(); + array[i] = (T) it.allocateDuplicateEntry(); + } else { + array[i] = null; + break; + } + } + + return array; + } + } + + private static int evenHash(final int value, final int mask) { + final int hash = (value << 1) - (value << 8); + + return hash & mask; + } + + private static void validateLoadFactor(final float loadFactor) { + if (loadFactor < 0.1f || loadFactor > 0.9f) { + throw new IllegalArgumentException( + "load factor must be in the range of 0.1 to 0.9: " + loadFactor); + } + } + + private static int findNextPositivePowerOfTwo(final int value) { + return 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(value - 1)); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java new file mode 100644 index 000000000..d59cbb86e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java @@ -0,0 +1,195 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import java.util.List; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * An implementation of {@link RSocketClient} backed by a pool of {@code RSocket} instances and + * using a {@link LoadbalanceStrategy} to select the {@code RSocket} to use for a given request. + * + * @since 1.1 + */ +public class LoadbalanceRSocketClient implements RSocketClient { + + private final RSocketPool rSocketPool; + + private LoadbalanceRSocketClient(RSocketPool rSocketPool) { + this.rSocketPool = rSocketPool; + } + + @Override + public Mono onClose() { + return rSocketPool.onClose(); + } + + @Override + public boolean connect() { + return rSocketPool.connect(); + } + + /** Return {@code Mono} that selects an RSocket from the underlying pool. */ + @Override + public Mono source() { + return Mono.fromSupplier(rSocketPool::select); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().fireAndForget(p)); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().requestResponse(p)); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return payloadMono.flatMapMany(p -> rSocketPool.select().requestStream(p)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source().flatMapMany(rSocket -> rSocket.requestChannel(payloads)); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().metadataPush(p)); + } + + @Override + public void dispose() { + rSocketPool.dispose(); + } + + /** + * Shortcut to create an {@link LoadbalanceRSocketClient} with round-robin load balancing. + * Effectively a shortcut for: + * + *

    +   * LoadbalanceRSocketClient.builder(targetPublisher)
    +   *    .connector(RSocketConnector.create())
    +   *    .build();
    +   * 
    + * + * @param connector a "template" for connecting to load balance targets + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created client instance + */ + public static LoadbalanceRSocketClient create( + RSocketConnector connector, Publisher> targetPublisher) { + return builder(targetPublisher).connector(connector).build(); + } + + /** + * Return a builder for a {@link LoadbalanceRSocketClient}. + * + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created builder + */ + public static Builder builder(Publisher> targetPublisher) { + return new Builder(targetPublisher); + } + + /** Builder for creating an {@link LoadbalanceRSocketClient}. */ + public static class Builder { + + private final Publisher> targetPublisher; + + @Nullable private RSocketConnector connector; + + @Nullable LoadbalanceStrategy loadbalanceStrategy; + + Builder(Publisher> targetPublisher) { + this.targetPublisher = targetPublisher; + } + + /** + * Configure the "template" connector to use for connecting to load balance targets. To + * establish a connection, the {@link LoadbalanceTarget#getTransport() ClientTransport} + * contained in each target is passed to the connector's {@link + * RSocketConnector#connect(ClientTransport) connect} method and thus the same connector with + * the same settings applies to all targets. + * + *

    By default this is initialized with {@link RSocketConnector#create()}. + * + * @param connector the connector to use as a template + */ + public Builder connector(RSocketConnector connector) { + this.connector = connector; + return this; + } + + /** + * Configure {@link RoundRobinLoadbalanceStrategy} as the strategy to use to select targets. + * + *

    This is the strategy used by default. + */ + public Builder roundRobinLoadbalanceStrategy() { + this.loadbalanceStrategy = new RoundRobinLoadbalanceStrategy(); + return this; + } + + /** + * Configure {@link WeightedLoadbalanceStrategy} as the strategy to use to select targets. + * + *

    By default, {@link RoundRobinLoadbalanceStrategy} is used. + */ + public Builder weightedLoadbalanceStrategy() { + this.loadbalanceStrategy = WeightedLoadbalanceStrategy.create(); + return this; + } + + /** + * Configure the {@link LoadbalanceStrategy} to use. + * + *

    By default, {@link RoundRobinLoadbalanceStrategy} is used. + */ + public Builder loadbalanceStrategy(LoadbalanceStrategy strategy) { + this.loadbalanceStrategy = strategy; + return this; + } + + /** Build the {@link LoadbalanceRSocketClient} instance. */ + public LoadbalanceRSocketClient build() { + final RSocketConnector connector = + (this.connector != null ? this.connector : RSocketConnector.create()); + + final LoadbalanceStrategy strategy = + (this.loadbalanceStrategy != null + ? this.loadbalanceStrategy + : new RoundRobinLoadbalanceStrategy()); + + if (strategy instanceof ClientLoadbalanceStrategy) { + ((ClientLoadbalanceStrategy) strategy).initialize(connector); + } + + return new LoadbalanceRSocketClient( + new RSocketPool(connector, this.targetPublisher, strategy)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java new file mode 100644 index 000000000..5662448e7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import java.util.List; + +/** + * Strategy to select an {@link RSocket} given a list of instances for load-balancing purposes. A + * simple implementation might go in round-robin fashion while a more sophisticated strategy might + * check availability, track usage stats, and so on. + * + * @since 1.1 + */ +@FunctionalInterface +public interface LoadbalanceStrategy { + + /** + * Select an {@link RSocket} from the given non-empty list. + * + * @param sockets the list to choose from + * @return the selected instance + */ + RSocket select(List sockets); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java new file mode 100644 index 000000000..3b5d71e4e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import org.reactivestreams.Publisher; + +/** + * Representation for a load-balance target used as input to {@link LoadbalanceRSocketClient} that + * in turn maintains and peridodically updates a list of current load-balance targets. The {@link + * #getKey()} is used to identify a target uniquely while the {@link #getTransport() transport} is + * used to connect to the target server. + * + * @since 1.1 + * @see LoadbalanceRSocketClient#create(RSocketConnector, Publisher) + */ +public class LoadbalanceTarget { + + final String key; + final ClientTransport transport; + + private LoadbalanceTarget(String key, ClientTransport transport) { + this.key = key; + this.transport = transport; + } + + /** Return the key that identifies this target uniquely. */ + public String getKey() { + return key; + } + + /** Return the transport to use to connect to the target server. */ + public ClientTransport getTransport() { + return transport; + } + + /** + * Create a new {@link LoadbalanceTarget} with the given key and {@link ClientTransport}. The key + * can be anything that identifies the target uniquely, e.g. SocketAddress, URL, and so on. + * + * @param key identifies the load-balance target uniquely + * @param transport for connecting to the target + * @return the created instance + */ + public static LoadbalanceTarget from(String key, ClientTransport transport) { + return new LoadbalanceTarget(key, transport); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + LoadbalanceTarget that = (LoadbalanceTarget) other; + return key.equals(that.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java new file mode 100644 index 000000000..5319706f9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +/** This implementation gives better results because it considers more data-point. */ +class Median extends FrugalQuantile { + + public Median() { + super(0.5, 1.0); + } + + public synchronized void reset() { + super.reset(0.5); + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + } else { + final double estimate = this.estimate; + if (x > estimate) { + greaterThanZero(x); + } else if (x < estimate) { + lessThanZero(x); + } + } + } + + private void greaterThanZero(double x) { + double estimate = this.estimate; + + step += sign; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + + this.estimate = estimate; + } + + private void lessThanZero(double x) { + double estimate = this.estimate; + + step -= sign; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + + this.estimate = estimate; + } + + @Override + public String toString() { + return "Median(v=" + estimate + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java new file mode 100644 index 000000000..69838f1b6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java @@ -0,0 +1,226 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +abstract class MonoDeferredResolution extends Mono + implements CoreSubscriber, Subscription, Scannable, BiConsumer { + + final ResolvingOperator parent; + final Payload payload; + final FrameType requestType; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(MonoDeferredResolution.class, "requested"); + + static final long STATE_UNSUBSCRIBED = -1; + static final long STATE_SUBSCRIBER_SET = 0; + static final long STATE_SUBSCRIBED = -2; + static final long STATE_TERMINATED = Long.MIN_VALUE; + + Subscription s; + CoreSubscriber actual; + boolean done; + + MonoDeferredResolution(ResolvingOperator parent, Payload payload, FrameType requestType) { + this.parent = parent; + this.payload = payload; + this.requestType = requestType; + + REQUESTED.lazySet(this, STATE_UNSUBSCRIBED); + } + + @Override + public final void subscribe(CoreSubscriber actual) { + if (this.requested == STATE_UNSUBSCRIBED + && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { + + actual.onSubscribe(this); + + if (this.requested == STATE_TERMINATED) { + return; + } + + this.actual = actual; + this.parent.observe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_TERMINATED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_TERMINATED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_TERMINATED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(RESULT payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public final void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_TERMINATED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public final void cancel() { + long state = REQUESTED.getAndSet(this, STATE_TERMINATED); + if (state == STATE_TERMINATED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + ReferenceCountUtil.safeRelease(this.payload); + } + } + + boolean isTerminated() { + return this.requested == STATE_TERMINATED; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java new file mode 100644 index 000000000..a77329d31 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java @@ -0,0 +1,310 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.context.Context; + +/** Default implementation of {@link RSocket} stored in {@link RSocketPool} */ +final class PooledRSocket extends ResolvingOperator + implements CoreSubscriber, RSocket { + + final RSocketPool parent; + final Mono rSocketSource; + final LoadbalanceTarget loadbalanceTarget; + final Sinks.Empty onCloseSink; + + volatile Subscription s; + + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(PooledRSocket.class, Subscription.class, "s"); + + PooledRSocket( + RSocketPool parent, Mono rSocketSource, LoadbalanceTarget loadbalanceTarget) { + this.parent = parent; + this.rSocketSource = rSocketSource; + this.loadbalanceTarget = loadbalanceTarget; + this.onCloseSink = Sinks.unsafe().empty(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final RSocket value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + if (value == null) { + this.terminate(new IllegalStateException("Source completed empty")); + } else { + this.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream (retryBackoff has exhausted) and remove from the parent target list + this.doCleanup(t); + } + + @Override + public void onNext(RSocket value) { + if (this.s == Operators.cancelledSubscription()) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + @Override + protected void doSubscribe() { + this.rSocketSource.subscribe(this); + } + + @Override + protected void doOnValueResolved(RSocket value) { + value.onClose().subscribe(null, this::doCleanup, () -> doCleanup(ON_DISPOSE)); + } + + void doCleanup(Throwable t) { + if (isDisposed()) { + return; + } + + this.terminate(t); + + final RSocketPool parent = this.parent; + for (; ; ) { + final PooledRSocket[] sockets = parent.activeSockets; + final int activeSocketsCount = sockets.length; + + int index = -1; + for (int i = 0; i < activeSocketsCount; i++) { + if (sockets[i] == this) { + index = i; + break; + } + } + + if (index == -1) { + break; + } + + final PooledRSocket[] newSockets; + if (activeSocketsCount == 1) { + newSockets = RSocketPool.EMPTY; + } else { + final int lastIndex = activeSocketsCount - 1; + + newSockets = new PooledRSocket[lastIndex]; + if (index != 0) { + System.arraycopy(sockets, 0, newSockets, 0, index); + } + + if (index != lastIndex) { + System.arraycopy(sockets, index + 1, newSockets, index, lastIndex - index); + } + } + + if (RSocketPool.ACTIVE_SOCKETS.compareAndSet(parent, sockets, newSockets)) { + break; + } + } + + if (t == ON_DISPOSE) { + this.onCloseSink.tryEmitEmpty(); + } else { + this.onCloseSink.tryEmitError(t); + } + } + + @Override + protected void doOnValueExpired(RSocket value) { + value.dispose(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onCloseSink::tryEmitError, onCloseSink::tryEmitEmpty); + } else { + onCloseSink.tryEmitEmpty(); + } + } + + @Override + public Mono fireAndForget(Payload payload) { + return new MonoInner<>(this, payload, FrameType.REQUEST_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return new MonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); + } + + @Override + public Flux requestStream(Payload payload) { + return new FluxInner<>(this, payload, FrameType.REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); + } + + @Override + public Mono metadataPush(Payload payload) { + return new MonoInner<>(this, payload, FrameType.METADATA_PUSH); + } + + LoadbalanceTarget target() { + return this.loadbalanceTarget; + } + + @Override + public Mono onClose() { + return this.onCloseSink.asMono(); + } + + @Override + public double availability() { + final RSocket socket = valueIfResolved(); + return socket != null ? socket.availability() : 0.0d; + } + + static final class MonoInner extends MonoDeferredResolution { + + MonoInner(PooledRSocket parent, Payload payload, FrameType requestType) { + super(parent, payload, requestType); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(RSocket rSocket, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + ReferenceCountUtil.safeRelease(this.payload); + onError(t); + return; + } + + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + } else { + parent.observe(this); + } + } + } + + static final class FluxInner extends FluxDeferredResolution { + + FluxInner(PooledRSocket parent, INPUT fluxOrPayload, FrameType requestType) { + super(parent, fluxOrPayload, requestType); + } + + @Override + @SuppressWarnings("unchecked") + public void accept(RSocket rSocket, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(t); + return; + } + + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + } else { + parent.observe(this); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java similarity index 57% rename from rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java rename to rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java index 6553e5ec5..84c699197 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,20 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.rsocket.loadbalance; -package io.rsocket.resume; +interface Quantile { + /** @return the estimation of the current value of the quantile */ + double estimation(); -import reactor.core.publisher.Flux; -import reactor.core.publisher.ReplayProcessor; - -class RequestListener { - private final ReplayProcessor requests = ReplayProcessor.create(1); - - public Flux apply(Flux flux) { - return flux.doOnRequest(requests::onNext); - } - - public Flux requests() { - return requests; - } + /** + * Insert a data point `x` in the quantile estimator. + * + * @param x the data point to add. + */ + void insert(double x); } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java new file mode 100644 index 000000000..59d9678d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java @@ -0,0 +1,532 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.frame.FrameType; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.stream.Collectors; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +class RSocketPool extends ResolvingOperator + implements CoreSubscriber>, Closeable { + + static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = + AtomicReferenceFieldUpdater.newUpdater( + RSocketPool.class, PooledRSocket[].class, "activeSockets"); + static final PooledRSocket[] EMPTY = new PooledRSocket[0]; + static final PooledRSocket[] TERMINATED = new PooledRSocket[0]; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(RSocketPool.class, Subscription.class, "s"); + final DeferredResolutionRSocket deferredResolutionRSocket = new DeferredResolutionRSocket(this); + final RSocketConnector connector; + final LoadbalanceStrategy loadbalanceStrategy; + final Sinks.Empty onAllClosedSink = Sinks.unsafe().empty(); + volatile PooledRSocket[] activeSockets; + volatile Subscription s; + + public RSocketPool( + RSocketConnector connector, + Publisher> targetPublisher, + LoadbalanceStrategy loadbalanceStrategy) { + this.connector = connector; + this.loadbalanceStrategy = loadbalanceStrategy; + + ACTIVE_SOCKETS.lazySet(this, EMPTY); + + targetPublisher.subscribe(this); + } + + @Override + public Mono onClose() { + return onAllClosedSink.asMono(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + + RSocket[] activeSockets = ACTIVE_SOCKETS.getAndSet(this, TERMINATED); + for (RSocket rSocket : activeSockets) { + rSocket.dispose(); + } + + if (activeSockets.length > 0) { + Mono.whenDelayError( + Arrays.stream(activeSockets).map(RSocket::onClose).collect(Collectors.toList())) + .subscribe(null, onAllClosedSink::tryEmitError, onAllClosedSink::tryEmitEmpty); + } else { + onAllClosedSink.tryEmitEmpty(); + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(List targets) { + if (isDisposed()) { + return; + } + + // This operation should happen less frequently than calls to select() (which are per request) + // and therefore it is acceptable somewhat less efficient. + + PooledRSocket[] previouslyActiveSockets; + PooledRSocket[] inactiveSockets; + PooledRSocket[] socketsToUse; + for (; ; ) { + HashMap rSocketSuppliersCopy = new HashMap<>(targets.size()); + + int j = 0; + for (LoadbalanceTarget target : targets) { + rSocketSuppliersCopy.put(target, j++); + } + + // Intersect current and new list of targets and find the ones to keep vs dispose + previouslyActiveSockets = this.activeSockets; + inactiveSockets = new PooledRSocket[previouslyActiveSockets.length]; + PooledRSocket[] nextActiveSockets = + new PooledRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; + int activeSocketsPosition = 0; + int inactiveSocketsPosition = 0; + for (int i = 0; i < previouslyActiveSockets.length; i++) { + PooledRSocket rSocket = previouslyActiveSockets[i]; + + Integer index = rSocketSuppliersCopy.remove(rSocket.target()); + if (index == null) { + // if one of the active rSockets is not included, we remove it and put in the + // pending removal + if (!rSocket.isDisposed()) { + inactiveSockets[inactiveSocketsPosition++] = rSocket; + // TODO: provide a meaningful algo for keeping removed rsocket in the list + // nextActiveSockets[position++] = rSocket; + } + } else { + if (!rSocket.isDisposed()) { + // keep old RSocket instance + nextActiveSockets[activeSocketsPosition++] = rSocket; + } else { + // put newly create RSocket instance + LoadbalanceTarget target = targets.get(index); + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); + } + } + } + + // The remainder are the brand new targets + for (LoadbalanceTarget target : rSocketSuppliersCopy.keySet()) { + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); + } + + if (activeSocketsPosition == 0) { + socketsToUse = EMPTY; + } else { + socketsToUse = Arrays.copyOf(nextActiveSockets, activeSocketsPosition); + } + if (ACTIVE_SOCKETS.compareAndSet(this, previouslyActiveSockets, socketsToUse)) { + break; + } + } + + for (PooledRSocket inactiveSocket : inactiveSockets) { + if (inactiveSocket == null) { + break; + } + + inactiveSocket.dispose(); + } + + if (isPending()) { + // notifies that upstream is resolved + if (socketsToUse != EMPTY) { + //noinspection ConstantConditions + complete(this); + } + } + } + + @Override + public void onError(Throwable t) { + // indicates upstream termination + S.set(this, Operators.cancelledSubscription()); + // propagates error and terminates the whole pool + terminate(t); + } + + @Override + public void onComplete() { + // indicates upstream termination + S.set(this, Operators.cancelledSubscription()); + } + + RSocket select() { + if (isDisposed()) { + return this.deferredResolutionRSocket; + } + + RSocket selected = doSelect(); + + if (selected == null) { + if (this.s == Operators.cancelledSubscription()) { + terminate(new CancellationException("Pool is exhausted")); + } else { + invalidate(); + + // check since it is possible that between doSelect() and invalidate() we might + // have received new sockets + selected = doSelect(); + if (selected != null) { + return selected; + } + } + return this.deferredResolutionRSocket; + } + + return selected; + } + + @Nullable + RSocket doSelect() { + PooledRSocket[] sockets = this.activeSockets; + + if (sockets == EMPTY || sockets == TERMINATED) { + return null; + } + + return this.loadbalanceStrategy.select(WrappingList.wrap(sockets)); + } + + static class DeferredResolutionRSocket implements RSocket { + + final RSocketPool parent; + + DeferredResolutionRSocket(RSocketPool parent) { + this.parent = parent; + } + + @Override + public Mono fireAndForget(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.REQUEST_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.REQUEST_RESPONSE); + } + + @Override + public Flux requestStream(Payload payload) { + return new FluxInner<>(this.parent, payload, FrameType.REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FluxInner<>(this.parent, payloads, FrameType.REQUEST_CHANNEL); + } + + @Override + public Mono metadataPush(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.METADATA_PUSH); + } + } + + static final class MonoInner extends MonoDeferredResolution { + + MonoInner(RSocketPool parent, Payload payload, FrameType requestType) { + super(parent, payload, requestType); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(Object aVoid, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + ReferenceCountUtil.safeRelease(this.payload); + onError(t); + return; + } + + RSocketPool parent = (RSocketPool) this.parent; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + + return; + } + + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + ReferenceCountUtil.safeRelease(this.payload); + onError(error); + return; + } + } + } + } + + static final class FluxInner extends FluxDeferredResolution { + + FluxInner(RSocketPool parent, INPUT fluxOrPayload, FrameType requestType) { + super(parent, fluxOrPayload, requestType); + } + + @Override + @SuppressWarnings("unchecked") + public void accept(Object aVoid, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(t); + return; + } + + RSocketPool parent = (RSocketPool) this.parent; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + + return; + } + + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(error); + return; + } + } + } + } + + static final class WrappingList implements List { + + static final ThreadLocal INSTANCE = ThreadLocal.withInitial(WrappingList::new); + + private PooledRSocket[] activeSockets; + + static List wrap(PooledRSocket[] activeSockets) { + final WrappingList sockets = INSTANCE.get(); + sockets.activeSockets = activeSockets; + return sockets; + } + + @Override + public RSocket get(int index) { + final PooledRSocket socket = activeSockets[index]; + + RSocket realValue = socket.value; + if (realValue != null) { + return realValue; + } + + realValue = socket.valueIfResolved(); + if (realValue != null) { + return realValue; + } + + return socket; + } + + @Override + public int size() { + return activeSockets.length; + } + + @Override + public boolean isEmpty() { + return activeSockets.length == 0; + } + + @Override + public Object[] toArray() { + return activeSockets; + } + + @Override + @SuppressWarnings("unchecked") + public T[] toArray(T[] a) { + return (T[]) activeSockets; + } + + @Override + public boolean contains(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean add(RSocket weightedRSocket) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean containsAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(int index, Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket set(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public void add(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket remove(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public int indexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public int lastIndexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator() { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public List subList(int fromIndex, int toIndex) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java new file mode 100644 index 000000000..52f16e166 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java @@ -0,0 +1,420 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +// This class is a copy of the same class in io.rsocket.core + +class ResolvingOperator implements Disposable { + + static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ResolvingOperator.class, "wip"); + + volatile BiConsumer[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ResolvingOperator.class, BiConsumer[].class, "subscribers"); + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_UNSUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_SUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] READY = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] TERMINATED = new BiConsumer[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + public ResolvingOperator() { + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public final void dispose() { + this.terminate(ON_DISPOSE); + } + + @Override + public final boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + public final boolean isPending() { + BiConsumer[] state = this.subscribers; + return state != READY && state != TERMINATED; + } + + @Nullable + public final T valueIfResolved() { + if (this.subscribers == READY) { + T value = this.value; + if (value != null) { + return value; + } + } + + return null; + } + + final void observe(BiConsumer actual) { + for (; ; ) { + final int state = this.add(actual); + + T value = this.value; + + if (state == READY_STATE) { + if (value != null) { + actual.accept(value, null); + return; + } + // value == null means racing between invalidate and this subscriber + // thus, we have to loop again + continue; + } else if (state == TERMINATED_STATE) { + actual.accept(null, this.t); + return; + } + + return; + } + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ResolvingOperator} is completed with an error a RuntimeException + * that wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@link ResolvingOperator} or {@code null} if the timeout is reached + * and the {@link ResolvingOperator} has not completed + * @throws RuntimeException if terminated with error + * @throws IllegalStateException if timed out or {@link Thread} was interrupted with {@link + * InterruptedException} + */ + @Nullable + @SuppressWarnings({"uncheked", "BusyWait"}) + public T block(@Nullable Duration timeout) { + try { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + subscribers = this.subscribers; + + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + final void terminate(Throwable t) { + if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // writes happens before volatile write + this.t = t; + + final BiConsumer[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doOnDispose(); + + this.doFinally(); + + for (BiConsumer consumer : subscribers) { + consumer.accept(null, t); + } + } + + final void complete(T value) { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.doOnValueResolved(value); + + for (BiConsumer consumer : subscribers) { + consumer.accept(value, null); + } + } + + protected void doOnValueResolved(T value) { + // no ops + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && isDisposed()) { + this.value = null; + this.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + final void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final BiConsumer[] subscribers = this.subscribers; + + if (subscribers == READY) { + // guarded section to ensure we expire value exactly once if there is racing + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final T value = this.value; + if (value != null) { + this.value = null; + this.doOnValueExpired(value); + } + + int m = 1; + for (; ; ) { + if (isDisposed()) { + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + } + + SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED); + } + } + + protected void doOnValueExpired(T value) { + // no ops + } + + protected void doOnDispose() { + // no ops + } + + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + + final int add(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + BiConsumer[] b = new BiConsumer[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.doSubscribe(); + } + return ADDED_STATE; + } + } + } + + protected void doSubscribe() { + // no ops + } + + @SuppressWarnings("unchecked") + final void remove(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + BiConsumer[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new BiConsumer[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java new file mode 100644 index 000000000..f1a9f8c55 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java @@ -0,0 +1,42 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import java.util.List; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Simple {@link LoadbalanceStrategy} that selects the {@code RSocket} to use in round-robin order. + * + * @since 1.1 + */ +public class RoundRobinLoadbalanceStrategy implements LoadbalanceStrategy { + + volatile int nextIndex; + + private static final AtomicIntegerFieldUpdater NEXT_INDEX = + AtomicIntegerFieldUpdater.newUpdater(RoundRobinLoadbalanceStrategy.class, "nextIndex"); + + @Override + public RSocket select(List sockets) { + int length = sockets.size(); + + int indexToUse = Math.abs(NEXT_INDEX.getAndIncrement(this) % length); + + return sockets.get(indexToUse); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java new file mode 100644 index 000000000..c30c8ad6b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java @@ -0,0 +1,249 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.RequestInterceptor; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; +import reactor.util.annotation.Nullable; + +/** + * {@link LoadbalanceStrategy} that assigns a weight to each {@code RSocket} based on {@link + * RSocket#availability() availability} and usage statistics. The weight is used to decide which + * {@code RSocket} to select. + * + *

    Use {@link #create()} or a {@link #builder() Builder} to create an instance. + * + * @since 1.1 + * @see Predictive Load-Balancing: Unfair but + * Faster & more Robust + * @see WeightedStatsRequestInterceptor + */ +public class WeightedLoadbalanceStrategy implements ClientLoadbalanceStrategy { + + private static final double EXP_FACTOR = 4.0; + + final int maxPairSelectionAttempts; + final Function weightedStatsResolver; + + private WeightedLoadbalanceStrategy( + int numberOfAttempts, @Nullable Function resolver) { + this.maxPairSelectionAttempts = numberOfAttempts; + this.weightedStatsResolver = (resolver != null ? resolver : new DefaultWeightedStatsResolver()); + } + + @Override + public void initialize(RSocketConnector connector) { + final Function resolver = weightedStatsResolver; + if (resolver instanceof DefaultWeightedStatsResolver) { + ((DefaultWeightedStatsResolver) resolver).init(connector); + } + } + + @Override + public RSocket select(List sockets) { + final int size = sockets.size(); + + RSocket weightedRSocket; + final Function weightedStatsResolver = this.weightedStatsResolver; + switch (size) { + case 1: + weightedRSocket = sockets.get(0); + break; + case 2: + { + RSocket rsc1 = sockets.get(0); + RSocket rsc2 = sockets.get(1); + + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } + break; + default: + { + RSocket rsc1 = null; + RSocket rsc2 = null; + + for (int i = 0; i < this.maxPairSelectionAttempts; i++) { + int i1 = ThreadLocalRandom.current().nextInt(size); + int i2 = ThreadLocalRandom.current().nextInt(size - 1); + + if (i2 >= i1) { + i2++; + } + rsc1 = sockets.get(i1); + rsc2 = sockets.get(i2); + if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { + break; + } + } + + if (rsc1 != null & rsc2 != null) { + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } else if (rsc1 != null) { + weightedRSocket = rsc1; + } else { + weightedRSocket = rsc2; + } + } + } + + return weightedRSocket; + } + + private static double algorithmicWeight( + RSocket rSocket, @Nullable final WeightedStats weightedStats) { + if (weightedStats == null) { + return 1.0; + } + if (rSocket.isDisposed() || rSocket.availability() == 0.0) { + return 0.0; + } + final int pending = weightedStats.pending(); + + double latency = weightedStats.predictedLatency(); + + final double low = weightedStats.lowerQuantileLatency(); + final double high = + Math.max( + weightedStats.higherQuantileLatency(), + low * 1.001); // ensure higherQuantile > lowerQuantile + .1% + final double bandWidth = Math.max(high - low, 1); + + if (latency < low) { + latency /= calculateFactor(low, latency, bandWidth); + } else if (latency > high) { + latency *= calculateFactor(latency, high, bandWidth); + } + + return (rSocket.availability() * weightedStats.weightedAvailability()) + / (1.0d + latency * (pending + 1)); + } + + private static double calculateFactor(final double u, final double l, final double bandWidth) { + final double alpha = (u - l) / bandWidth; + return Math.pow(1 + alpha, EXP_FACTOR); + } + + /** + * Create an instance of {@link WeightedLoadbalanceStrategy} with default settings, which include + * round-robin load-balancing and 5 {@link #maxPairSelectionAttempts}. + */ + public static WeightedLoadbalanceStrategy create() { + return new Builder().build(); + } + + /** Return a builder to create a {@link WeightedLoadbalanceStrategy} with. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link WeightedLoadbalanceStrategy}. */ + public static class Builder { + + private int maxPairSelectionAttempts = 5; + + @Nullable private Function weightedStatsResolver; + + private Builder() {} + + /** + * How many times to try to randomly select a pair of RSocket connections with non-zero + * availability. This is applicable when there are more than two connections in the pool. If the + * number of attempts is exceeded, the last selected pair is used. + * + *

    By default this is set to 5. + * + * @param numberOfAttempts the iteration count + */ + public Builder maxPairSelectionAttempts(int numberOfAttempts) { + this.maxPairSelectionAttempts = numberOfAttempts; + return this; + } + + /** + * Configure how the created {@link WeightedLoadbalanceStrategy} should find the stats for a + * given RSocket. + * + *

    By default this resolver is not set. + * + *

    When {@code WeightedLoadbalanceStrategy} is used through the {@link + * LoadbalanceRSocketClient}, the resolver does not need to be set because a {@link + * WeightedStatsRequestInterceptor} is automatically installed through the {@link + * ClientLoadbalanceStrategy} callback. If this strategy is used in any other context however, a + * resolver here must be provided. + * + * @param resolver to find the stats for an RSocket with + */ + public Builder weightedStatsResolver(Function resolver) { + this.weightedStatsResolver = resolver; + return this; + } + + /** Build the {@code WeightedLoadbalanceStrategy} instance. */ + public WeightedLoadbalanceStrategy build() { + return new WeightedLoadbalanceStrategy( + this.maxPairSelectionAttempts, this.weightedStatsResolver); + } + } + + private static class DefaultWeightedStatsResolver implements Function { + + final Map statsMap = new ConcurrentHashMap<>(); + + @Override + public WeightedStats apply(RSocket rSocket) { + return statsMap.get(rSocket); + } + + void init(RSocketConnector connector) { + connector.interceptors( + registry -> + registry.forRequestsInRequester( + (Function) + rSocket -> { + final WeightedStatsRequestInterceptor interceptor = + new WeightedStatsRequestInterceptor() { + @Override + public void dispose() { + statsMap.remove(rSocket); + } + }; + statsMap.put(rSocket, interceptor); + + return interceptor; + })); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java new file mode 100644 index 000000000..5ebe668ce --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; + +/** + * Contract to expose the stats required in {@link WeightedLoadbalanceStrategy} to calculate an + * algorithmic weight for an {@code RSocket}. The weight helps to select an {@code RSocket} for + * load-balancing. + * + * @since 1.1 + */ +public interface WeightedStats { + + double higherQuantileLatency(); + + double lowerQuantileLatency(); + + int pending(); + + double predictedLatency(); + + double weightedAvailability(); + + /** + * Create a proxy for the given {@code RSocket} that attaches the stats contained in this instance + * and exposes them as {@link WeightedStats}. + * + * @param rsocket the RSocket to wrap + * @return the wrapped RSocket + * @since 1.1.1 + */ + default RSocket wrap(RSocket rsocket) { + return new WeightedStatsRSocketProxy(rsocket, this); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java new file mode 100644 index 000000000..f2cf3fbd0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; + +/** + * Package private {@code RSocketProxy} used from {@link WeightedStats#wrap(RSocket)} to attach a + * {@link WeightedStats} instance to an {@code RSocket}. + */ +class WeightedStatsRSocketProxy extends RSocketProxy implements WeightedStats { + + private final WeightedStats weightedStats; + + public WeightedStatsRSocketProxy(RSocket source, WeightedStats weightedStats) { + super(source); + this.weightedStats = weightedStats; + } + + @Override + public double higherQuantileLatency() { + return this.weightedStats.higherQuantileLatency(); + } + + @Override + public double lowerQuantileLatency() { + return this.weightedStats.lowerQuantileLatency(); + } + + @Override + public int pending() { + return this.weightedStats.pending(); + } + + @Override + public double predictedLatency() { + return this.weightedStats.predictedLatency(); + } + + @Override + public double weightedAvailability() { + return this.weightedStats.weightedAvailability(); + } + + public WeightedStats getDelegate() { + return this.weightedStats; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java new file mode 100644 index 000000000..ec2c88b19 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import reactor.util.annotation.Nullable; + +/** + * {@link RequestInterceptor} that hooks into request lifecycle and calls methods of the parent + * class to manage tracking state and expose {@link WeightedStats}. + * + *

    This interceptor the default mechanism for gathering stats when {@link + * WeightedLoadbalanceStrategy} is used with {@link LoadbalanceRSocketClient}. + * + * @since 1.1 + * @see LoadbalanceRSocketClient + * @see WeightedLoadbalanceStrategy + */ +public class WeightedStatsRequestInterceptor extends BaseWeightedStats + implements RequestInterceptor { + + final Int2LongHashMap requestsStartTime = new Int2LongHashMap(-1); + + public WeightedStatsRequestInterceptor() { + super(); + } + + @Override + public final void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + final long startTime = startRequest(); + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + requestsStartTime.put(streamId, startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + this.startStream(); + } + } + + @Override + public final void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + long endTime = stopRequest(startTime); + if (t == null) { + record(endTime - startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + + if (t != null) { + updateAvailability(0.0d); + } else { + updateAvailability(1.0d); + } + } + + @Override + public final void onCancel(int streamId, FrameType requestType) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + stopRequest(startTime); + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + } + + @Override + public final void onReject(Throwable rejectionReason, FrameType requestType, ByteBuf metadata) {} + + @Override + public void dispose() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java similarity index 73% rename from rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java rename to rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java index 8cc3fb41a..19668e99c 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java @@ -14,14 +14,8 @@ * limitations under the License. */ -/** - * Support for frame fragmentation and reassembly. - * - * @see Fragmentation - * and Reassembly - */ +/** Support client load-balancing in RSocket Java. */ @NonNullApi -package io.rsocket.fragmentation; +package io.rsocket.loadbalance; import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java index d908abb3c..c16c4dc52 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java @@ -12,7 +12,7 @@ public class AuthMetadataCodec { static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 - static final int USERNAME_BYTES_LENGTH = 1; + static final int USERNAME_BYTES_LENGTH = 2; static final int AUTH_TYPE_ID_LENGTH = 1; static final char[] EMPTY_CHARS_ARRAY = new char[0]; @@ -81,7 +81,7 @@ public static ByteBuf encodeMetadata( /** * Encode a Authentication CompositeMetadata payload using Simple Authentication format * - * @throws IllegalArgumentException if the username length is greater than 255 + * @throws IllegalArgumentException if the username length is greater than 65535 * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. * @param username the char sequence which represents user name. * @param password the char sequence which represents user password. @@ -90,9 +90,9 @@ public static ByteBuf encodeSimpleMetadata( ByteBufAllocator allocator, char[] username, char[] password) { int usernameLength = CharByteBufUtil.utf8Bytes(username); - if (usernameLength > 255) { + if (usernameLength > 65535) { throw new IllegalArgumentException( - "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); } int passwordLength = CharByteBufUtil.utf8Bytes(password); @@ -101,7 +101,7 @@ public static ByteBuf encodeSimpleMetadata( allocator .buffer(capacity, capacity) .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) - .writeByte(usernameLength); + .writeShort(usernameLength); CharByteBufUtil.writeUtf8(buffer, username); CharByteBufUtil.writeUtf8(buffer, password); @@ -235,15 +235,15 @@ public static ByteBuf readPayload(ByteBuf metadata) { } /** - * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username - * length and the subsequent number of bytes equal to decoded length + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length * * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the username length byte + * simpleAuthMetadata#readIndex} should be set to the username length position * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero */ public static ByteBuf readUsername(ByteBuf simpleAuthMetadata) { - short usernameLength = readUsernameLength(simpleAuthMetadata); + int usernameLength = readUsernameLength(simpleAuthMetadata); if (usernameLength == 0) { return Unpooled.EMPTY_BUFFER; @@ -268,15 +268,15 @@ public static ByteBuf readPassword(ByteBuf simpleAuthMetadata) { return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); } /** - * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username - * length and the subsequent number of bytes equal to decoded length + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length * * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code * simpleAuthMetadata#readIndex} should be set to the username length byte * @return {@code char[]} which represents UTF-8 username */ public static char[] readUsernameAsCharArray(ByteBuf simpleAuthMetadata) { - short usernameLength = readUsernameLength(simpleAuthMetadata); + int usernameLength = readUsernameLength(simpleAuthMetadata); if (usernameLength == 0) { return EMPTY_CHARS_ARRAY; @@ -302,11 +302,10 @@ public static char[] readPasswordAsCharArray(ByteBuf simpleAuthMetadata) { } /** - * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is - * username length and the subsequent number of bytes equal to decoded length + * Read all the remaining {@code bytes} from the given {@link ByteBuf} * * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * bearerAuthMetadata#readIndex} should be set to the beginning of the password bytes * @return {@code char[]} which represents UTF-8 password */ public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { @@ -317,13 +316,13 @@ public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); } - private static short readUsernameLength(ByteBuf simpleAuthMetadata) { - if (simpleAuthMetadata.readableBytes() < 1) { + private static int readUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 2) { throw new IllegalStateException( "Unable to decode custom username. Not enough readable bytes"); } - short usernameLength = simpleAuthMetadata.readUnsignedByte(); + int usernameLength = simpleAuthMetadata.readUnsignedShort(); if (simpleAuthMetadata.readableBytes() < usernameLength) { throw new IllegalArgumentException( diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java index 4a48921b1..1c3ae9423 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java @@ -16,12 +16,12 @@ package io.rsocket.metadata; -import static io.rsocket.metadata.CompositeMetadataFlyweight.computeNextEntryIndex; -import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeAndContentBuffersSlices; -import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeIdFromMimeBuffer; -import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer; -import static io.rsocket.metadata.CompositeMetadataFlyweight.hasEntry; -import static io.rsocket.metadata.CompositeMetadataFlyweight.isWellKnownMimeType; +import static io.rsocket.metadata.CompositeMetadataCodec.computeNextEntryIndex; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.hasEntry; +import static io.rsocket.metadata.CompositeMetadataCodec.isWellKnownMimeType; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -52,8 +52,8 @@ * ReservedMimeTypeEntry}. In this case {@link Entry#getMimeType()} will return {@code null}. The * encoded id can be retrieved using {@link ReservedMimeTypeEntry#getType()}. The byte and content * buffer should be kept around and re-encoded using {@link - * CompositeMetadataFlyweight#encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, byte, - * ByteBuf)} in case passing that entry through is required. + * CompositeMetadataCodec#encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, byte, ByteBuf)} + * in case passing that entry through is required. */ public final class CompositeMetadata implements Iterable { diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java deleted file mode 100644 index 9916dfd3b..000000000 --- a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java +++ /dev/null @@ -1,262 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.metadata; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import reactor.util.annotation.Nullable; - -/** - * A flyweight class that can be used to encode/decode composite metadata information to/from {@link - * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link - * CompositeMetadata} for an Iterator-like approach to decoding entries. - * - * @deprecated in favor of {@link CompositeMetadataCodec} - */ -@Deprecated -public class CompositeMetadataFlyweight { - - private CompositeMetadataFlyweight() {} - - public static int computeNextEntryIndex( - int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { - return CompositeMetadataCodec.computeNextEntryIndex( - currentEntryIndex, headerSlice, contentSlice); - } - - /** - * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link - * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are - * actually slices of the full metadata buffer, and this method doesn't move the full metadata - * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code - * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, - * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method - * should be further applied to it. - * - *

    The header buffer is either: - * - *

      - *
    • made up of a single byte: this represents an encoded mime id, which can be further - * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} - *
    • made up of 2 or more bytes: this represents an encoded mime String + its length, which - * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the - * encoded length, in the first byte, is skipped by this decoding method because the - * remaining length of the buffer is that of the mime string. - *
    - * - * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more - * metadata entries - * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader - * index is kept on the source buffer - * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be - * {@link ByteBuf#retainedSlice() retained}? - * @return a {@link ByteBuf} array of length 2 containing the mime header buffer - * slice and the content buffer slice, or one of the - * zero-length error constant arrays - */ - public static ByteBuf[] decodeMimeAndContentBuffersSlices( - ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { - return CompositeMetadataCodec.decodeMimeAndContentBuffersSlices( - compositeMetadata, entryIndex, retainSlices); - } - - /** - * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly - * contains such an id. - * - *

    The buffer must have exactly one readable byte, which is assumed to have been tested for - * mime id encoding via the {@link CompositeMetadataCodec#STREAM_METADATA_KNOWN_MASK} mask ({@code - * firstByte & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). - * - *

    If there is no readable byte, the negative identifier of {@link - * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. - * - * @param mimeBuffer the buffer that should next contain the compressed mime id byte - * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid - * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) - */ - public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { - return CompositeMetadataCodec.decodeMimeIdFromMimeBuffer(mimeBuffer); - } - - /** - * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer - * properly contains such a mime type. - * - *

    The buffer must at least have two readable bytes, which distinguishes it from the {@link - * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the - * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a - * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the - * mime type. - * - *

    If the mime header buffer is less than 2 bytes long, returns {@code null}. - * - * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime - * type - * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is - * invalid - * @see #decodeMimeIdFromMimeBuffer(ByteBuf) - */ - @Nullable - public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { - return CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(flyweightMimeBuffer); - } - - /** - * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf - * buffer}, without checking if the {@link String} can be matched with a well known compressable - * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, - * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime - * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, - * ByteBufAllocator, String, ByteBuf)} - * - * @param compositeMetaData the buffer that will hold all composite metadata information. - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param customMimeType the custom mime type to encode. - * @param metadata the metadata value to encode. - */ - // see #encodeMetadataHeader(ByteBufAllocator, String, int) - public static void encodeAndAddMetadata( - CompositeByteBuf compositeMetaData, - ByteBufAllocator allocator, - String customMimeType, - ByteBuf metadata) { - CompositeMetadataCodec.encodeAndAddMetadata( - compositeMetaData, allocator, customMimeType, metadata); - } - - /** - * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf - * buffer}. - * - * @param compositeMetaData the buffer that will hold all composite metadata information. - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param knownMimeType the {@link WellKnownMimeType} to encode. - * @param metadata the metadata value to encode. - */ - // see #encodeMetadataHeader(ByteBufAllocator, byte, int) - public static void encodeAndAddMetadata( - CompositeByteBuf compositeMetaData, - ByteBufAllocator allocator, - WellKnownMimeType knownMimeType, - ByteBuf metadata) { - CompositeMetadataCodec.encodeAndAddMetadata( - compositeMetaData, allocator, knownMimeType, metadata); - } - - /** - * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf - * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in - * which case it will be encoded in a compressed fashion using the mime id of that type). - * - *

    Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, - * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. - * - * @param compositeMetaData the buffer that will hold all composite metadata information. - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param mimeType the mime type to encode, as a {@link String}. well known mime types are - * compressed. - * @param metadata the metadata value to encode. - * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) - */ - // see #encodeMetadataHeader(ByteBufAllocator, String, int) - public static void encodeAndAddMetadataWithCompression( - CompositeByteBuf compositeMetaData, - ByteBufAllocator allocator, - String mimeType, - ByteBuf metadata) { - CompositeMetadataCodec.encodeAndAddMetadataWithCompression( - compositeMetaData, allocator, mimeType, metadata); - } - - /** - * Returns whether there is another entry available at a given index - * - * @param compositeMetadata the buffer to inspect - * @param entryIndex the index to check at - * @return whether there is another entry available at a given index - */ - public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { - return CompositeMetadataCodec.hasEntry(compositeMetadata, entryIndex); - } - - /** - * Returns whether the header represents a well-known MIME type. - * - * @param header the header to inspect - * @return whether the header represents a well-known MIME type - */ - public static boolean isWellKnownMimeType(ByteBuf header) { - return CompositeMetadataCodec.isWellKnownMimeType(header); - } - - /** - * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf - * buffer}. - * - * @param compositeMetaData the buffer that will hold all composite metadata information. - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param unknownCompressedMimeType the id of the {@link - * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. - * @param metadata the metadata value to encode. - */ - // see #encodeMetadataHeader(ByteBufAllocator, byte, int) - static void encodeAndAddMetadata( - CompositeByteBuf compositeMetaData, - ByteBufAllocator allocator, - byte unknownCompressedMimeType, - ByteBuf metadata) { - CompositeMetadataCodec.encodeAndAddMetadata( - compositeMetaData, allocator, unknownCompressedMimeType, metadata); - } - - /** - * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. - * - *

    This larger representation encodes the mime type representation's length on a single byte, - * then the representation itself, then the unsigned metadata value length on 3 additional bytes. - * - * @param allocator the {@link ByteBufAllocator} to use to create the buffer. - * @param customMime a custom mime type to encode. - * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits - * integer. - * @return the encoded mime and metadata length information - */ - static ByteBuf encodeMetadataHeader( - ByteBufAllocator allocator, String customMime, int metadataLength) { - return CompositeMetadataCodec.encodeMetadataHeader(allocator, customMime, metadataLength); - } - - /** - * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a - * newly allocated {@link ByteBuf}. - * - *

    This compact representation encodes the mime type via its ID on a single byte, and the - * unsigned value length on 3 additional bytes. - * - * @param allocator the {@link ByteBufAllocator} to use to create the buffer. - * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. - * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits - * integer. - * @return the encoded mime and metadata length information - */ - static ByteBuf encodeMetadataHeader( - ByteBufAllocator allocator, byte mimeType, int metadataLength) { - return CompositeMetadataCodec.encodeMetadataHeader(allocator, mimeType, metadataLength); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java new file mode 100644 index 000000000..2e03bd754 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayList; +import java.util.List; + +/** + * Provides support for encoding and decoding the per-stream MIME type to use for payload data. + * + *

    For more on the format of the metadata, see the + * Stream Data MIME Types extension specification. + * + * @since 1.1.1 + */ +public class MimeTypeMetadataCodec { + + private static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + private static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private MimeTypeMetadataCodec() {} + + /** + * Encode a {@link WellKnownMimeType} into a newly allocated single byte {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType well-known MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, WellKnownMimeType mimeType) { + return allocator.buffer(1, 1).writeByte(mimeType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + } + + /** + * Encode the given MIME type into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, String mimeType) { + if (mimeType == null || mimeType.length() == 0) { + throw new IllegalArgumentException("MIME type is required"); + } + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + return encodeCustomMimeType(allocator, mimeType); + } else { + return encode(allocator, wkn); + } + } + + /** + * Encode multiple MIME types into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeTypes MIME types to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, List mimeTypes) { + if (mimeTypes == null || mimeTypes.size() == 0) { + throw new IllegalArgumentException("No MIME types provided"); + } + CompositeByteBuf compositeByteBuf = allocator.compositeBuffer(); + for (String mimeType : mimeTypes) { + ByteBuf byteBuf = encode(allocator, mimeType); + compositeByteBuf.addComponents(true, byteBuf); + } + return compositeByteBuf; + } + + private static ByteBuf encodeCustomMimeType(ByteBufAllocator allocator, String customMimeType) { + ByteBuf byteBuf = allocator.buffer(1 + customMimeType.length()); + + byteBuf.writerIndex(1); + int length = ByteBufUtil.writeUtf8(byteBuf, customMimeType); + + if (!ByteBufUtil.isText(byteBuf, 1, length, CharsetUtil.US_ASCII)) { + byteBuf.release(); + throw new IllegalArgumentException("MIME type must be ASCII characters only"); + } + + if (length < 1 || length > 128) { + byteBuf.release(); + throw new IllegalArgumentException( + "MIME type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + byteBuf.markWriterIndex(); + byteBuf.writerIndex(0); + byteBuf.writeByte(length - 1); + byteBuf.resetWriterIndex(); + + return byteBuf; + } + + /** + * Decode the per-stream MIME type metadata encoded in the given {@link ByteBuf}. + * + * @return the decoded MIME types + */ + public static List decode(ByteBuf byteBuf) { + List mimeTypes = new ArrayList<>(); + while (byteBuf.isReadable()) { + byte idOrLength = byteBuf.readByte(); + if ((idOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + byte id = (byte) (idOrLength & STREAM_METADATA_LENGTH_MASK); + WellKnownMimeType wellKnownMimeType = WellKnownMimeType.fromIdentifier(id); + mimeTypes.add(wellKnownMimeType.toString()); + } else { + int length = Byte.toUnsignedInt(idOrLength) + 1; + mimeTypes.add(byteBuf.readCharSequence(length, CharsetUtil.US_ASCII).toString()); + } + } + return mimeTypes; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java deleted file mode 100644 index 718528358..000000000 --- a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java +++ /dev/null @@ -1,62 +0,0 @@ -package io.rsocket.metadata; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import java.util.Collection; - -/** - * A flyweight class that can be used to encode/decode tagging metadata information to/from {@link - * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link - * TaggingMetadata} for an Iterator-like approach to decoding entries. - * - * @deprecated in favor of {@link TaggingMetadataCodec} - * @author linux_china - */ -@Deprecated -public class TaggingMetadataFlyweight { - /** - * create routing metadata - * - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param tags tag values - * @return routing metadata - */ - public static RoutingMetadata createRoutingMetadata( - ByteBufAllocator allocator, Collection tags) { - return TaggingMetadataCodec.createRoutingMetadata(allocator, tags); - } - - /** - * create tagging metadata from composite metadata entry - * - * @param entry composite metadata entry - * @return tagging metadata - */ - public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { - return TaggingMetadataCodec.createTaggingMetadata(entry); - } - - /** - * create tagging metadata - * - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param knownMimeType the {@link WellKnownMimeType} to encode. - * @param tags tag values - * @return Tagging Metadata - */ - public static TaggingMetadata createTaggingMetadata( - ByteBufAllocator allocator, String knownMimeType, Collection tags) { - return TaggingMetadataCodec.createTaggingMetadata(allocator, knownMimeType, tags); - } - - /** - * create tagging content - * - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param tags tag values - * @return tagging content - */ - public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { - return TaggingMetadataCodec.createTaggingContent(allocator, tags); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java deleted file mode 100644 index e1a8ba449..000000000 --- a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java +++ /dev/null @@ -1,194 +0,0 @@ -package io.rsocket.metadata.security; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.metadata.AuthMetadataCodec; - -/** @deprecated in favor of {@link io.rsocket.metadata.AuthMetadataCodec} */ -@Deprecated -public class AuthMetadataFlyweight { - - static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 - - private AuthMetadataFlyweight() {} - - /** - * Encode a Authentication CompositeMetadata payload using custom authentication type - * - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param customAuthType the custom mime type to encode. - * @param metadata the metadata value to encode. - * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or - * empty string or its length is greater than 128 bytes - */ - public static ByteBuf encodeMetadata( - ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { - - return AuthMetadataCodec.encodeMetadata(allocator, customAuthType, metadata); - } - - /** - * Encode a Authentication CompositeMetadata payload using custom authentication type - * - * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. - * @param authType the well-known mime type to encode. - * @param metadata the metadata value to encode. - * @throws IllegalArgumentException in case of {@code authType} is {@link - * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link - * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} - */ - public static ByteBuf encodeMetadata( - ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { - - return AuthMetadataCodec.encodeMetadata(allocator, WellKnownAuthType.cast(authType), metadata); - } - - /** - * Encode a Authentication CompositeMetadata payload using Simple Authentication format - * - * @throws IllegalArgumentException if the username length is greater than 255 - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param username the char sequence which represents user name. - * @param password the char sequence which represents user password. - */ - public static ByteBuf encodeSimpleMetadata( - ByteBufAllocator allocator, char[] username, char[] password) { - return AuthMetadataCodec.encodeSimpleMetadata(allocator, username, password); - } - - /** - * Encode a Authentication CompositeMetadata payload using Bearer Authentication format - * - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param token the char sequence which represents BEARER token. - */ - public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { - return AuthMetadataCodec.encodeBearerMetadata(allocator, token); - } - - /** - * Encode a new Authentication Metadata payload information, first verifying if the passed {@link - * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed - * fashion using the mime id of that type). - * - *

    Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know - * that the mime type is not a {@link WellKnownAuthType}. - * - * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. - * @param authType the mime type to encode, as a {@link String}. well known mime types are - * compressed. - * @param metadata the metadata value to encode. - * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) - * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) - */ - public static ByteBuf encodeMetadataWithCompression( - ByteBufAllocator allocator, String authType, ByteBuf metadata) { - return AuthMetadataCodec.encodeMetadataWithCompression(allocator, authType, metadata); - } - - /** - * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link - * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} - * - * @param metadata byteBuf used to get information from - */ - public static boolean isWellKnownAuthType(ByteBuf metadata) { - return AuthMetadataCodec.isWellKnownAuthType(metadata); - } - - /** - * Read first byte from the given {@code metadata} and tries to convert it's value to {@link - * WellKnownAuthType}. - * - * @param metadata given metadata buffer to read from - * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if - * field's value is length or unknown auth type - * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} - */ - public static WellKnownAuthType decodeWellKnownAuthType(ByteBuf metadata) { - return WellKnownAuthType.cast(AuthMetadataCodec.readWellKnownAuthType(metadata)); - } - - /** - * Read up to 129 bytes from the given metadata in order to get the custom Auth Type - * - * @param metadata - * @return - */ - public static CharSequence decodeCustomAuthType(ByteBuf metadata) { - return AuthMetadataCodec.readCustomAuthType(metadata); - } - - /** - * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced - * representation of a payload - * - * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} - * should be set to the beginning of the payload bytes - * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the - * given one - */ - public static ByteBuf decodePayload(ByteBuf metadata) { - return AuthMetadataCodec.readPayload(metadata); - } - - /** - * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username - * length and the subsequent number of bytes equal to decoded length - * - * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the username length byte - * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero - */ - public static ByteBuf decodeUsername(ByteBuf simpleAuthMetadata) { - return AuthMetadataCodec.readUsername(simpleAuthMetadata); - } - - /** - * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's - * password - * - * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes - * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero - */ - public static ByteBuf decodePassword(ByteBuf simpleAuthMetadata) { - return AuthMetadataCodec.readPassword(simpleAuthMetadata); - } - /** - * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username - * length and the subsequent number of bytes equal to decoded length - * - * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the username length byte - * @return {@code char[]} which represents UTF-8 username - */ - public static char[] decodeUsernameAsCharArray(ByteBuf simpleAuthMetadata) { - return AuthMetadataCodec.readUsernameAsCharArray(simpleAuthMetadata); - } - - /** - * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's - * password - * - * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes - * @return {@code char[]} which represents UTF-8 password - */ - public static char[] decodePasswordAsCharArray(ByteBuf simpleAuthMetadata) { - return AuthMetadataCodec.readPasswordAsCharArray(simpleAuthMetadata); - } - - /** - * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is - * username length and the subsequent number of bytes equal to decoded length - * - * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes - * @return {@code char[]} which represents UTF-8 password - */ - public static char[] decodeBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { - return AuthMetadataCodec.readBearerTokenAsCharArray(bearerAuthMetadata); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java deleted file mode 100644 index 24e5ff0db..000000000 --- a/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.metadata.security; - -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.Map; - -/** - * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are - * used in composite metadata (which can include routing and/or tracing metadata). Per - * specification, identifiers are between 0 and 127 (inclusive). - * - * @deprecated in favor of {@link io.rsocket.metadata.WellKnownAuthType} - */ -@Deprecated -public enum WellKnownAuthType { - UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), - UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), - - SIMPLE("simple", (byte) 0x00), - BEARER("bearer", (byte) 0x01); - // ... reserved for future use ... - - static final WellKnownAuthType[] TYPES_BY_AUTH_ID; - static final Map TYPES_BY_AUTH_STRING; - - static { - // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum - TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive - Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); - // also prepare a Map of the types by auth string - TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); - - for (WellKnownAuthType value : values()) { - if (value.getIdentifier() >= 0) { - TYPES_BY_AUTH_ID[value.getIdentifier()] = value; - TYPES_BY_AUTH_STRING.put(value.getString(), value); - } - } - } - - private final byte identifier; - private final String str; - - WellKnownAuthType(String str, byte identifier) { - this.str = str; - this.identifier = identifier; - } - - static io.rsocket.metadata.WellKnownAuthType cast(WellKnownAuthType wellKnownAuthType) { - byte identifier = wellKnownAuthType.identifier; - if (identifier == io.rsocket.metadata.WellKnownAuthType.UNPARSEABLE_AUTH_TYPE.getIdentifier()) { - return io.rsocket.metadata.WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; - } else if (identifier - == io.rsocket.metadata.WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE.getIdentifier()) { - return io.rsocket.metadata.WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; - } else { - return io.rsocket.metadata.WellKnownAuthType.fromIdentifier(identifier); - } - } - - static WellKnownAuthType cast(io.rsocket.metadata.WellKnownAuthType wellKnownAuthType) { - byte identifier = wellKnownAuthType.getIdentifier(); - if (identifier == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE.identifier) { - return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; - } else if (identifier == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE.identifier) { - return WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; - } else { - return TYPES_BY_AUTH_ID[identifier]; - } - } - - /** - * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid - * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of - * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in - * that range are still only reserved and don't have a type associated yet: this method returns - * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites - * potentially detect this and keep the original representation when transmitting the associated - * metadata buffer. - * - * @param id the looked up identifier - * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out - * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that - * is merely reserved but unknown to this implementation. - */ - public static WellKnownAuthType fromIdentifier(int id) { - if (id < 0x00 || id > 0x7F) { - return UNPARSEABLE_AUTH_TYPE; - } - return TYPES_BY_AUTH_ID[id]; - } - - /** - * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the - * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link - * #UNPARSEABLE_AUTH_TYPE} is returned. - * - * @param authType the looked up auth type - * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none - * matches - */ - public static WellKnownAuthType fromString(String authType) { - if (authType == null) throw new IllegalArgumentException("type must be non-null"); - - // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used - if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { - return UNPARSEABLE_AUTH_TYPE; - } - - return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); - } - - /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ - public byte getIdentifier() { - return identifier; - } - - /** - * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible - * characters only - */ - public String getString() { - return str; - } - - /** @see #getString() */ - @Override - public String toString() { - return str; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java new file mode 100644 index 000000000..9a134153d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -0,0 +1,147 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import java.util.List; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +class CompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + CompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Nullable + static RequestInterceptor create(List interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0)); + default: + return new CompositeRequestInterceptor(interceptors.toArray(new RequestInterceptor[0])); + } + } + + static class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java index 6b2a7a71b..5d3a43b03 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java @@ -27,6 +27,8 @@ extends BiFunction { enum Type { + /** @deprecated since 1.1.0-M2. Will be removed in 1.2 */ + @Deprecated SETUP, CLIENT, SERVER, diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java index fc032847c..7c9a90f54 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -18,6 +18,9 @@ import io.rsocket.DuplexConnection; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import reactor.util.annotation.Nullable; /** * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. @@ -25,6 +28,27 @@ */ public class InitializingInterceptorRegistry extends InterceptorRegistry { + @Nullable + public RequestInterceptor initRequesterRequestInterceptor(RSocket rSocketRequester) { + return CompositeRequestInterceptor.create( + getRequestInterceptorsForRequester() + .stream() + .map(factory -> factory.apply(rSocketRequester)) + .collect(Collectors.toList())); + } + + @Nullable + public RequestInterceptor initResponderRequestInterceptor( + RSocket rSocketResponder, RequestInterceptor... perConnectionInterceptors) { + return CompositeRequestInterceptor.create( + Stream.concat( + Stream.of(perConnectionInterceptors), + getRequestInterceptorsForResponder() + .stream() + .map(inteptorFactory -> inteptorFactory.apply(rSocketResponder))) + .collect(Collectors.toList())); + } + public DuplexConnection initConnection( DuplexConnectionInterceptor.Type type, DuplexConnection connection) { for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { @@ -34,7 +58,7 @@ public DuplexConnection initConnection( } public RSocket initRequester(RSocket rsocket) { - for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + for (RSocketInterceptor interceptor : getRequesterInterceptors()) { rsocket = interceptor.apply(rsocket); } return rsocket; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java index 427fa15ae..680fb514f 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -15,9 +15,11 @@ */ package io.rsocket.plugins; +import io.rsocket.RSocket; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +import java.util.function.Function; /** * Provides support for registering interceptors at the following levels: @@ -30,16 +32,46 @@ * */ public class InterceptorRegistry { - private List requesterInteceptors = new ArrayList<>(); - private List responderInterceptors = new ArrayList<>(); + private List> requesterRequestInterceptors = + new ArrayList<>(); + private List> responderRequestInterceptors = + new ArrayList<>(); + private List requesterRSocketInterceptors = new ArrayList<>(); + private List responderRSocketInterceptors = new ArrayList<>(); private List socketAcceptorInterceptors = new ArrayList<>(); private List connectionInterceptors = new ArrayList<>(); + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInRequester( + Function interceptor) { + requesterRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInResponder( + Function interceptor) { + responderRequestInterceptors.add(interceptor); + return this; + } + /** * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. */ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { - requesterInteceptors.add(interceptor); + requesterRSocketInterceptors.add(interceptor); return this; } @@ -48,7 +80,7 @@ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forRequester(Consumer> consumer) { - consumer.accept(requesterInteceptors); + consumer.accept(requesterRSocketInterceptors); return this; } @@ -57,7 +89,7 @@ public InterceptorRegistry forRequester(Consumer> consu * requests. */ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { - responderInterceptors.add(interceptor); + responderRSocketInterceptors.add(interceptor); return this; } @@ -66,7 +98,7 @@ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forResponder(Consumer> consumer) { - consumer.accept(responderInterceptors); + consumer.accept(responderRSocketInterceptors); return this; } @@ -102,12 +134,20 @@ public InterceptorRegistry forConnection(Consumer getRequesterInteceptors() { - return requesterInteceptors; + List> getRequestInterceptorsForRequester() { + return requesterRequestInterceptors; + } + + List> getRequestInterceptorsForResponder() { + return responderRequestInterceptors; + } + + List getRequesterInterceptors() { + return requesterRSocketInterceptors; } List getResponderInterceptors() { - return responderInterceptors; + return responderRSocketInterceptors; } List getConnectionInterceptors() { diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java new file mode 100644 index 000000000..08131b39d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -0,0 +1,79 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.Disposable; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Class used to track the RSocket requests lifecycles. The main difference and advantage of this + * interceptor compares to {@link RSocketInterceptor} is that it allows intercepting the initial and + * terminal phases on every individual request. + * + *

    Note, if any of the invocations will rise a runtime exception, this exception will be + * caught and be propagated to {@link reactor.core.publisher.Operators#onErrorDropped(Throwable, + * Context)} + * + * @since 1.1 + */ +public interface RequestInterceptor extends Disposable { + + /** + * Method which is being invoked on successful acceptance and start of a request. + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + + /** + * Method which is being invoked once a successfully accepted request is terminated. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onCancel(int, FrameType)}. + * + * @param streamId used by this request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param t with which this finished has terminated. Must be one of the following signals + */ + void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t); + + /** + * Method which is being invoked once a successfully accepted request is cancelled. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onTerminate(int, FrameType, Throwable)}. + * + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param streamId used by this request + */ + void onCancel(int streamId, FrameType requestType); + + /** + * Method which is being invoked on the request rejection. This method is being called only if the + * actual request can not be started and is called instead of the {@link #onStart(int, FrameType, + * ByteBuf)} method. The reason for rejection can be one of the following: + * + *

    + * + *

      + *
    • No available {@link io.rsocket.lease.Lease} on the requester or the responder sides + *
    • Invalid {@link io.rsocket.Payload} size or format on the Requester side, so the request + * is being rejected before the actual streamId is generated + *
    • A second subscription on the ongoing Request + *
    + * + * @param rejectionReason exception which causes rejection of a particular request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onReject(Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java index ed9450357..ca4f5dcb4 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -18,177 +18,366 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.ResumeOkFrameCodec; -import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.keepalive.KeepAliveSupport; import java.time.Duration; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.function.Tuple2; import reactor.util.retry.Retry; -public class ClientRSocketSession implements RSocketSession> { +public class ClientRSocketSession + implements RSocketSession, + ResumeStateHolder, + CoreSubscriber> { + private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); - private final ResumableDuplexConnection resumableConnection; - private volatile Mono newConnection; - private volatile ByteBuf resumeToken; - private final ByteBufAllocator allocator; + final ResumableDuplexConnection resumableConnection; + final Mono> connectionFactory; + final ResumableFramesStore resumableFramesStore; + + final ByteBufAllocator allocator; + final Duration resumeSessionDuration; + final Retry retry; + final boolean cleanupStoreOnKeepAlive; + final ByteBuf resumeToken; + final String session; + final Disposable reconnectDisposable; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ClientRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; public ClientRSocketSession( - DuplexConnection duplexConnection, + ByteBuf resumeToken, + ResumableDuplexConnection resumableDuplexConnection, + Mono connectionFactory, + Function>> connectionTransformer, + ResumableFramesStore resumableFramesStore, Duration resumeSessionDuration, Retry retry, - ResumableFramesStore resumableFramesStore, - Duration resumeStreamTimeout, boolean cleanupStoreOnKeepAlive) { - this.allocator = duplexConnection.alloc(); - this.resumableConnection = - new ResumableDuplexConnection( - "client", - duplexConnection, - resumableFramesStore, - resumeStreamTimeout, - cleanupStoreOnKeepAlive); - - /*session completed: release token initially retained in resumeToken(ByteBuf)*/ - onClose().doFinally(s -> resumeToken.release()).subscribe(); - - resumableConnection - .connectionErrors() - .flatMap( - err -> { - logger.debug("Client session connection error. Starting new connection"); - AtomicBoolean once = new AtomicBoolean(); - return newConnection - .delaySubscription( - once.compareAndSet(false, true) - ? retry.generateCompanion(Flux.just(new RetrySignal(err))) - : Mono.empty()) - .retryWhen(retry) - .timeout(resumeSessionDuration); - }) - .map(ClientServerInputMultiplexer::new) - .subscribe( - multiplexer -> { - /*reconnect resumable connection*/ - reconnect(multiplexer.asClientServerConnection()); - long impliedPosition = resumableConnection.impliedPosition(); - long position = resumableConnection.position(); - logger.debug( - "Client ResumableConnection reconnected. Sending RESUME frame with state: [impliedPos: {}, pos: {}]", - impliedPosition, - position); - /*Connection is established again: send RESUME frame to server, listen for RESUME_OK*/ - sendFrame( + this.resumeToken = resumeToken; + this.session = resumeToken.toString(CharsetUtil.UTF_8); + this.connectionFactory = + connectionFactory + .doOnDiscard( + DuplexConnection.class, + c -> { + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + c.sendErrorAndClose(connectionErrorException); + c.receive().subscribe(); + }) + .flatMap( + dc -> { + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + final long position = resumableFramesStore.framePosition(); + dc.sendFrame( + 0, ResumeFrameCodec.encode( - allocator, - /*retain so token is not released once sent as part of resume frame*/ + dc.alloc(), resumeToken.retain(), - impliedPosition, - position)) - .then(multiplexer.asSetupConnection().receive().next()) - .subscribe(this::resumeWith); - }, - err -> { - logger.debug("Client ResumableConnection reconnect timeout"); - resumableConnection.dispose(); - }); - } + // server uses this to release its cache + impliedPosition, // observed on the client side + // server uses this to check whether there is no mismatch + position // sent from the client sent + )); - @Override - public ClientRSocketSession continueWith(Mono connectionFactory) { - this.newConnection = connectionFactory; - return this; - } + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeFrame[impliedPosition[{}], position[{}]] has been sent.", + session, + impliedPosition, + position); + } - @Override - public ClientRSocketSession resumeWith(ByteBuf resumeOkFrame) { - logger.debug("ResumeOK FRAME received"); - long remotePos = remotePos(resumeOkFrame); - long remoteImpliedPos = remoteImpliedPos(resumeOkFrame); - resumeOkFrame.release(); - - resumableConnection.resume( - remotePos, - remoteImpliedPos, - pos -> - pos.then() - /*Resumption is impossible: send CONNECTION_ERROR*/ - .onErrorResume( - err -> - sendFrame( - ErrorFrameCodec.encode( - allocator, 0, errorFrameThrowable(remoteImpliedPos))) - .then(Mono.fromRunnable(resumableConnection::dispose)) - /*Resumption is impossible: no need to return control to ResumableConnection*/ - .then(Mono.never()))); - return this; + return connectionTransformer.apply(dc); + }) + .doOnDiscard(Tuple2.class, this::tryReestablishSession); + this.resumableFramesStore = resumableFramesStore; + this.allocator = resumableDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.retry = retry; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + + this.reconnectDisposable = + resumableDuplexConnection.onActiveConnectionClosed().subscribe(this::reconnect); } - public ClientRSocketSession resumeToken(ByteBuf resumeToken) { - /*retain so token is not released once sent as part of setup frame*/ - this.resumeToken = resumeToken.retain(); - return this; + void reconnect(int index) { + if (this.s == Operators.cancelledSubscription()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting rejected since session is closed", + session, + index); + } + return; + } + + keepAliveSupport.stop(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting to resume...", + session, + index); + } + connectionFactory + .doOnNext(this::tryReestablishSession) + .retryWhen(retry) + .timeout(resumeSessionDuration) + .subscribe(this); } @Override - public void reconnect(DuplexConnection connection) { - resumableConnection.reconnect(connection); + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); } @Override - public ResumableDuplexConnection resumableConnection() { - return resumableConnection; + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } } @Override - public ByteBuf token() { - return resumeToken; - } + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Disposing", session); + } - private Mono sendFrame(ByteBuf frame) { - return resumableConnection.sendOne(frame).onErrorResume(err -> Mono.empty()); - } + boolean result = Operators.terminate(S, this); - private static long remoteImpliedPos(ByteBuf resumeOkFrame) { - return ResumeOkFrameCodec.lastReceivedClientPos(resumeOkFrame); - } + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Sessions[isDisposed={}]", session, result); + } - private static long remotePos(ByteBuf resumeOkFrame) { - return -1; + reconnectDisposable.dispose(); + resumableConnection.dispose(); + // frame store is disposed by resumable connection + // resumableFramesStore.dispose(); + + if (resumeToken.refCnt() > 0) { + resumeToken.release(); + } } - private static ConnectionErrorException errorFrameThrowable(long impliedPos) { - return new ConnectionErrorException("resumption_server_pos=[" + impliedPos + "]"); + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); } - private static class RetrySignal implements Retry.RetrySignal { + void tryReestablishSession(Tuple2 tuple2) { + if (logger.isDebugEnabled()) { + logger.debug("Active subscription is canceled {}", s == Operators.cancelledSubscription()); + } + ByteBuf shouldBeResumeOKFrame = tuple2.getT1(); + DuplexConnection nextDuplexConnection = tuple2.getT2(); + + final int streamId = FrameHeaderCodec.streamId(shouldBeResumeOKFrame); + if (streamId != 0) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + throw connectionErrorException; // throw to retry connection again + } + + final FrameType frameType = FrameHeaderCodec.nativeFrameType(shouldBeResumeOKFrame); + if (frameType == FrameType.RESUME_OK) { + // how many frames the server has received from the client + // so the client can release cached frames by this point + long remoteImpliedPos = ResumeOkFrameCodec.lastReceivedClientPos(shouldBeResumeOKFrame); + // what was the last notification from the server about number of frames being + // observed + final long position = resumableFramesStore.framePosition(); + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeOK FRAME received. ServerResumeState[remoteImpliedPosition[{}]]. ClientResumeState[impliedPosition[{}], position[{}]]", + session, + remoteImpliedPos, + impliedPosition, + position); + } + if (position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + } catch (IllegalStateException e) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + e); + } + final ConnectionErrorException t = new ConnectionErrorException(e.getMessage(), e); + + resumableConnection.dispose(nextDuplexConnection, t); + + nextDuplexConnection.sendErrorAndClose(t); + nextDuplexConnection.receive().subscribe(); + + return; + } + + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + return; + } + + keepAliveSupport.start(); - private final Throwable ex; + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Session has been resumed successfully", session); + } - RetrySignal(Throwable ex) { - this.ex = ex; + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + // no need to do anything since connection resumable connection is liklly to + // be disposed + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[" + remoteImpliedPos + "]"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + } + } else if (frameType == FrameType.ERROR) { + final RuntimeException exception = Exceptions.from(0, shouldBeResumeOKFrame); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Received error frame. Terminating received connection", + session, + exception); + } + if (exception instanceof RejectedResumeException) { + resumableConnection.dispose(nextDuplexConnection, exception); + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + return; + } + + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + throw exception; // assume retryable exception + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + // no need to do anything since remote server rejected our connection completely } + } - @Override - public long totalRetries() { - return 0; + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } } + } - @Override - public long totalRetriesInARow() { - return 0; + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); } + } + + @Override + public void onNext(Tuple2 objects) {} - @Override - public Throwable failure() { - return ex; + @Override + public void onError(Throwable t) { + if (!Operators.terminate(S, this)) { + Operators.onErrorDropped(t, currentContext()); } + + resumableConnection.dispose(); + } + + @Override + public void onComplete() {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java deleted file mode 100644 index 461be02d2..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import java.time.Duration; -import java.util.Objects; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; - -/** - * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#backoff(long, Duration)} to {@link - * io.rsocket.core.Resume#retry(Retry)}. - */ -@Deprecated -public class ExponentialBackoffResumeStrategy implements ResumeStrategy { - private volatile Duration next; - private final Duration firstBackoff; - private final Duration maxBackoff; - private final int factor; - - public ExponentialBackoffResumeStrategy(Duration firstBackoff, Duration maxBackoff, int factor) { - this.firstBackoff = Objects.requireNonNull(firstBackoff, "firstBackoff"); - this.maxBackoff = Objects.requireNonNull(maxBackoff, "maxBackoff"); - this.factor = requirePositive(factor); - } - - @Override - public Publisher apply(ClientResume clientResume, Throwable throwable) { - return Flux.defer(() -> Mono.delay(next()).thenReturn(toString())); - } - - Duration next() { - next = - next == null - ? firstBackoff - : Duration.ofMillis(Math.min(maxBackoff.toMillis(), next.toMillis() * factor)); - return next; - } - - private static int requirePositive(int value) { - if (value <= 0) { - throw new IllegalArgumentException("Value must be positive: " + value); - } else { - return value; - } - } - - @Override - public String toString() { - return "ExponentialBackoffResumeStrategy{" - + "next=" - + next - + ", firstBackoff=" - + firstBackoff - + ", maxBackoff=" - + maxBackoff - + ", factor=" - + factor - + '}'; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java index 1875b7eac..e23bc154b 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,225 +16,839 @@ package io.rsocket.resume; +import static io.rsocket.resume.ResumableDuplexConnection.isResumableFrame; + import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayDeque; import java.util.Queue; -import org.reactivestreams.Subscriber; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.util.concurrent.Queues; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * writes - n (where n is frequent, primary operation) reads - m (where m == KeepAliveFrequency) + * skip - k -> 0 (where k is the rare operation which happens after disconnection + */ +public class InMemoryResumableFramesStore extends Flux + implements ResumableFramesStore, Subscription { -public class InMemoryResumableFramesStore implements ResumableFramesStore { + private FramesSubscriber framesSubscriber; private static final Logger logger = LoggerFactory.getLogger(InMemoryResumableFramesStore.class); - private static final long SAVE_REQUEST_SIZE = Long.MAX_VALUE; - private final MonoProcessor disposed = MonoProcessor.create(); - volatile long position; - volatile long impliedPosition; - volatile int cacheSize; + final Sinks.Empty disposed = Sinks.empty(); final Queue cachedFrames; - private final String tag; - private final int cacheLimit; - private volatile int upstreamFrameRefCnt; + final String side; + final String session; + final int cacheLimit; + + volatile long impliedPosition; + static final AtomicLongFieldUpdater IMPLIED_POSITION = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "impliedPosition"); + + volatile long firstAvailableFramePosition; + static final AtomicLongFieldUpdater FIRST_AVAILABLE_FRAME_POSITION = + AtomicLongFieldUpdater.newUpdater( + InMemoryResumableFramesStore.class, "firstAvailableFramePosition"); + + long remoteImpliedPosition; + + int cacheSize; - public InMemoryResumableFramesStore(String tag, int cacheSizeBytes) { - this.tag = tag; + Throwable terminal; + + CoreSubscriber actual; + CoreSubscriber pendingActual; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "state"); + + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is finalized and all related + * stores are cleaned + */ + static final long FINALIZED_FLAG = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * InMemoryResumableFramesStore#dispose()} method + */ + static final long DISPOSED_FLAG = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * FramesSubscriber#onComplete()} or {@link FramesSubscriber#onError(Throwable)} ()} methods + */ + static final long TERMINATED_FLAG = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** Flag which indicates that {@link InMemoryResumableFramesStore} has active frames consumer */ + static final long CONNECTED_FLAG = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} has no active frames consumer + * but there is a one pending + */ + static final long PENDING_CONNECTION_FLAG = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some received implied position changes from the remote + * party + */ + static final long REMOTE_IMPLIED_POSITION_CHANGED_FLAG = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some frames stored in the {@link + * io.rsocket.internal.UnboundedProcessor} which has to be cached and sent to the remote party + */ + static final long HAS_FRAME_FLAG = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore#drain(long)} has an actor which + * is currently progressing on the work. This flag should work as a guard to enter|exist into|from + * the {@link InMemoryResumableFramesStore#drain(long)} method. + */ + static final long MAX_WORK_IN_PROGRESS = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + public InMemoryResumableFramesStore(String side, ByteBuf session, int cacheSizeBytes) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); this.cacheLimit = cacheSizeBytes; - this.cachedFrames = cachedFramesQueue(cacheSizeBytes); + this.cachedFrames = new ArrayDeque<>(); } public Mono saveFrames(Flux frames) { - MonoProcessor completed = MonoProcessor.create(); - frames - .doFinally(s -> completed.onComplete()) - .subscribe(new FramesSubscriber(SAVE_REQUEST_SIZE)); - return completed; + return frames + .transform( + Operators.lift( + (__, actual) -> this.framesSubscriber = new FramesSubscriber(actual, this))) + .then(); } @Override public void releaseFrames(long remoteImpliedPos) { - long pos = position; - logger.debug( - "{} Removing frames for local: {}, remote implied: {}", tag, pos, remoteImpliedPos); - long removeSize = Math.max(0, remoteImpliedPos - pos); - while (removeSize > 0) { - ByteBuf cachedFrame = cachedFrames.poll(); - if (cachedFrame != null) { - removeSize -= releaseTailFrame(cachedFrame); - } else { + long lastReceivedRemoteImpliedPosition = this.remoteImpliedPosition; + if (lastReceivedRemoteImpliedPosition > remoteImpliedPos) { + throw new IllegalStateException( + "Given Remote Implied Position is behind the last received Remote Implied Position"); + } + + this.remoteImpliedPosition = remoteImpliedPos; + + final long previousState = markRemoteImpliedPositionChanged(this); + if (isFinalized(previousState) || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | REMOTE_IMPLIED_POSITION_CHANGED_FLAG); + } + + void drain(long expectedState) { + final Fuseable.QueueSubscription qs = this.framesSubscriber.qs; + final Queue cachedFrames = this.cachedFrames; + + for (; ; ) { + if (hasRemoteImpliedPositionChanged(expectedState)) { + expectedState = handlePendingRemoteImpliedPositionChanges(expectedState, cachedFrames); + } + + if (hasPendingConnection(expectedState)) { + expectedState = handlePendingConnection(expectedState, cachedFrames); + } + + if (isConnected(expectedState)) { + if (isTerminated(expectedState)) { + handleTerminated(qs, this.terminal); + } else if (isDisposed()) { + handleDisposed(); + } else if (hasFrames(expectedState)) { + handlePendingFrames(qs); + } + } + + if (isDisposed(expectedState) || isTerminated(expectedState)) { + clearAndFinalize(this); + return; + } + + expectedState = markWorkDone(this, expectedState); + if (isFinalized(expectedState)) { + return; + } + + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + long handlePendingRemoteImpliedPositionChanges(long expectedState, Queue cachedFrames) { + final long remoteImpliedPosition = this.remoteImpliedPosition; + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toDropFromCache = Math.max(0, remoteImpliedPosition - firstAvailableFramePosition); + + if (toDropFromCache > 0) { + final int droppedFromCache = dropFramesFromCache(toDropFromCache, cachedFrames); + + if (toDropFromCache > droppedFromCache) { + this.terminal = + new IllegalStateException( + String.format( + "Local and remote state disagreement: " + + "need to remove additional %d bytes, but cache is empty", + toDropFromCache)); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + if (toDropFromCache < droppedFromCache) { + this.terminal = + new IllegalStateException( + "Local and remote state disagreement: local and remote frame sizes are not equal"); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + droppedFromCache); + if (this.cacheLimit != Integer.MAX_VALUE) { + this.cacheSize -= droppedFromCache; + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Removed frames from cache to position[{}]. CacheSize[{}]", + this.side, + this.session, + this.remoteImpliedPosition, + this.cacheSize); + } + } + } + + return expectedState; + } + + void handlePendingFrames(Fuseable.QueueSubscription qs) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; + + if (empty) { + break; + } + + handleFrame(frame); + + if (!isConnected(this.state)) { break; } } - if (removeSize > 0) { - throw new IllegalStateException( - String.format( - "Local and remote state disagreement: " - + "need to remove additional %d bytes, but cache is empty", - removeSize)); - } else if (removeSize < 0) { - throw new IllegalStateException( - "Local and remote state disagreement: " + "local and remote frame sizes are not equal"); - } else { - logger.debug("{} Removed frames. Current cache size: {}", tag, cacheSize); + } + + long handlePendingConnection(long expectedState, Queue cachedFrames) { + CoreSubscriber lastActual = null; + for (; ; ) { + final CoreSubscriber nextActual = this.pendingActual; + + if (nextActual != lastActual) { + for (final ByteBuf frame : cachedFrames) { + nextActual.onNext(frame.retainedSlice()); + } + } + + expectedState = markConnected(this, expectedState); + if (isConnected(expectedState)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Connected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + impliedPosition); + } + + this.actual = nextActual; + break; + } + + if (!hasPendingConnection(expectedState)) { + break; + } + + lastActual = nextActual; } + return expectedState; + } + + static int dropFramesFromCache(long toRemoveBytes, Queue cache) { + int removedBytes = 0; + while (toRemoveBytes > removedBytes && cache.size() > 0) { + final ByteBuf cachedFrame = cache.poll(); + final int frameSize = cachedFrame.readableBytes(); + + cachedFrame.release(); + + removedBytes += frameSize; + } + + return removedBytes; } @Override public Flux resumeStream() { - return Flux.generate( - () -> new ResumeStreamState(cachedFrames.size(), upstreamFrameRefCnt), - (state, sink) -> { - if (state.next()) { - /*spsc queue has no iterator - iterating by consuming*/ - ByteBuf frame = cachedFrames.poll(); - if (state.shouldRetain(frame)) { - frame.retain(); - } - cachedFrames.offer(frame); - sink.next(frame); - } else { - sink.complete(); - logger.debug("{} Resuming stream completed", tag); - } - return state; - }); + return this; } @Override public long framePosition() { - return position; + return this.firstAvailableFramePosition; } @Override public long frameImpliedPosition() { - return impliedPosition; + return this.impliedPosition & Long.MAX_VALUE; } @Override - public void resumableFrameReceived(ByteBuf frame) { - /*called on transport thread so non-atomic on volatile is safe*/ - impliedPosition += frame.readableBytes(); + public boolean resumableFrameReceived(ByteBuf frame) { + final int frameSize = frame.readableBytes(); + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (impliedPosition < 0) { + return false; + } + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition + frameSize)) { + return true; + } + } + } + + void pauseImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition | Long.MIN_VALUE)) { + logger.debug( + "Side[{}]|Session[{}]. Paused at position[{}]", side, session, impliedPosition); + return; + } + } + } + + void resumeImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + final long restoredImpliedPosition = impliedPosition & Long.MAX_VALUE; + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, restoredImpliedPosition)) { + logger.debug( + "Side[{}]|Session[{}]. Resumed at position[{}]", + side, + session, + restoredImpliedPosition); + return; + } + } } @Override public Mono onClose() { - return disposed; + return disposed.asMono(); } @Override public void dispose() { - cacheSize = 0; - ByteBuf frame = cachedFrames.poll(); - while (frame != null) { + final long previousState = markDisposed(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | DISPOSED_FLAG); + } + + void clearCache() { + final Queue frames = this.cachedFrames; + this.cacheSize = 0; + + ByteBuf frame; + while ((frame = frames.poll()) != null) { frame.release(); - frame = cachedFrames.poll(); } - disposed.onComplete(); } @Override public boolean isDisposed() { - return disposed.isTerminated(); + return isDisposed(this.state); } - /* this method and saveFrame() won't be called concurrently, - * so non-atomic on volatile is safe*/ - private int releaseTailFrame(ByteBuf content) { - int frameSize = content.readableBytes(); - cacheSize -= frameSize; - position += frameSize; - content.release(); - return frameSize; + void handleFrame(ByteBuf frame) { + final boolean isResumable = isResumableFrame(frame); + if (isResumable) { + handleResumableFrame(frame); + return; + } + + handleConnectionFrame(frame); } - /*this method and releaseTailFrame() won't be called concurrently, - * so non-atomic on volatile is safe*/ - void saveFrame(ByteBuf frame) { - if (upstreamFrameRefCnt == 0) { - upstreamFrameRefCnt = frame.refCnt(); - } + void handleTerminated(Fuseable.QueueSubscription qs, @Nullable Throwable t) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; - int frameSize = frame.readableBytes(); - long availableSize = cacheLimit - cacheSize; - while (availableSize < frameSize) { - ByteBuf cachedFrame = cachedFrames.poll(); - if (cachedFrame != null) { - availableSize += releaseTailFrame(cachedFrame); - } else { + if (empty) { break; } + + handleFrame(frame); } - if (availableSize >= frameSize) { - cachedFrames.offer(frame.retain()); - cacheSize += frameSize; + if (t != null) { + this.actual.onError(t); } else { - position += frameSize; + this.actual.onComplete(); } } - static class ResumeStreamState { - private final int cacheSize; - private final int expectedRefCnt; - private int cacheCounter; + void handleDisposed() { + this.actual.onError(new CancellationException("Disposed")); + } - public ResumeStreamState(int cacheSize, int expectedRefCnt) { - this.cacheSize = cacheSize; - this.expectedRefCnt = expectedRefCnt; - } + void handleConnectionFrame(ByteBuf frame) { + this.actual.onNext(frame); + } - public boolean next() { - if (cacheCounter < cacheSize) { - cacheCounter++; - return true; + void handleResumableFrame(ByteBuf frame) { + final Queue frames = this.cachedFrames; + final int incomingFrameSize = frame.readableBytes(); + final int cacheLimit = this.cacheLimit; + + final boolean canBeStore; + int cacheSize = this.cacheSize; + if (cacheLimit != Integer.MAX_VALUE) { + final long availableSize = cacheLimit - cacheSize; + + if (availableSize < incomingFrameSize) { + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toRemoveBytes = incomingFrameSize - availableSize; + final int removedBytes = dropFramesFromCache(toRemoveBytes, frames); + + cacheSize = cacheSize - removedBytes; + canBeStore = removedBytes >= toRemoveBytes; + + if (canBeStore) { + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + removedBytes); + } else { + this.cacheSize = cacheSize; + FIRST_AVAILABLE_FRAME_POSITION.lazySet( + this, firstAvailableFramePosition + removedBytes + incomingFrameSize); + } } else { - return false; + canBeStore = true; + } + } else { + canBeStore = true; + } + + if (canBeStore) { + frames.offer(frame); + + if (cacheLimit != Integer.MAX_VALUE) { + this.cacheSize = cacheSize + incomingFrameSize; } } - public boolean shouldRetain(ByteBuf frame) { - return frame.refCnt() == expectedRefCnt; + this.actual.onNext(canBeStore ? frame.retainedSlice() : frame); + } + + @Override + public void request(long n) {} + + @Override + public void cancel() { + pauseImplied(); + markDisconnected(this); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Disconnected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + frameImpliedPosition()); } } - static Queue cachedFramesQueue(int size) { - return Queues.get(size).get(); + @Override + public void subscribe(CoreSubscriber actual) { + resumeImplied(); + actual.onSubscribe(this); + this.pendingActual = actual; + + final long previousState = markPendingConnection(this); + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isTerminated(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | PENDING_CONNECTION_FLAG); } - class FramesSubscriber implements Subscriber { - private final long firstRequestSize; - private final long refillSize; - private int received; - private Subscription s; + static class FramesSubscriber + implements CoreSubscriber, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final InMemoryResumableFramesStore parent; + + Fuseable.QueueSubscription qs; + + boolean done; - public FramesSubscriber(long requestSize) { - this.firstRequestSize = requestSize; - this.refillSize = firstRequestSize / 2; + FramesSubscriber(CoreSubscriber actual, InMemoryResumableFramesStore parent) { + this.actual = actual; + this.parent = parent; } @Override + @SuppressWarnings("unchecked") public void onSubscribe(Subscription s) { - this.s = s; - s.request(firstRequestSize); + if (Operators.validate(this.qs, s)) { + final Fuseable.QueueSubscription qs = (Fuseable.QueueSubscription) s; + this.qs = qs; + + final int m = qs.requestFusion(Fuseable.ANY); + + if (m != Fuseable.ASYNC) { + s.cancel(); + this.actual.onSubscribe(this); + this.actual.onError(new IllegalStateException("Source has to be ASYNC fuseable")); + return; + } + + this.actual.onSubscribe(this); + } } @Override public void onNext(ByteBuf byteBuf) { - saveFrame(byteBuf); - if (firstRequestSize != Long.MAX_VALUE && ++received == refillSize) { - received = 0; - s.request(refillSize); + final InMemoryResumableFramesStore parent = this.parent; + long previousState = InMemoryResumableFramesStore.markFrameAdded(parent); + + if (isFinalized(previousState)) { + this.qs.clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (isConnected(previousState) || hasPendingConnection(previousState)) { + parent.drain((previousState + 1) | HAS_FRAME_FLAG); } } @Override public void onError(Throwable t) { - logger.info("unexpected onError signal: {}, {}", t.getClass(), t.getMessage()); + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + parent.terminal = t; + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void cancel() { + if (this.done) { + return; + } + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain(previousState | TERMINATED_FLAG); } @Override - public void onComplete() {} + public void request(long n) {} + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public Void poll() { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public void clear() {} + } + + static long markFrameAdded(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isConnected(state) || hasPendingConnection(state) || isWorkInProgress(state)) { + nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? nextState : nextState + 1; + } + + if (STATE.compareAndSet(store, state, nextState | HAS_FRAME_FLAG)) { + return state; + } + } + } + + static long markPendingConnection(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state) || isDisposed(state) || isTerminated(state)) { + return state; + } + + if (isConnected(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : state + 1; + if (STATE.compareAndSet(store, state, nextState | PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markRemoteImpliedPositionChanged(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | REMOTE_IMPLIED_POSITION_CHANGED_FLAG)) { + return state; + } + } + } + + static long markDisconnected(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + if (STATE.compareAndSet(store, state, state & ~CONNECTED_FLAG & ~PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markWorkDone(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state & ~MAX_WORK_IN_PROGRESS & ~REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markConnected(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state ^ PENDING_CONNECTION_FLAG | CONNECTED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markTerminated(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | TERMINATED_FLAG)) { + return state; + } + } + } + + static long markDisposed(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | DISPOSED_FLAG)) { + return state; + } + } + } + + static void clearAndFinalize(InMemoryResumableFramesStore store) { + final Fuseable.QueueSubscription qs = store.framesSubscriber.qs; + for (; ; ) { + final long state = store.state; + + qs.clear(); + store.clearCache(); + + if (isFinalized(state)) { + return; + } + + if (STATE.compareAndSet(store, state, state | FINALIZED_FLAG & ~MAX_WORK_IN_PROGRESS)) { + store.disposed.tryEmitEmpty(); + store.framesSubscriber.onComplete(); + return; + } + } + } + + static boolean isConnected(long state) { + return (state & CONNECTED_FLAG) == CONNECTED_FLAG; + } + + static boolean hasRemoteImpliedPositionChanged(long state) { + return (state & REMOTE_IMPLIED_POSITION_CHANGED_FLAG) == REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + } + + static boolean hasPendingConnection(long state) { + return (state & PENDING_CONNECTION_FLAG) == PENDING_CONNECTION_FLAG; + } + + static boolean hasFrames(long state) { + return (state & HAS_FRAME_FLAG) == HAS_FRAME_FLAG; + } + + static boolean isTerminated(long state) { + return (state & TERMINATED_FLAG) == TERMINATED_FLAG; + } + + static boolean isDisposed(long state) { + return (state & DISPOSED_FLAG) == DISPOSED_FLAG; + } + + static boolean isFinalized(long state) { + return (state & FINALIZED_FLAG) == FINALIZED_FLAG; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WORK_IN_PROGRESS) > 0; } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java deleted file mode 100644 index bd447c8a9..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import java.time.Duration; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; - -/** - * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#fixedDelay(long, Duration)} to {@link - * io.rsocket.core.Resume#retry(Retry)}. - */ -@Deprecated -public class PeriodicResumeStrategy implements ResumeStrategy { - private final Duration interval; - - public PeriodicResumeStrategy(Duration interval) { - this.interval = interval; - } - - @Override - public Publisher apply(ClientResume clientResumeConfiguration, Throwable throwable) { - return Mono.delay(interval).thenReturn(toString()); - } - - @Override - public String toString() { - return "PeriodicResumeStrategy{" + "interval=" + interval + '}'; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java index 7ec0abaee..6dd3d5f4d 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java @@ -16,35 +16,10 @@ package io.rsocket.resume; -import io.netty.buffer.ByteBuf; -import io.rsocket.Closeable; -import io.rsocket.DuplexConnection; -import reactor.core.publisher.Mono; +import io.rsocket.keepalive.KeepAliveSupport; +import reactor.core.Disposable; -public interface RSocketSession extends Closeable { +public interface RSocketSession extends Disposable { - ByteBuf token(); - - ResumableDuplexConnection resumableConnection(); - - RSocketSession continueWith(T ConnectionFactory); - - RSocketSession resumeWith(ByteBuf resumeFrame); - - void reconnect(DuplexConnection connection); - - @Override - default Mono onClose() { - return resumableConnection().onClose(); - } - - @Override - default void dispose() { - resumableConnection().dispose(); - } - - @Override - default boolean isDisposed() { - return resumableConnection().isDisposed(); - } + void setKeepAliveSupport(KeepAliveSupport keepAliveSupport); } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java index 461d71228..c8811b9b3 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,433 +18,430 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Closeable; +import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; -import java.nio.channels.ClosedChannelException; -import java.time.Duration; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; -import org.reactivestreams.Publisher; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; import reactor.core.Disposable; -import reactor.core.Disposables; -import reactor.core.publisher.*; -import reactor.util.concurrent.Queues; - -public class ResumableDuplexConnection implements DuplexConnection, ResumeStateHolder { - private static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); - private static final Throwable closedChannelException = new ClosedChannelException(); - - private final String tag; - private final ResumableFramesStore resumableFramesStore; - private final Duration resumeStreamTimeout; - private final boolean cleanupOnKeepAlive; - - private final ReplayProcessor connections = ReplayProcessor.create(1); - private final EmitterProcessor connectionErrors = EmitterProcessor.create(); - private volatile DuplexConnection curConnection; - /*used instead of EmitterProcessor because its autocancel=false capability had no expected effect*/ - private final FluxProcessor downStreamFrames = ReplayProcessor.create(0); - private final FluxProcessor resumeSaveFrames = EmitterProcessor.create(); - private final MonoProcessor resumeSaveCompleted = MonoProcessor.create(); - private final Queue actions = Queues.unboundedMultiproducer().get(); - private final AtomicInteger actionsWip = new AtomicInteger(); - private final AtomicBoolean disposed = new AtomicBoolean(); - - private final Mono framesSent; - private final RequestListener downStreamRequestListener = new RequestListener(); - private final RequestListener resumeSaveStreamRequestListener = new RequestListener(); - private final UnicastProcessor> upstreams = UnicastProcessor.create(); - private final UpstreamFramesSubscriber upstreamSubscriber = - new UpstreamFramesSubscriber( - Queues.SMALL_BUFFER_SIZE, - downStreamRequestListener.requests(), - resumeSaveStreamRequestListener.requests(), - this::dispatch); - - private volatile Runnable onResume; - private volatile Runnable onDisconnect; - private volatile int state; - private volatile Disposable resumedStreamDisposable = Disposables.disposed(); +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class ResumableDuplexConnection extends Flux + implements DuplexConnection, Subscription { + + static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); + + final String side; + final String session; + final ResumableFramesStore resumableFramesStore; + + final UnboundedProcessor savableFramesSender; + final Sinks.Empty onQueueClose; + final Sinks.Empty onLastConnectionClose; + final SocketAddress remoteAddress; + final Sinks.Many onConnectionClosedSink; + + CoreSubscriber receiveSubscriber; + FrameReceivingSubscriber activeReceivingSubscriber; + + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ResumableDuplexConnection.class, "state"); + + volatile DuplexConnection activeConnection; + static final AtomicReferenceFieldUpdater + ACTIVE_CONNECTION = + AtomicReferenceFieldUpdater.newUpdater( + ResumableDuplexConnection.class, DuplexConnection.class, "activeConnection"); + + int connectionIndex = 0; public ResumableDuplexConnection( - String tag, - DuplexConnection duplexConnection, - ResumableFramesStore resumableFramesStore, - Duration resumeStreamTimeout, - boolean cleanupOnKeepAlive) { - this.tag = tag; + String side, + ByteBuf session, + DuplexConnection initialConnection, + ResumableFramesStore resumableFramesStore) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); + this.onConnectionClosedSink = Sinks.unsafe().many().unicast().onBackpressureBuffer(); this.resumableFramesStore = resumableFramesStore; - this.resumeStreamTimeout = resumeStreamTimeout; - this.cleanupOnKeepAlive = cleanupOnKeepAlive; - - resumableFramesStore - .saveFrames(resumeSaveStreamRequestListener.apply(resumeSaveFrames)) - .subscribe(resumeSaveCompleted); - - upstreams.flatMap(Function.identity()).subscribe(upstreamSubscriber); - - framesSent = - connections - .switchMap( - c -> { - logger.debug("Switching transport: {}", tag); - return c.send(downStreamRequestListener.apply(downStreamFrames)) - .doFinally( - s -> - logger.debug( - "{} Transport send completed: {}, {}", tag, s, c.toString())) - .onErrorResume(err -> Mono.never()); - }) - .then() - .cache(); - - reconnect(duplexConnection); - } + this.onQueueClose = Sinks.unsafe().empty(); + this.onLastConnectionClose = Sinks.unsafe().empty(); + this.savableFramesSender = new UnboundedProcessor(onQueueClose::tryEmitEmpty); + this.remoteAddress = initialConnection.remoteAddress(); - @Override - public ByteBufAllocator alloc() { - return curConnection.alloc(); + resumableFramesStore.saveFrames(savableFramesSender).subscribe(); + + ACTIVE_CONNECTION.lazySet(this, initialConnection); } - public void disconnect() { - DuplexConnection c = this.curConnection; - if (c != null) { - disconnect(c); + public boolean connect(DuplexConnection nextConnection) { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE + && ACTIVE_CONNECTION.compareAndSet(this, activeConnection, nextConnection)) { + + if (!activeConnection.isDisposed()) { + activeConnection.sendErrorAndClose( + new ConnectionErrorException("Connection unexpectedly replaced")); + } + + initConnection(nextConnection); + + return true; + } else { + return false; } } - public void onDisconnect(Runnable onDisconnectAction) { - this.onDisconnect = onDisconnectAction; - } + void initConnection(DuplexConnection nextConnection) { + final int nextConnectionIndex = this.connectionIndex + 1; + final FrameReceivingSubscriber frameReceivingSubscriber = + new FrameReceivingSubscriber(side, resumableFramesStore, receiveSubscriber); - public void onResume(Runnable onResumeAction) { - this.onResume = onResumeAction; - } + this.connectionIndex = nextConnectionIndex; + this.activeReceivingSubscriber = frameReceivingSubscriber; - /*reconnected by session after error. After this downstream can receive frames, - * but sending in suppressed until resume() is called*/ - public void reconnect(DuplexConnection connection) { - if (curConnection == null) { - logger.debug("{} Resumable duplex connection started with connection: {}", tag, connection); - state = State.CONNECTED; - onNewConnection(connection); - } else { + if (logger.isDebugEnabled()) { logger.debug( - "{} Resumable duplex connection reconnected with connection: {}", tag, connection); - /*race between sendFrame and doResumeStart may lead to ongoing upstream frames - written before resume complete*/ - dispatch(new ResumeStart(connection)); + "Side[{}]|Session[{}]|DuplexConnection[{}]. Connecting", side, session, connectionIndex); } + + final Disposable resumeStreamSubscription = + resumableFramesStore + .resumeStream() + .subscribe( + f -> nextConnection.sendFrame(FrameHeaderCodec.streamId(f), f), + t -> { + dispose(nextConnection, t); + nextConnection.sendErrorAndClose(new ConnectionErrorException(t.getMessage(), t)); + }, + () -> { + final ConnectionErrorException e = + new ConnectionErrorException("Connection Closed Unexpectedly"); + dispose(nextConnection, e); + nextConnection.sendErrorAndClose(e); + }); + nextConnection.receive().subscribe(frameReceivingSubscriber); + nextConnection + .onClose() + .doFinally( + __ -> { + frameReceivingSubscriber.dispose(); + resumeStreamSubscription.dispose(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Disconnected", + side, + session, + connectionIndex); + } + Sinks.EmitResult result = onConnectionClosedSink.tryEmitNext(nextConnectionIndex); + if (!result.equals(Sinks.EmitResult.OK)) { + logger.error( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Failed to notify session of closed connection: {}", + side, + session, + connectionIndex, + result); + } + }) + .subscribe(); } - /*after receiving RESUME (Server) or RESUME_OK (Client) - calculate and send resume frames */ - public void resume( - long remotePos, long remoteImpliedPos, Function, Mono> resumeFrameSent) { - /*race between sendFrame and doResume may lead to duplicate frames on resume store*/ - dispatch(new Resume(remotePos, remoteImpliedPos, resumeFrameSent)); + public void disconnect() { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE && !activeConnection.isDisposed()) { + activeConnection.dispose(); + } } @Override - public Mono sendOne(ByteBuf frame) { - return curConnection.sendOne(frame); + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + savableFramesSender.tryEmitPrioritized(frame); + } else { + savableFramesSender.tryEmitNormal(frame); + } } - @Override - public Mono send(Publisher frames) { - upstreams.onNext(Flux.from(frames)); - return framesSent; + /** + * Publisher for a sequence of integers starting at 1, with each next number emitted when the + * currently active connection is closed and should be resumed. The Publisher never emits an error + * and completes when the connection is disposed and not resumed. + */ + Flux onActiveConnectionClosed() { + return onConnectionClosedSink.asFlux(); } @Override - public Flux receive() { - return connections.switchMap( - c -> - c.receive() - .doOnNext( - f -> { - if (isResumableFrame(f)) { - resumableFramesStore.resumableFrameReceived(f); - } - }) - .onErrorResume(err -> Mono.never())); - } + public void sendErrorAndClose(RSocketErrorException rSocketErrorException) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } - public long position() { - return resumableFramesStore.framePosition(); + savableFramesSender.tryEmitFinal( + ErrorFrameCodec.encode(activeConnection.alloc(), 0, rSocketErrorException)); + + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + + final Throwable cause = rSocketErrorException.getCause(); + if (cause == null) { + onLastConnectionClose.tryEmitEmpty(); + } else { + onLastConnectionClose.tryEmitError(cause); + } + }); } @Override - public long impliedPosition() { - return resumableFramesStore.frameImpliedPosition(); + public Flux receive() { + return this; } @Override - public void onImpliedPosition(long remoteImpliedPos) { - logger.debug("Got remote position from keep-alive: {}", remoteImpliedPos); - if (cleanupOnKeepAlive) { - dispatch(new ReleaseFrames(remoteImpliedPos)); - } + public ByteBufAllocator alloc() { + return activeConnection.alloc(); } @Override public Mono onClose() { - return Flux.merge(connections.last().flatMap(Closeable::onClose), resumeSaveCompleted).then(); + return Mono.whenDelayError( + onQueueClose.asMono(), resumableFramesStore.onClose(), onLastConnectionClose.asMono()); } @Override public void dispose() { - if (disposed.compareAndSet(false, true)) { - logger.debug("Resumable connection disposed: {}, {}", tag, this); - upstreams.onComplete(); - connections.onComplete(); - connectionErrors.onComplete(); - resumeSaveFrames.onComplete(); - curConnection.dispose(); - upstreamSubscriber.dispose(); - resumedStreamDisposable.dispose(); - resumableFramesStore.dispose(); + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + savableFramesSender.onComplete(); + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }); + } + + void dispose(DuplexConnection nextConnection, @Nullable Throwable e) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; } + savableFramesSender.onComplete(); + nextConnection + .onClose() + .subscribe( + null, + t -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }, + () -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }); } @Override - public double availability() { - return curConnection.availability(); + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onQueueClose.scan(Scannable.Attr.TERMINATED) + || onQueueClose.scan(Scannable.Attr.CANCELLED); } @Override - public boolean isDisposed() { - return disposed.get(); + public SocketAddress remoteAddress() { + return remoteAddress; } - private void sendFrame(ByteBuf f) { - if (disposed.get()) { - f.release(); - return; - } - /*resuming from store so no need to save again*/ - if (state != State.RESUME && isResumableFrame(f)) { - resumeSaveFrames.onNext(f); - } - /*filter frames coming from upstream before actual resumption began, - * to preserve frames ordering*/ - if (state != State.RESUME_STARTED) { - downStreamFrames.onNext(f); + @Override + public void request(long n) { + if (state == 1 && STATE.compareAndSet(this, 1, 2)) { + // happens for the very first time with the initial connection + initConnection(this.activeConnection); } } - Flux connectionErrors() { - return connectionErrors; + @Override + public void cancel() { + dispose(); } - private void dispatch(Object action) { - actions.offer(action); - if (actionsWip.getAndIncrement() == 0) { - do { - Object a = actions.poll(); - if (a instanceof ByteBuf) { - sendFrame((ByteBuf) a); - } else { - ((Runnable) a).run(); - } - } while (actionsWip.decrementAndGet() != 0); + @Override + public void subscribe(CoreSubscriber receiverSubscriber) { + if (state == 0 && STATE.compareAndSet(this, 0, 1)) { + receiveSubscriber = receiverSubscriber; + receiverSubscriber.onSubscribe(this); } } - private void doResumeStart(DuplexConnection connection) { - state = State.RESUME_STARTED; - resumedStreamDisposable.dispose(); - upstreamSubscriber.resumeStart(); - onNewConnection(connection); + static boolean isResumableFrame(ByteBuf frame) { + return FrameHeaderCodec.streamId(frame) != 0; } - private void doResume( - long remotePosition, - long remoteImpliedPosition, - Function, Mono> sendResumeFrame) { - long localPosition = position(); - long localImpliedPosition = impliedPosition(); - - logger.debug("Resumption start"); - logger.debug( - "Resumption states. local: [pos: {}, impliedPos: {}], remote: [pos: {}, impliedPos: {}]", - localPosition, - localImpliedPosition, - remotePosition, - remoteImpliedPosition); - - long remoteImpliedPos = - calculateRemoteImpliedPos( - localPosition, localImpliedPosition, - remotePosition, remoteImpliedPosition); - - Mono impliedPositionOrError; - if (remoteImpliedPos >= 0) { - state = State.RESUME; - releaseFramesToPosition(remoteImpliedPos); - impliedPositionOrError = Mono.just(localImpliedPosition); - } else { - impliedPositionOrError = - Mono.error( - new ResumeStateException( - localPosition, localImpliedPosition, - remotePosition, remoteImpliedPosition)); - } + @Override + public String toString() { + return "ResumableDuplexConnection{" + + "side='" + + side + + '\'' + + ", session='" + + session + + '\'' + + ", remoteAddress=" + + remoteAddress + + ", state=" + + state + + ", activeConnection=" + + activeConnection + + ", connectionIndex=" + + connectionIndex + + '}'; + } + + private static final class DisposedConnection implements DuplexConnection { + + static final DisposedConnection INSTANCE = new DisposedConnection(); + + private DisposedConnection() {} - sendResumeFrame - .apply(impliedPositionOrError) - .doOnSuccess( - v -> { - Runnable r = this.onResume; - if (r != null) { - r.run(); - } - }) - .then( - streamResumedFrames( - resumableFramesStore - .resumeStream() - .timeout(resumeStreamTimeout) - .doFinally(s -> dispatch(new ResumeComplete()))) - .doOnError(err -> dispose())) - .onErrorResume(err -> Mono.empty()) - .subscribe(); - } + @Override + public void dispose() {} - static long calculateRemoteImpliedPos( - long pos, long impliedPos, long remotePos, long remoteImpliedPos) { - if (remotePos <= impliedPos && pos <= remoteImpliedPos) { - return remoteImpliedPos; - } else { - return -1L; + @Override + public Mono onClose() { + return Mono.never(); } - } - private void doResumeComplete() { - logger.debug("Completing resumption"); - state = State.RESUME_COMPLETED; - upstreamSubscriber.resumeComplete(); - } + @Override + public void sendFrame(int streamId, ByteBuf frame) {} - private Mono streamResumedFrames(Flux frames) { - return Mono.create( - s -> { - ResumeFramesSubscriber subscriber = - new ResumeFramesSubscriber( - downStreamRequestListener.requests(), this::dispatch, s::error, s::success); - s.onDispose(subscriber); - resumedStreamDisposable = subscriber; - frames.subscribe(subscriber); - }); - } + @Override + public Flux receive() { + return Flux.never(); + } - private void onNewConnection(DuplexConnection connection) { - curConnection = connection; - connection.onClose().doFinally(v -> disconnect(connection)).subscribe(); - connections.onNext(connection); - } + @Override + public void sendErrorAndClose(RSocketErrorException e) {} - private void disconnect(DuplexConnection connection) { - /*do not report late disconnects on old connection if new one is available*/ - if (curConnection == connection && state != State.DISCONNECTED) { - connection.dispose(); - state = State.DISCONNECTED; - logger.debug( - "{} Inner connection disconnected: {}", - tag, - closedChannelException.getClass().getSimpleName()); - connectionErrors.onNext(closedChannelException); - Runnable r = this.onDisconnect; - if (r != null) { - r.run(); - } + @Override + public ByteBufAllocator alloc() { + return ByteBufAllocator.DEFAULT; } - } - - /*remove frames confirmed by implied pos, - set current pos accordingly*/ - private void releaseFramesToPosition(long remoteImpliedPos) { - resumableFramesStore.releaseFrames(remoteImpliedPos); - } - static boolean isResumableFrame(ByteBuf frame) { - switch (FrameHeaderCodec.nativeFrameType(frame)) { - case REQUEST_CHANNEL: - case REQUEST_STREAM: - case REQUEST_RESPONSE: - case REQUEST_FNF: - case REQUEST_N: - case CANCEL: - case ERROR: - case PAYLOAD: - return true; - default: - return false; + @Override + @SuppressWarnings("ConstantConditions") + public SocketAddress remoteAddress() { + return null; } } - static class State { - static int CONNECTED = 0; - static int RESUME_STARTED = 1; - static int RESUME = 2; - static int RESUME_COMPLETED = 3; - static int DISCONNECTED = 4; - } + private static final class FrameReceivingSubscriber + implements CoreSubscriber, Disposable { - class ResumeStart implements Runnable { - private final DuplexConnection connection; + final ResumableFramesStore resumableFramesStore; + final CoreSubscriber actual; + final String tag; - public ResumeStart(DuplexConnection connection) { - this.connection = connection; + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + FrameReceivingSubscriber.class, Subscription.class, "s"); + + boolean cancelled; + + private FrameReceivingSubscriber( + String tag, ResumableFramesStore store, CoreSubscriber actual) { + this.tag = tag; + this.resumableFramesStore = store; + this.actual = actual; } @Override - public void run() { - doResumeStart(connection); + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } } - } - class Resume implements Runnable { - private final long remotePos; - private final long remoteImpliedPos; - private final Function, Mono> resumeFrameSent; + @Override + public void onNext(ByteBuf frame) { + if (cancelled || s == Operators.cancelledSubscription()) { + return; + } + + if (isResumableFrame(frame)) { + if (resumableFramesStore.resumableFrameReceived(frame)) { + actual.onNext(frame); + } + return; + } - public Resume( - long remotePos, long remoteImpliedPos, Function, Mono> resumeFrameSent) { - this.remotePos = remotePos; - this.remoteImpliedPos = remoteImpliedPos; - this.resumeFrameSent = resumeFrameSent; + actual.onNext(frame); } @Override - public void run() { - doResume(remotePos, remoteImpliedPos, resumeFrameSent); + public void onError(Throwable t) { + Operators.set(S, this, Operators.cancelledSubscription()); } - } - - private class ResumeComplete implements Runnable { @Override - public void run() { - doResumeComplete(); + public void onComplete() { + Operators.set(S, this, Operators.cancelledSubscription()); } - } - private class ReleaseFrames implements Runnable { - private final long remoteImpliedPos; - - public ReleaseFrames(long remoteImpliedPos) { - this.remoteImpliedPos = remoteImpliedPos; + @Override + public void dispose() { + cancelled = true; + Operators.terminate(S, this); } @Override - public void run() { - releaseFramesToPosition(remoteImpliedPos); + public boolean isDisposed() { + return cancelled || s == Operators.cancelledSubscription(); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java index 3a30544b6..80d9a36dd 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java @@ -50,6 +50,8 @@ public interface ResumableFramesStore extends Closeable { /** * Received resumable frame as defined by RSocket protocol. Implementation must increment frame * implied position + * + * @return {@code true} if information about the frame has been stored */ - void resumableFrameReceived(ByteBuf frame); + boolean resumableFrameReceived(ByteBuf frame); } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java deleted file mode 100644 index 4facdd3c1..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; - -class ResumeFramesSubscriber implements Subscriber, Disposable { - private final Flux requests; - private final Consumer onNext; - private final Consumer onError; - private final Runnable onComplete; - private final AtomicBoolean disposed = new AtomicBoolean(); - private volatile Disposable requestsDisposable; - private volatile Subscription subscription; - - public ResumeFramesSubscriber( - Flux requests, - Consumer onNext, - Consumer onError, - Runnable onComplete) { - this.requests = requests; - this.onNext = onNext; - this.onError = onError; - this.onComplete = onComplete; - } - - @Override - public void onSubscribe(Subscription s) { - if (isDisposed()) { - s.cancel(); - } else { - this.subscription = s; - this.requestsDisposable = requests.subscribe(s::request); - } - } - - @Override - public void onNext(ByteBuf frame) { - this.onNext.accept(frame); - } - - @Override - public void onError(Throwable t) { - this.onError.accept(t); - requestsDisposable.dispose(); - } - - @Override - public void onComplete() { - this.onComplete.run(); - requestsDisposable.dispose(); - } - - @Override - public void dispose() { - if (disposed.compareAndSet(false, true)) { - if (subscription != null) { - subscription.cancel(); - requestsDisposable.dispose(); - } - } - } - - @Override - public boolean isDisposed() { - return disposed.get(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java index b54ce644f..ad1b38375 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -18,143 +18,284 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.RejectedResumeException; -import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; import java.time.Duration; -import java.util.function.Function; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.FluxProcessor; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; -import reactor.core.publisher.ReplayProcessor; +import reactor.core.publisher.Operators; +import reactor.util.concurrent.Queues; -public class ServerRSocketSession implements RSocketSession { +public class ServerRSocketSession + implements RSocketSession, ResumeStateHolder, CoreSubscriber { private static final Logger logger = LoggerFactory.getLogger(ServerRSocketSession.class); - private final ResumableDuplexConnection resumableConnection; - /*used instead of EmitterProcessor because its autocancel=false capability had no expected effect*/ - private final FluxProcessor newConnections = - ReplayProcessor.create(0); - private final ByteBufAllocator allocator; - private final ByteBuf resumeToken; + final ResumableDuplexConnection resumableConnection; + final Duration resumeSessionDuration; + final ResumableFramesStore resumableFramesStore; + final String session; + final ByteBufAllocator allocator; + final boolean cleanupStoreOnKeepAlive; + + /** + * All incoming connections with the Resume intent are enqueued in this queue. Such an approach + * ensure that the new connection will affect the resumption state anyhow until the previous + * (active) connection is finally closed + */ + final Queue connectionsQueue; + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ServerRSocketSession.class, "wip"); + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ServerRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; public ServerRSocketSession( - DuplexConnection duplexConnection, + ByteBuf session, + ResumableDuplexConnection resumableDuplexConnection, + DuplexConnection initialDuplexConnection, + ResumableFramesStore resumableFramesStore, Duration resumeSessionDuration, - Duration resumeStreamTimeout, - Function resumeStoreFactory, - ByteBuf resumeToken, boolean cleanupStoreOnKeepAlive) { - this.allocator = duplexConnection.alloc(); - this.resumeToken = resumeToken; - this.resumableConnection = - new ResumableDuplexConnection( - "server", - duplexConnection, - resumeStoreFactory.apply(resumeToken), - resumeStreamTimeout, - cleanupStoreOnKeepAlive); - - Mono timeout = - resumableConnection - .connectionErrors() - .flatMap( - err -> { - logger.debug("Starting session timeout due to error", err); - return newConnections - .next() - .doOnNext(c -> logger.debug("Connection after error: {}", c)) - .timeout(resumeSessionDuration); - }) - .then() - .cast(DuplexConnection.class); - - newConnections - .mergeWith(timeout) - .subscribe( - connection -> { - reconnect(connection); - logger.debug("Server ResumableConnection reconnected: {}", connection); - }, - err -> { - logger.debug("Server ResumableConnection reconnect timeout"); - resumableConnection.dispose(); - }); + this.session = session.toString(CharsetUtil.UTF_8); + this.allocator = initialDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.resumableFramesStore = resumableFramesStore; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + this.connectionsQueue = Queues.unboundedMultiproducer().get(); + + WIP.lazySet(this, 1); + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + resumableDuplexConnection.onActiveConnectionClosed().subscribe(__ -> tryTimeoutSession()); } - @Override - public ServerRSocketSession continueWith(DuplexConnection connectionFactory) { - logger.debug("Server continued with connection: {}", connectionFactory); - newConnections.onNext(connectionFactory); - return this; + void tryTimeoutSession() { + keepAliveSupport.stop(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Connection is lost. Trying to timeout the active session", + session); + } + + Mono.delay(resumeSessionDuration).subscribe(this); + + if (WIP.decrementAndGet(this) == 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } } - @Override - public ServerRSocketSession resumeWith(ByteBuf resumeFrame) { - logger.debug("Resume FRAME received"); - long remotePos = remotePos(resumeFrame); - long remoteImpliedPos = remoteImpliedPos(resumeFrame); - resumeFrame.release(); - - resumableConnection.resume( - remotePos, - remoteImpliedPos, - pos -> - pos.flatMap(impliedPos -> sendFrame(ResumeOkFrameCodec.encode(allocator, impliedPos))) - .onErrorResume( - err -> - sendFrame(ErrorFrameCodec.encode(allocator, 0, errorFrameThrowable(err))) - .then(Mono.fromRunnable(resumableConnection::dispose)) - /*Resumption is impossible: no need to return control to ResumableConnection*/ - .then(Mono.never()))); - return this; + public void resumeWith(ByteBuf resumeFrame, DuplexConnection nextDuplexConnection) { + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. New DuplexConnection received.", session); + } + + long remotePos = ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + long remoteImpliedPos = ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + + connectionsQueue.offer(() -> doResume(remotePos, remoteImpliedPos, nextDuplexConnection)); + + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } } - @Override - public void reconnect(DuplexConnection connection) { - resumableConnection.reconnect(connection); + void doResume(long remotePos, long remoteImpliedPos, DuplexConnection nextDuplexConnection) { + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + return; + } + + long impliedPosition = resumableFramesStore.frameImpliedPosition(); + long position = resumableFramesStore.framePosition(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Resume FRAME received. ServerResumeState[impliedPosition[{}], position[{}]]. ClientResumeState[remoteImpliedPosition[{}], remotePosition[{}]]", + session, + impliedPosition, + position, + remoteImpliedPos, + remotePos); + } + + if (remotePos <= impliedPosition && position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + nextDuplexConnection.sendFrame(0, ResumeOkFrameCodec.encode(allocator, impliedPosition)); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. ResumeOKFrame[impliedPosition[{}]] has been sent", + session, + impliedPosition); + } + } catch (Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + t); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException(t.getMessage(), t); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + return; + } + + keepAliveSupport.start(); + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Session has been resumed successfully", session); + } + + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + // resumableConnection is likely to be disposed at this stage. Thus we have + // nothing to do + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}] and RemotePosition[{}] to be less or equal to LocalImpliedPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position, + remotePos, + impliedPosition); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException( + String.format( + "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", + remotePos, remoteImpliedPos, position, impliedPosition)); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + } + } + + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } + } } @Override - public ResumableDuplexConnection resumableConnection() { - return resumableConnection; + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); } @Override - public ByteBuf token() { - return resumeToken; + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } } - private Mono sendFrame(ByteBuf frame) { - logger.debug("Sending Resume frame: {}", frame); - return resumableConnection.sendOne(frame).onErrorResume(e -> Mono.empty()); + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } } - private static long remotePos(ByteBuf resumeFrame) { - return ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + @Override + public void onNext(Long aLong) { + if (!Operators.terminate(S, this)) { + return; + } + + resumableConnection.dispose(); } - private static long remoteImpliedPos(ByteBuf resumeFrame) { - return ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + @Override + public void onComplete() {} + + @Override + public void onError(Throwable t) {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; } - private static RejectedResumeException errorFrameThrowable(Throwable err) { - String msg; - if (err instanceof ResumeStateException) { - ResumeStateException resumeException = ((ResumeStateException) err); - msg = - String.format( - "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", - resumeException.getRemotePos(), - resumeException.getRemoteImpliedPos(), - resumeException.getLocalPos(), - resumeException.getLocalImpliedPos()); - } else { - msg = String.format("resume_internal_error: %s", err.getMessage()); + @Override + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Disposing session", session); } - return new RejectedResumeException(msg); + Operators.terminate(S, this); + resumableConnection.dispose(); + } + + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java index 1d5c23bd6..736d7c77c 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java @@ -17,27 +17,36 @@ package io.rsocket.resume; import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.util.annotation.Nullable; public class SessionManager { + static final Logger logger = LoggerFactory.getLogger(SessionManager.class); + private volatile boolean isDisposed; - private final Map sessions = new ConcurrentHashMap<>(); + private final Map sessions = new ConcurrentHashMap<>(); - public ServerRSocketSession save(ServerRSocketSession session) { + public ServerRSocketSession save(ServerRSocketSession session, ByteBuf resumeToken) { if (isDisposed) { session.dispose(); } else { - ByteBuf token = session.token().retain(); + final String token = resumeToken.toString(CharsetUtil.UTF_8); session + .resumableConnection .onClose() - .doOnSuccess( - v -> { + .doFinally( + __ -> { + logger.debug( + "ResumableConnection has been closed. Removing associated session {" + + token + + "}"); if (isDisposed || sessions.get(token) == session) { sessions.remove(token); } - token.release(); }) .subscribe(); ServerRSocketSession prevSession = sessions.remove(token); @@ -51,7 +60,7 @@ public ServerRSocketSession save(ServerRSocketSession session) { @Nullable public ServerRSocketSession get(ByteBuf resumeToken) { - return sessions.get(resumeToken); + return sessions.get(resumeToken.toString(CharsetUtil.UTF_8)); } public void dispose() { diff --git a/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java b/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java deleted file mode 100644 index f010a05bd..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; -import reactor.util.concurrent.Queues; - -class UpstreamFramesSubscriber implements Subscriber, Disposable { - private static final Logger logger = LoggerFactory.getLogger(UpstreamFramesSubscriber.class); - - private final AtomicBoolean disposed = new AtomicBoolean(); - private final Consumer itemConsumer; - private final Disposable downstreamRequestDisposable; - private final Disposable resumeSaveStreamDisposable; - - private volatile Subscription subs; - private volatile boolean resumeStarted; - private final Queue framesCache; - private long request; - private long downStreamRequestN; - private long resumeSaveStreamRequestN; - - UpstreamFramesSubscriber( - int estimatedDownstreamRequest, - Flux downstreamRequests, - Flux resumeSaveStreamRequests, - Consumer itemConsumer) { - this.itemConsumer = itemConsumer; - this.framesCache = Queues.unbounded(estimatedDownstreamRequest).get(); - - downstreamRequestDisposable = downstreamRequests.subscribe(requestN -> requestN(0, requestN)); - - resumeSaveStreamDisposable = - resumeSaveStreamRequests.subscribe(requestN -> requestN(requestN, 0)); - } - - @Override - public void onSubscribe(Subscription s) { - this.subs = s; - if (!isDisposed()) { - doRequest(); - } else { - s.cancel(); - } - } - - @Override - public void onNext(ByteBuf item) { - processFrame(item); - } - - @Override - public void onError(Throwable t) { - dispose(); - } - - @Override - public void onComplete() { - dispose(); - } - - public void resumeStart() { - resumeStarted = true; - } - - public void resumeComplete() { - ByteBuf frame = framesCache.poll(); - while (frame != null) { - itemConsumer.accept(frame); - frame = framesCache.poll(); - } - resumeStarted = false; - doRequest(); - } - - @Override - public void dispose() { - if (disposed.compareAndSet(false, true)) { - releaseCache(); - if (subs != null) { - subs.cancel(); - } - resumeSaveStreamDisposable.dispose(); - downstreamRequestDisposable.dispose(); - } - } - - @Override - public boolean isDisposed() { - return disposed.get(); - } - - private void requestN(long resumeStreamRequest, long downStreamRequest) { - synchronized (this) { - downStreamRequestN = Operators.addCap(downStreamRequestN, downStreamRequest); - resumeSaveStreamRequestN = Operators.addCap(resumeSaveStreamRequestN, resumeStreamRequest); - - long requests = Math.min(downStreamRequestN, resumeSaveStreamRequestN); - if (requests > 0) { - downStreamRequestN -= requests; - resumeSaveStreamRequestN -= requests; - logger.debug("Upstream subscriber requestN: {}", requests); - request = Operators.addCap(request, requests); - } - } - doRequest(); - } - - private void doRequest() { - if (subs != null && !resumeStarted) { - synchronized (this) { - long r = request; - if (r > 0) { - subs.request(r); - request = 0; - } - } - } - } - - private void releaseCache() { - ByteBuf frame = framesCache.poll(); - while (frame != null && frame.refCnt() > 0) { - frame.release(); - } - } - - private void processFrame(ByteBuf item) { - if (resumeStarted) { - framesCache.offer(item); - } else { - itemConsumer.accept(item); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java b/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java deleted file mode 100644 index 16b863d9e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport; - -import java.util.Map; -import java.util.function.Supplier; - -/** - * Extension interface to support Transports with headers at the transport layer, e.g. Websockets, - * Http2. - * - * @deprecated as of 1.0.1 in favor using properties on individual transports. - */ -@Deprecated -public interface TransportHeaderAware { - - /** - * Sets the transport headers - * - * @param transportHeaders the transport headers - * @throws NullPointerException if {@code transportHeaders} is {@code null} - */ - void setTransportHeaders(Supplier> transportHeaders); -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index 4cf33fa86..12e0b60dc 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -115,7 +115,7 @@ public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { ByteBufPayload payload = RECYCLER.get(); payload.data = data; payload.metadata = metadata; - // unsure data and metadata is set before refCnt change + // ensure data and metadata is set before refCnt change payload.setRefCnt(1); return payload; } diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index 58f282110..08b8b2fb7 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -99,13 +99,27 @@ public static Payload create(ByteBuf data) { } public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { - return create(data.nioBuffer(), metadata == null ? null : metadata.nioBuffer()); + try { + return create(toBytes(data), metadata != null ? toBytes(metadata) : null); + } finally { + data.release(); + if (metadata != null) { + metadata.release(); + } + } } public static Payload create(Payload payload) { return create( - Unpooled.copiedBuffer(payload.sliceData()), - payload.hasMetadata() ? Unpooled.copiedBuffer(payload.sliceMetadata()) : null); + toBytes(payload.data()), payload.hasMetadata() ? toBytes(payload.metadata()) : null); + } + + private static byte[] toBytes(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readableBytes()]; + byteBuf.markReaderIndex(); + byteBuf.readBytes(bytes); + byteBuf.resetReaderIndex(); + return bytes; } @Override diff --git a/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json new file mode 100644 index 000000000..0a3844451 --- /dev/null +++ b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json @@ -0,0 +1,130 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef", + "fields": [ + { + "name": "consumerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef", + "fields": [ + { + "name": "producerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.LinkedQueueNode" + }, + "name": "io.rsocket.internal.jctools.queues.LinkedQueueNode", + "fields": [ + { + "name": "next" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.UnsafeAccess" + }, + "name": "sun.misc.Unsafe", + "fields": [ + { + "name": "theUnsafe" + } + ], + "queriedMethods": [ + { + "name": "getAndAddLong", + "parameterTypes": [ + "java.lang.Object", + "long", + "long" + ] + }, + { + "name": "getAndSetObject", + "parameterTypes": [ + "java.lang.Object", + "long", + "java.lang.Object" + ] + } + ] + } +] \ No newline at end of file diff --git a/rsocket-core/src/test/java/io/rsocket/FrameAssert.java b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java new file mode 100644 index 000000000..b5b1e2ec9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java @@ -0,0 +1,336 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBe.shouldBe; +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; +import static org.assertj.core.error.ShouldNotHave.shouldNotHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.*; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class FrameAssert extends AbstractAssert { + public static FrameAssert assertThat(@Nullable ByteBuf frame) { + return new FrameAssert(frame); + } + + private final Failures failures = Failures.instance(); + + public FrameAssert(@Nullable ByteBuf frame) { + super(frame, FrameAssert.class); + } + + public FrameAssert hasMetadata() { + assertValid(); + + if (!FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public FrameAssert hasNoMetadata() { + assertValid(); + + if (FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public FrameAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public FrameAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public FrameAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + final FrameType frameType = FrameHeaderCodec.frameType(actual); + ByteBuf content; + if (frameType == FrameType.METADATA_PUSH) { + content = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.metadata(actual); + } else { + content = PayloadFrameCodec.metadata(actual); + } + + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public FrameAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public FrameAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content; + final FrameType frameType = FrameHeaderCodec.frameType(actual); + if (!frameType.canHaveData()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have data content but frame type %n<%s> does not support data content", + actual, frameType)); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.data(actual); + } else if (frameType == FrameType.ERROR) { + content = ErrorFrameCodec.data(actual); + } else { + content = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasFragmentsFollow() { + return hasFollows(true); + } + + public FrameAssert hasNoFragmentsFollow() { + return hasFollows(false); + } + + public FrameAssert hasFollows(boolean hasFollows) { + assertValid(); + + if (FrameHeaderCodec.hasFollows(actual) != hasFollows) { + throw failures.failure( + info, + hasFollows + ? shouldHave(actual, new Condition<>("follows fragment present")) + : shouldNotHave(actual, new Condition<>("follows fragment present"))); + } + + return this; + } + + public FrameAssert typeOf(FrameType frameType) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + if (currentFrameType != frameType) { + throw failures.failure( + info, shouldBe(currentFrameType, new Condition<>("frame of type [" + frameType + "]"))); + } + + return this; + } + + public FrameAssert hasStreamId(int streamId) { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId != streamId) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting streamId:%n<%s>%n to be equal %n<%s>", currentStreamId, streamId)); + } + + return this; + } + + public FrameAssert hasStreamIdZero() { + return hasStreamId(0); + } + + public FrameAssert hasClientSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId % 2 != 1) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting Client Side StreamId %nbut was " + + (currentStreamId == 0 ? "Stream Id 0" : "Server Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasServerSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId == 0 || currentStreamId % 2 != 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n Server Side Stream Id %nbut was %n " + + (currentStreamId == 0 ? "Stream Id 0" : "Client Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasPayloadSize(int payloadLength) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + + final int currentFrameLength = + actual.readableBytes() + - FrameHeaderCodec.size() + - (FrameHeaderCodec.hasMetadata(actual) && currentFrameType.canHaveData() ? 3 : 0) + - (currentFrameType.hasInitialRequestN() ? Integer.BYTES : 0); + if (currentFrameLength != payloadLength) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n<%s> %nframe payload size to be equal to %n<%s> %nbut was %n<%s>", + actual, payloadLength, currentFrameLength)); + } + + return this; + } + + public FrameAssert hasRequestN(int n) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + long requestN; + if (currentFrameType.hasInitialRequestN()) { + requestN = RequestStreamFrameCodec.initialRequestN(actual); + } else if (currentFrameType == FrameType.REQUEST_N) { + requestN = RequestNFrameCodec.requestN(actual); + } else { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have requestN but frame type %n<%s> does not support requestN", + actual, currentFrameType)); + } + + if ((requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : requestN) != n) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have %nrequestN(<%s>) but got %nrequestN(<%s>)", + actual, n, requestN)); + } + + return this; + } + + public FrameAssert hasPayload(Payload expectedPayload) { + assertValid(); + + List failedExpectation = new ArrayList<>(); + FrameType frameType = FrameHeaderCodec.frameType(actual); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(actual); + if (expectedPayload.hasMetadata() != hasMetadata) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), hasMetadata)); + } else if (hasMetadata) { + ByteBuf metadataContent; + if (frameType == FrameType.METADATA_PUSH) { + metadataContent = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + metadataContent = RequestStreamFrameCodec.metadata(actual); + } else { + metadataContent = PayloadFrameCodec.metadata(actual); + } + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), metadataContent)) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), metadataContent)); + } + } + + ByteBuf dataContent; + if (!frameType.canHaveData() && expectedPayload.sliceData().readableBytes() > 0) { + failedExpectation.add( + String.format( + "data(%s) but frame type %n<%s> does not support data", actual, frameType)); + } else { + if (frameType.hasInitialRequestN()) { + dataContent = RequestStreamFrameCodec.data(actual); + } else { + dataContent = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), dataContent)) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", expectedPayload.sliceData(), dataContent)); + } + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given payload but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + public void hasNoLeaks() { + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + + try { + FrameHeaderCodec.frameType(actual); + } catch (Throwable t) { + throw failures.failure( + info, shouldBe(actual, new Condition<>("a valid frame, but got exception [" + t + "]"))); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java new file mode 100755 index 000000000..847f24722 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java @@ -0,0 +1,180 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.ByteBufRepresentation; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class PayloadAssert extends AbstractAssert { + + public static PayloadAssert assertThat(@Nullable Payload payload) { + return new PayloadAssert(payload); + } + + private final Failures failures = Failures.instance(); + + public PayloadAssert(@Nullable Payload payload) { + super(payload, PayloadAssert.class); + } + + public PayloadAssert hasMetadata() { + assertValid(); + + if (!actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public PayloadAssert hasNoMetadata() { + assertValid(); + + if (actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public PayloadAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public PayloadAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public PayloadAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + ByteBuf content = actual.sliceMetadata(); + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public PayloadAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public PayloadAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public PayloadAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content = actual.sliceData(); + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public void hasNoLeaks() { + if (!(actual instanceof DefaultPayload)) { + if (actual.refCnt() == 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was already released", + actual, actual.refCnt())); + } + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + } + + public void isReleased() { + if (actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) but " + "actual was " + "%n", + actual, actual.refCnt())); + } + } + + @Override + public PayloadAssert isEqualTo(Object expected) { + if (expected instanceof Payload) { + if (expected == actual) { + return this; + } + + Payload expectedPayload = (Payload) expected; + List failedExpectation = new ArrayList<>(); + if (expectedPayload.hasMetadata() != actual.hasMetadata()) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), actual.hasMetadata())); + } else { + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), actual.sliceMetadata())) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), actual.sliceMetadata())); + } + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), actual.sliceData())) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", + expectedPayload.sliceData(), actual.sliceData())); + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given one but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + return super.isEqualTo(expected); + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java new file mode 100644 index 000000000..d30f1415e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java @@ -0,0 +1,6 @@ +package io.rsocket; + +public class RaceTestConstants { + public static final int REPEATS = + Integer.parseInt(System.getProperty("rsocket.test.race.repeats", "1000")); +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java index 800e5d678..1db708ab5 100644 --- a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -1,16 +1,29 @@ package io.rsocket.buffer; +import static java.util.concurrent.locks.LockSupport.parkNanos; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created * ByteBuffs */ public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); /** * Allows to instrument any given the instance of ByteBufAllocator @@ -19,24 +32,96 @@ public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { * @return */ public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { - return new LeaksTrackingByteBufAllocator(allocator); + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); } final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); final ByteBufAllocator delegate; - private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; } public LeaksTrackingByteBufAllocator assertHasNoLeaks() { try { - Assertions.assertThat(tracker) - .allSatisfy( - buf -> - Assertions.assertThat(buf) - .matches(bb -> bb.refCnt() == 0, "buffer should be released")); + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); } finally { tracker.clear(); } @@ -150,4 +235,60 @@ T track(T buffer) { return buffer; } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index 7398548be..310e15b3e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -23,45 +23,50 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; +import java.time.Duration; import org.reactivestreams.Subscriber; -public abstract class AbstractSocketRule extends ExternalResource { +public abstract class AbstractSocketRule { protected TestDuplexConnection connection; protected Subscriber connectSub; protected T socket; protected LeaksTrackingByteBufAllocator allocator; protected int maxFrameLength = FRAME_LENGTH_MASK; + protected int maxInboundPayloadSize = Integer.MAX_VALUE; - @Override - public Statement apply(final Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - connection = new TestDuplexConnection(allocator); - connectSub = TestSubscriber.create(); - init(); - base.evaluate(); - } - }; + public void init() { + allocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(5), ""); + connectSub = TestSubscriber.create(); + doInit(); } - protected void init() { + protected void doInit() { + if (connection != null) { + connection.dispose(); + } + if (socket != null) { + socket.dispose(); + } + connection = new TestDuplexConnection(allocator); socket = newRSocket(); } + public void setMaxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = maxInboundPayloadSize; + doInit(); + } + public void setMaxFrameLength(int maxFrameLength) { this.maxFrameLength = maxFrameLength; - init(); + doInit(); } protected abstract T newRSocket(); - public ByteBufAllocator alloc() { + public LeaksTrackingByteBufAllocator alloc() { return allocator; } diff --git a/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java new file mode 100644 index 000000000..195df9434 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ClientServerInputMultiplexerTest { + private TestDuplexConnection source; + private ClientServerInputMultiplexer clientMultiplexer; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private ClientServerInputMultiplexer serverMultiplexer; + + @BeforeEach + public void setup() { + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); + } + + @Test + public void clientSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + clientMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + clientMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isOne(); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isEqualTo(2); + } + + @Test + public void serverSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + serverMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + serverMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(1); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(2); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(3); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(4); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isEqualTo(4); + } + + private ByteBuf leaseFrame() { + return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf errorFrame(int i) { + return ErrorFrameCodec.encode(allocator, i, new Exception()); + } + + private ByteBuf keepAliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf metadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index 8d1d292c6..84576e6ce 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -1,6 +1,6 @@ package io.rsocket.core; /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,35 +15,31 @@ * limitations under the License. */ -import static io.rsocket.frame.FrameHeaderCodec.frameType; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; - import io.netty.buffer.ByteBuf; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketClient; -import io.rsocket.TestScheduler; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.RSocketProxy; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; import org.assertj.core.api.Assertions; @@ -54,18 +50,19 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.runners.model.Statement; +import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.SignalType; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; import reactor.util.context.Context; +import reactor.util.context.ContextView; import reactor.util.retry.Retry; public class DefaultRSocketClientTests { @@ -77,19 +74,33 @@ public void setUp() throws Throwable { Hooks.onNextDropped(ReferenceCountUtil::safeRelease); Hooks.onErrorDropped((t) -> {}); rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); } @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.allocator.assertHasNoLeaks(); + } + + @Test + @SuppressWarnings("unchecked") + void discardElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = Mockito.mock(ReferenceCounted.class); + Mockito.when(referenceCounted.release()).thenReturn(true); + Mockito.when(referenceCounted.refCnt()).thenReturn(1); + + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(referenceCounted); + + Mockito.verify(referenceCounted).release(); } static Stream interactions() { @@ -144,12 +155,26 @@ public void shouldSentFrameOnResolution( }) .then(testPublisher::complete) .then( - () -> + () -> { + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(rule.connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } else { Assertions.assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) - .matches(ReferenceCounted::release)) + .matches(ReferenceCounted::release); + } + }) .then( () -> { if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { @@ -167,19 +192,12 @@ public void shouldSentFrameOnResolution( @MethodSource("interactions") @SuppressWarnings({"unchecked", "rawtypes"}) public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfOnNextAndCancel( - BiFunction, Publisher> request, FrameType requestType) - throws Throwable { + BiFunction, Publisher> request, FrameType requestType) { Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); Payload payload = ByteBufPayload.create("test", "testMetadata"); TestPublisher testPublisher = TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); @@ -229,19 +247,12 @@ public void evaluate() {} @MethodSource("interactions") @SuppressWarnings({"unchecked", "rawtypes"}) public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfRequestAndCancel( - BiFunction, Publisher> request, FrameType requestType) - throws Throwable { + BiFunction, Publisher> request, FrameType requestType) { Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); ByteBuf dataBuffer = rule.allocator.buffer(); dataBuffer.writeCharSequence("test", CharsetUtil.UTF_8); @@ -299,14 +310,17 @@ public void shouldPropagateDownstreamContext( Payload payload = ByteBufPayload.create(dataBuffer, metadataBuffer); AssertSubscriber assertSubscriber = new AssertSubscriber(Context.of("test", "test")); - Context[] receivedContext = new Context[1]; + ContextView[] receivedContext = new Context[1]; Publisher publisher = request.apply( rule.client, Mono.just(payload) .mergeWith( - Mono.subscriberContext() - .doOnNext(c -> receivedContext[0] = c) + Mono.deferContextual( + c -> { + receivedContext[0] = c; + return Mono.empty(); + }) .then(Mono.empty()))); publisher.subscribe(assertSubscriber); @@ -395,10 +409,30 @@ public void shouldSupportMultiSubscriptionOnTheSameInteractionPublisher( assertSubscriber.await(Duration.ofSeconds(10)).assertComplete(); - Collection sent = rule.connection.getSent(); - Assertions.assertThat(sent) - .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) - .allMatch(ReferenceCounted::release); + if (requestType == FrameType.REQUEST_CHANNEL) { + ArrayList sent = new ArrayList<>(rule.connection.getSent()); + Assertions.assertThat(sent).hasSize(4); + for (int i = 0; i < sent.size(); i++) { + if (i % 2 == 0) { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } + } + } else { + Collection sent = rule.connection.getSent(); + Assertions.assertThat(sent) + .hasSize( + requestType == FrameType.REQUEST_FNF || requestType == FrameType.METADATA_PUSH + ? 1 + : 2) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } rule.allocator.assertHasNoLeaks(); } @@ -423,6 +457,8 @@ public void shouldBeAbleToResolveOriginalSource() { assertSubscriber1.assertTerminated().assertValueCount(1); Assertions.assertThat(assertSubscriber1.values()).isEqualTo(assertSubscriber.values()); + + rule.allocator.assertHasNoLeaks(); } @Test @@ -446,19 +482,173 @@ public void shouldDisposeOriginalSource() { .assertErrorMessage("Disposed"); Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldReceiveOnCloseNotificationOnDisposeOriginalSource() { + Sinks.Empty onCloseDelayer = Sinks.empty(); + ClientSocketRule rule = + new ClientSocketRule() { + @Override + protected RSocket newRSocket() { + return new RSocketProxy(super.newRSocket()) { + @Override + public Mono onClose() { + return super.onClose().and(onCloseDelayer.asMono()); + } + }; + } + }; + rule.init(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber onCloseSubscriber = AssertSubscriber.create(); + + rule.client.onClose().subscribe(onCloseSubscriber); + onCloseSubscriber.assertNotTerminated(); + + onCloseDelayer.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldResolveOnStartSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldNotStartIfAlreadyDisposed() { + Assertions.assertThat(rule.client.connect()).isTrue(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.delayer.run(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); } @Test - public void shouldDisposeOriginalSourceIfRacing() throws Throwable { - for (int i = 0; i < 10000; i++) { + public void shouldBeRestartedIfSourceWasClosed() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + AssertSubscriber terminateSubscriber = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.client.onClose().subscribe(terminateSubscriber); + + rule.delayer.run(); + + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.socket.dispose(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + terminateSubscriber.assertNotTerminated(); + Assertions.assertThat(rule.client.isDisposed()).isFalse(); + + rule.connection = new TestDuplexConnection(rule.allocator); + rule.socket = rule.newRSocket(); + rule.producer = Sinks.one(); + + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber2); + + rule.delayer.run(); + + assertSubscriber2.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + terminateSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldDisposeOriginalSourceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + + rule.init(); AssertSubscriber assertSubscriber = AssertSubscriber.create(); rule.client.source().subscribe(assertSubscriber); @@ -478,57 +668,93 @@ public void evaluate() {} .assertTerminated() .assertError(CancellationException.class) .assertErrorMessage("Disposed"); + + ByteBuf buf; + while ((buf = rule.connection.pollFrame()) != null) { + FrameAssert.assertThat(buf).hasStreamIdZero().hasData("Disposed").hasNoLeaks(); + } + + rule.allocator.assertHasNoLeaks(); } } - public static class ClientSocketRule extends AbstractSocketRule { + @Test + public void shouldStartOriginalSourceOnceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + + rule.init(); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> rule.client.source().subscribe(assertSubscriber), () -> rule.client.connect()); + + Assertions.assertThat(rule.producer.currentSubscriberCount()).isOne(); + + rule.delayer.run(); + + assertSubscriber.assertTerminated(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + assertSubscriber1.assertTerminated().assertComplete(); + + rule.allocator.assertHasNoLeaks(); + } + } + + public static class ClientSocketRule extends AbstractSocketRule { protected RSocketClient client; protected Runnable delayer; - protected MonoProcessor producer; + protected Sinks.One producer; + + protected Sinks.Empty thisClosedSink; @Override - protected void init() { - super.init(); - delayer = () -> producer.onNext(socket); - producer = MonoProcessor.create(); + protected void doInit() { + super.doInit(); + delayer = () -> producer.tryEmitValue(socket); + producer = Sinks.one(); client = new DefaultRSocketClient( - producer - .doOnCancel(() -> socket.dispose()) - .doOnDiscard(Disposable.class, Disposable::dispose)); + Mono.defer( + () -> + producer + .asMono() + .doOnCancel(() -> socket.dispose()) + .doOnDiscard(Disposable.class, Disposable::dispose))); } @Override - protected RSocketRequester newRSocket() { + protected RSocket newRSocket() { + this.thisClosedSink = Sinks.empty(); return new RSocketRequester( connection, PayloadDecoder.ZERO_COPY, StreamIdSupplier.clientSupplier(), 0, maxFrameLength, + maxInboundPayloadSize, Integer.MAX_VALUE, Integer.MAX_VALUE, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); - } - - public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); - List framesFound = new ArrayList<>(); - for (ByteBuf frame : connection.getSent()) { - FrameType frameType = frameType(frame); - if (frameType == expectedFrameType) { - return FrameHeaderCodec.streamId(frame); - } - framesFound.add(frameType); - } - throw new AssertionError( - "No frames sent with frame type: " - + expectedFrameType - + ", frames found: " - + framesFound); + __ -> null, + null, + thisClosedSink, + thisClosedSink.asMono()); } } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java new file mode 100644 index 000000000..f5422a4bf --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -0,0 +1,448 @@ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class FireAndForgetRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Payload payload = genericPayload(activeStreams.getAllocator()); + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + final ByteBuf frame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectNothing(); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends all fragments as a separate + * frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameFragmentsShouldBeSentOnSubscription( + Consumer monoConsumer) { + final int mtu = 64; + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + // should not add anything to map + streamManager.assertNoActiveStreams(); + stateAssert.isTerminated(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> frameSent() { + return Stream.of( + (s) -> StepVerifier.create(s).expectSubscription().expectComplete().verify(), + FireAndForgetRequesterMono::block); + } + + /** + * RefCnt validation test. Should send error if RefCnt is incorrect and frame has already been + * released Note: ONCE state should be 0 + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject(FrameType.REQUEST_FNF, new IllegalReferenceCountException("refCnt: 0")) + .expectNothing(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Check that proper payload size validation is enabled so in case payload fragmentation is + * disabled we will not send anything bigger that 16MB (see specification for MAX frame size) + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject( + FrameType.REQUEST_FNF, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK))) + .expectNothing(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that frame will not be sent if we dont have availability for that. Options: 1. RSocket + * disposed / Connection Error, so all racing on existing interactions should be terminated as + * well 2. RSocket tries to use lease and end-ups with no available leases + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RuntimeException exception = new RuntimeException("test"); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(exception, testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + final Payload payload = genericPayload(allocator); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor.expectOnReject(FrameType.REQUEST_FNF, exception).expectNothing(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + /** Ensures single subscription happens in case of racing */ + @Test + public void shouldSubscribeExactlyOnce1() { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + for (int i = 1; i < 50000; i += 2) { + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AtomicReference atomicReference = new AtomicReference<>(); + fireAndForgetRequesterMono.subscribe(null, atomicReference::set); + Throwable throwable = atomicReference.get(); + if (throwable != null) { + throw Exceptions.propagate(throwable); + } + }, + fireAndForgetRequesterMono::block)) + .matches( + t -> { + Assertions.assertThat(t) + .hasMessageContaining("FireAndForgetMono allows only a single Subscriber"); + return true; + }); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(i) + .hasNoLeaks(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, testRequesterResponderSupport); + + Assertions.assertThat(Scannable.from(fireAndForgetRequesterMono).name()) + .isEqualTo("source(FireAndForgetMono)"); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index 209bc3810..5be59235c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -23,26 +23,29 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.KeepAliveFrameCodec; -import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.RSocketSession; import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.util.DefaultPayload; import java.time.Duration; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.scheduler.VirtualTimeScheduler; @@ -67,19 +70,23 @@ static RSocketState requester(int tickPeriod, int timeout) { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); TestDuplexConnection connection = new TestDuplexConnection(allocator); + Sinks.Empty empty = Sinks.empty(); RSocketRequester rSocket = new RSocketRequester( connection, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, tickPeriod, timeout, - new DefaultKeepAliveHandler(connection), - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); - return new RSocketState(rSocket, allocator, connection); + new DefaultKeepAliveHandler(), + r -> null, + null, + empty, + empty.asMono()); + return new RSocketState(rSocket, allocator, connection, empty); } static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { @@ -89,24 +96,30 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { ResumableDuplexConnection resumableConnection = new ResumableDuplexConnection( "test", + Unpooled.EMPTY_BUFFER, connection, - new InMemoryResumableFramesStore("test", 10_000), - Duration.ofSeconds(10), - false); + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 10_000)); + Sinks.Empty onClose = Sinks.empty(); RSocketRequester rSocket = new RSocketRequester( resumableConnection, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, tickPeriod, timeout, - new ResumableKeepAliveHandler(resumableConnection), - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); - return new ResumableRSocketState(rSocket, connection, resumableConnection, allocator); + new ResumableKeepAliveHandler( + resumableConnection, + Mockito.mock(RSocketSession.class), + Mockito.mock(ResumeStateHolder.class)), + __ -> null, + null, + onClose, + onClose.asMono()); + return new ResumableRSocketState(rSocket, connection, resumableConnection, onClose, allocator); } @Test @@ -146,11 +159,14 @@ void noKeepAlivesSentAfterRSocketDispose() { requesterState.rSocket().dispose(); Duration duration = Duration.ofMillis(500); - StepVerifier.create(Flux.from(requesterState.connection().getSentAsPublisher()).take(duration)) - .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) - .expectComplete() - .verify(Duration.ofSeconds(1)); + virtualTimeScheduler.advanceTimeBy(duration); + + FrameAssert.assertThat(requesterState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(requesterState.connection.pollFrame()).isNull(); requesterState.allocator.assertHasNoLeaks(); } @@ -179,17 +195,18 @@ void clientRequesterSendsKeepAlives() { RSocketState RSocketState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); TestDuplexConnection connection = RSocketState.connection(); - StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(3)) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(this::keepAliveFrameWithRespondFlag) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(this::keepAliveFrameWithRespondFlag) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(this::keepAliveFrameWithRespondFlag) - .expectComplete() - .verify(Duration.ofSeconds(5)); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); RSocketState.rSocket.dispose(); + FrameAssert.assertThat(connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); RSocketState.connection.dispose(); RSocketState.allocator.assertHasNoLeaks(); @@ -207,13 +224,17 @@ void requesterRespondsToKeepAlives() { KeepAliveFrameCodec.encode( rSocketState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); - StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(1)) - .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) - .expectNextMatches(this::keepAliveFrameWithoutRespondFlag) - .expectComplete() - .verify(Duration.ofSeconds(5)); + virtualTimeScheduler.advanceTimeBy(duration); + FrameAssert.assertThat(connection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .matches(this::keepAliveFrameWithoutRespondFlag); rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); rSocketState.connection.dispose(); rSocketState.allocator.assertHasNoLeaks(); @@ -228,11 +249,9 @@ void resumableRequesterNoKeepAlivesAfterDisconnect() { resumableDuplexConnection.disconnect(); - Duration duration = Duration.ofMillis(500); - StepVerifier.create(Flux.from(testConnection.getSentAsPublisher()).take(duration)) - .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) - .expectComplete() - .verify(Duration.ofSeconds(5)); + Duration duration = Duration.ofMillis(KEEP_ALIVE_INTERVAL * 5); + virtualTimeScheduler.advanceTimeBy(duration); + Assertions.assertThat(testConnection.pollFrame()).isNull(); rSocketState.rSocket.dispose(); rSocketState.connection.dispose(); @@ -247,17 +266,28 @@ void resumableRequesterKeepAlivesAfterReconnect() { ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); resumableDuplexConnection.disconnect(); TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); - resumableDuplexConnection.reconnect(newTestConnection); - resumableDuplexConnection.resume(0, 0, ignored -> Mono.empty()); + resumableDuplexConnection.connect(newTestConnection); + // resumableDuplexConnection.(0, 0, ignored -> Mono.empty()); - StepVerifier.create(Flux.from(newTestConnection.getSentAsPublisher()).take(1)) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(frame -> keepAliveFrame(frame) && frame.release()) - .expectComplete() - .verify(Duration.ofSeconds(5)); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + + FrameAssert.assertThat(newTestConnection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .hasStreamIdZero() + .hasNoLeaks(); rSocketState.rSocket.dispose(); - rSocketState.connection.dispose(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") // API limitations + .hasNoLeaks(); + newTestConnection.dispose(); rSocketState.allocator.assertHasNoLeaks(); } @@ -274,7 +304,17 @@ void resumableRequesterNoKeepAlivesAfterDispose() { .verify(Duration.ofSeconds(5)); rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); rSocketState.connection.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") + .hasNoLeaks(); rSocketState.allocator.assertHasNoLeaks(); } @@ -315,12 +355,17 @@ static class RSocketState { private final RSocket rSocket; private final TestDuplexConnection connection; private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; public RSocketState( - RSocket rSocket, LeaksTrackingByteBufAllocator allocator, TestDuplexConnection connection) { + RSocket rSocket, + LeaksTrackingByteBufAllocator allocator, + TestDuplexConnection connection, + Sinks.Empty onClose) { this.rSocket = rSocket; this.connection = connection; this.allocator = allocator; + this.onClose = onClose; } public TestDuplexConnection connection() { @@ -341,15 +386,18 @@ static class ResumableRSocketState { private final TestDuplexConnection connection; private final ResumableDuplexConnection resumableDuplexConnection; private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; public ResumableRSocketState( RSocket rSocket, TestDuplexConnection connection, ResumableDuplexConnection resumableDuplexConnection, + Sinks.Empty onClose, LeaksTrackingByteBufAllocator allocator) { this.rSocket = rSocket; this.connection = connection; this.resumableDuplexConnection = resumableDuplexConnection; + this.onClose = onClose; this.allocator = allocator; } diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java index 1d93d9388..707d42afe 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -1,5 +1,7 @@ package io.rsocket.core; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; + import io.rsocket.Payload; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameLengthCodec; @@ -12,26 +14,45 @@ class PayloadValidationUtilsTest { @Test void shouldBeValidFrameWithNoFragmentation() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation1() { int maxFrameLength = ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); byte[] data = - new byte[maxFrameLength - FrameLengthCodec.FRAME_LENGTH_SIZE - FrameHeaderCodec.size()]; + new byte[maxFrameLength - FRAME_LENGTH_SIZE - Integer.BYTES - FrameHeaderCodec.size()]; ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); } @Test void shouldBeInValidFrameWithNoFragmentation() { int maxFrameLength = ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); - byte[] data = - new byte[maxFrameLength - FrameLengthCodec.FRAME_LENGTH_SIZE - FrameHeaderCodec.size() + 1]; + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size() + 1]; ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); } @Test @@ -41,15 +62,18 @@ void shouldBeValidFrameWithNoFragmentation0() { byte[] metadata = new byte[maxFrameLength / 2]; byte[] data = new byte - [maxFrameLength / 2 - - FrameLengthCodec.FRAME_LENGTH_SIZE + [(maxFrameLength / 2 + 1) + - FRAME_LENGTH_SIZE - FrameHeaderCodec.size() - FrameHeaderCodec.size()]; ThreadLocalRandom.current().nextBytes(data); ThreadLocalRandom.current().nextBytes(metadata); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); } @Test @@ -62,7 +86,10 @@ void shouldBeInValidFrameWithNoFragmentation1() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); } @Test @@ -75,7 +102,10 @@ void shouldBeValidFrameWithNoFragmentation2() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(0, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); } @Test @@ -88,7 +118,10 @@ void shouldBeValidFrameWithNoFragmentation3() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(64, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); } @Test @@ -101,6 +134,9 @@ void shouldBeValidFrameWithNoFragmentation4() { ThreadLocalRandom.current().nextBytes(data); final Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(PayloadValidationUtils.isValid(64, payload, maxFrameLength)).isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java index 468a13505..7cf12a81e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java @@ -1,29 +1,112 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCounted; import io.rsocket.ConnectionSetupPayload; +import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.ArrayList; import java.util.List; -import org.assertj.core.api.Assertions; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.test.StepVerifier; +import reactor.util.retry.Retry; public class RSocketConnectorTest { + @ParameterizedTest + @ValueSource(strings = {"KEEPALIVE", "REQUEST_RESPONSE"}) + public void unexpectedFramesBeforeResumeOKFrame(String frameType) { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create() + .resume(new Resume().retry(Retry.indefinitely())) + .connect(transport) + .block(); + + final TestDuplexConnection duplexConnection = transport.testConnection(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + FrameAssert.assertThat(duplexConnection.pollFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + + duplexConnection.dispose(); + + final TestDuplexConnection duplexConnection2 = transport.testConnection(); + + final ByteBuf frame; + switch (frameType) { + case "KEEPALIVE": + frame = + KeepAliveFrameCodec.encode(duplexConnection2.alloc(), false, 1, Unpooled.EMPTY_BUFFER); + break; + case "REQUEST_RESPONSE": + default: + frame = + RequestResponseFrameCodec.encode( + duplexConnection2.alloc(), 2, false, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + } + duplexConnection2.addToReceivedBuffer(frame); + + StepVerifier.create(duplexConnection2.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .typeOf(FrameType.RESUME) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("RESUME_OK frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + + transport.alloc().assertHasNoLeaks(); + } + @Test public void ensuresThatSetupPayloadCanBeRetained() { - MonoProcessor retainedSetupPayload = MonoProcessor.create(); + AtomicReference retainedSetupPayload = new AtomicReference<>(); TestClientTransport transport = new TestClientTransport(); ByteBuf data = transport.alloc().buffer(); @@ -34,13 +117,13 @@ public void ensuresThatSetupPayloadCanBeRetained() { .setupPayload(ByteBufPayload.create(data)) .acceptor( (setup, sendingSocket) -> { - retainedSetupPayload.onNext(setup.retain()); + retainedSetupPayload.set(setup.retain()); return Mono.just(new RSocket() {}); }) .connect(transport) .block(); - Assertions.assertThat(transport.testConnection().getSent()) + assertThat(transport.testConnection().getSent()) .hasSize(1) .first() .matches( @@ -55,17 +138,10 @@ public void ensuresThatSetupPayloadCanBeRetained() { return buf.refCnt() == 1; }); - retainedSetupPayload - .as(StepVerifier::create) - .expectNextMatches( - setup -> { - String dataUtf8 = setup.getDataUtf8(); - return "data".equals(dataUtf8) && setup.release(); - }) - .expectComplete() - .verify(Duration.ofSeconds(5)); - - Assertions.assertThat(retainedSetupPayload.peek().refCnt()).isZero(); + ConnectionSetupPayload setup = retainedSetupPayload.get(); + String dataUtf8 = setup.getDataUtf8(); + assertThat("data".equals(dataUtf8) && setup.release()).isTrue(); + assertThat(setup.refCnt()).isZero(); transport.alloc().assertHasNoLeaks(); } @@ -73,8 +149,13 @@ public void ensuresThatSetupPayloadCanBeRetained() { @Test public void ensuresThatMonoFromRSocketConnectorCanBeUsedForMultipleSubscriptions() { Payload setupPayload = ByteBufPayload.create("TestData", "TestMetadata"); + assertThat(setupPayload.refCnt()).isOne(); - Assertions.assertThat(setupPayload.refCnt()).isOne(); + // Keep the data and metadata around so we can try changing them independently + ByteBuf dataBuf = setupPayload.data(); + ByteBuf metadataBuf = setupPayload.metadata(); + dataBuf.retain(); + metadataBuf.retain(); TestClientTransport testClientTransport = new TestClientTransport(); Mono connectionMono = @@ -86,31 +167,59 @@ public void ensuresThatMonoFromRSocketConnectorCanBeUsedForMultipleSubscriptions .expectComplete() .verify(Duration.ofMillis(100)); + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData") + && payload.getMetadataUtf8().equals("TestMetadata"); + }) + .allMatch(ReferenceCounted::release); + connectionMono .as(StepVerifier::create) .expectNextCount(1) .expectComplete() .verify(Duration.ofMillis(100)); - Assertions.assertThat(testClientTransport.testConnection().getSent()) - .hasSize(2) + // Changing the original data and metadata should not impact the SetupPayload + dataBuf.writerIndex(dataBuf.readerIndex()); + dataBuf.writeChar('d'); + dataBuf.release(); + + metadataBuf.writerIndex(metadataBuf.readerIndex()); + metadataBuf.writeChar('m'); + metadataBuf.release(); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) .allMatch( bb -> { DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); return payload.getDataUtf8().equals("TestData") && payload.getMetadataUtf8().equals("TestMetadata"); }) - .allMatch(ReferenceCounted::release); - Assertions.assertThat(setupPayload.refCnt()).isZero(); + .allMatch( + byteBuf -> { + System.out.println("calling release " + byteBuf.refCnt()); + return byteBuf.release(); + }); + assertThat(setupPayload.refCnt()).isZero(); + + testClientTransport.alloc().assertHasNoLeaks(); } @Test public void ensuresThatSetupPayloadProvidedAsMonoIsReleased() { List saved = new ArrayList<>(); + AtomicLong subscriptions = new AtomicLong(); Mono setupPayloadMono = Mono.create( sink -> { - Payload payload = ByteBufPayload.create("TestData", "TestMetadata"); + final long subscriptionN = subscriptions.getAndIncrement(); + Payload payload = + ByteBufPayload.create("TestData" + subscriptionN, "TestMetadata" + subscriptionN); saved.add(payload); sink.success(payload); }); @@ -125,29 +234,41 @@ public void ensuresThatSetupPayloadProvidedAsMonoIsReleased() { .expectComplete() .verify(Duration.ofMillis(100)); + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData0") + && payload.getMetadataUtf8().equals("TestMetadata0"); + }) + .allMatch(ReferenceCounted::release); + connectionMono .as(StepVerifier::create) .expectNextCount(1) .expectComplete() .verify(Duration.ofMillis(100)); - Assertions.assertThat(testClientTransport.testConnection().getSent()) - .hasSize(2) + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) .allMatch( bb -> { DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); - return payload.getDataUtf8().equals("TestData") - && payload.getMetadataUtf8().equals("TestMetadata"); + return payload.getDataUtf8().equals("TestData1") + && payload.getMetadataUtf8().equals("TestMetadata1"); }) .allMatch(ReferenceCounted::release); - Assertions.assertThat(saved) + assertThat(saved) .as("Metadata and data were consumed and released as slices") .allMatch( payload -> payload.refCnt() == 1 && payload.data().refCnt() == 0 && payload.metadata().refCnt() == 0); + + testClientTransport.alloc().assertHasNoLeaks(); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 7faef600a..a461833d3 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,12 @@ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameType.*; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.LEASE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.SETUP; import static org.assertj.core.data.Offset.offset; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; @@ -30,18 +35,21 @@ import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.LeaseFrameCodec; import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.*; +import io.rsocket.lease.Lease; import io.rsocket.lease.MissingLeaseException; import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestClientTransport; @@ -52,12 +60,11 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.ArrayList; import java.util.Collection; -import java.util.Optional; import java.util.function.BiFunction; import java.util.stream.Stream; import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -65,24 +72,28 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; import org.reactivestreams.Publisher; -import reactor.core.publisher.EmitterProcessor; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; class RSocketLeaseTest { private static final String TAG = "test"; private RSocket rSocketRequester; - private ResponderLeaseHandler responderLeaseHandler; + private ResponderLeaseTracker responderLeaseTracker; private LeaksTrackingByteBufAllocator byteBufAllocator; private TestDuplexConnection connection; private RSocketResponder rSocketResponder; private RSocket mockRSocketHandler; - private EmitterProcessor leaseSender = EmitterProcessor.create(); - private Flux leaseReceiver; - private RequesterLeaseHandler requesterLeaseHandler; + private Sinks.Many leaseSender = Sinks.many().multicast().onBackpressureBuffer(); + private RequesterLeaseTracker requesterLeaseTracker; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; @BeforeEach void setUp() { @@ -90,10 +101,10 @@ void setUp() { byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); connection = new TestDuplexConnection(byteBufAllocator); - requesterLeaseHandler = new RequesterLeaseHandler.Impl(TAG, leases -> leaseReceiver = leases); - responderLeaseHandler = - new ResponderLeaseHandler.Impl<>( - TAG, byteBufAllocator, stats -> leaseSender, Optional.empty()); + requesterLeaseTracker = new RequesterLeaseTracker(TAG, 0); + responderLeaseTracker = new ResponderLeaseTracker(TAG, connection, () -> leaseSender.asFlux()); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); @@ -104,11 +115,14 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - requesterLeaseHandler, - TestScheduler.INSTANCE); + __ -> null, + requesterLeaseTracker, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); mockRSocketHandler = mock(RSocket.class); when(mockRSocketHandler.metadataPush(any())) @@ -145,7 +159,25 @@ void setUp() { Publisher payloadPublisher = a.getArgument(0); return Flux.from(payloadPublisher) .doOnNext(ReferenceCounted::release) - .thenMany(Flux.empty()); + .transform( + Operators.lift( + (__, actual) -> + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + actual.onSubscribe(this); + } + + @Override + protected void hookOnComplete() { + actual.onComplete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + actual.onError(throwable); + } + })); }); rSocketResponder = @@ -153,9 +185,17 @@ void setUp() { multiplexer.asServerConnection(), mockRSocketHandler, payloadDecoder, - responderLeaseHandler, + responderLeaseTracker, 0, - FRAME_LENGTH_MASK); + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); + } + + @AfterEach + void tearDownAndCheckForLeaks() { + byteBufAllocator.assertHasNoLeaks(); } @Test @@ -183,18 +223,26 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { Assertions.assertThat(FrameHeaderCodec.frameType(error)).isEqualTo(ERROR); Assertions.assertThat(Exceptions.from(0, error).getMessage()) .isEqualTo("lease is not supported"); + error.release(); + connection.dispose(); + transport.alloc().assertHasNoLeaks(); } @Test public void clientRSocketFactorySetsLeaseFlag() { TestClientTransport clientTransport = new TestClientTransport(); - RSocketConnector.create().lease(Leases::new).connect(clientTransport).block(); - - Collection sent = clientTransport.testConnection().getSent(); - Assertions.assertThat(sent).hasSize(1); - ByteBuf setup = sent.iterator().next(); - Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); - Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + try { + RSocketConnector.create().lease().connect(clientTransport).block(); + Collection sent = clientTransport.testConnection().getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf setup = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); + Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + setup.release(); + } finally { + clientTransport.testConnection().dispose(); + clientTransport.alloc().assertHasNoLeaks(); + } } @ParameterizedTest @@ -217,7 +265,7 @@ void requesterMissingLeaseRequestsAreRejected( void requesterPresentLeaseRequestsAreAccepted( BiFunction> interaction, FrameType frameType) { ByteBuf frame = leaseFrame(5_000, 2, Unpooled.EMPTY_BUFFER); - requesterLeaseHandler.receive(frame); + requesterLeaseTracker.handleLeaseFrame(frame); Assertions.assertThat(rSocketRequester.availability()).isCloseTo(1.0, offset(1e-2)); ByteBuf buffer = byteBufAllocator.buffer(); @@ -235,10 +283,23 @@ void requesterPresentLeaseRequestsAreAccepted( .expectComplete() .verify(Duration.ofSeconds(5)); - Assertions.assertThat(connection.getSent()) - .hasSize(1) - .first() - .matches(ReferenceCounted::release); + if (frameType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + } Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.5, offset(1e-2)); @@ -256,13 +317,13 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); ByteBuf leaseFrame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); - requesterLeaseHandler.receive(leaseFrame); + requesterLeaseTracker.handleLeaseFrame(leaseFrame); - double initialAvailability = requesterLeaseHandler.availability(); + double initialAvailability = requesterLeaseTracker.availability(); Publisher request = interaction.apply(rSocketRequester, payload1); // ensures that lease is not used until the frame is sent - Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseHandler.availability()); + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); Assertions.assertThat(connection.getSent()).hasSize(0); AssertSubscriber assertSubscriber = AssertSubscriber.create(0); @@ -271,7 +332,7 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( // if request is FNF, then request frame is sent on subscribe // otherwise we need to make request(1) if (interactionType != REQUEST_FNF) { - Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseHandler.availability()); + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); Assertions.assertThat(connection.getSent()).hasSize(0); assertSubscriber.request(1); @@ -279,11 +340,24 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( // ensures availability is changed and lease is used only up on frame sending Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); - Assertions.assertThat(connection.getSent()) - .hasSize(1) - .first() - .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) - .matches(ReferenceCounted::release); + + if (interactionType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + } ByteBuf buffer2 = byteBufAllocator.buffer(); buffer2.writeCharSequence("test", CharsetUtil.UTF_8); @@ -303,7 +377,7 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( void requesterExpiredLeaseRequestsAreRejected( BiFunction> interaction) { ByteBuf frame = leaseFrame(50, 1, Unpooled.EMPTY_BUFFER); - requesterLeaseHandler.receive(frame); + requesterLeaseTracker.handleLeaseFrame(frame); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); @@ -322,39 +396,99 @@ void requesterExpiredLeaseRequestsAreRejected( @Test void requesterAvailabilityRespectsTransport() { - requesterLeaseHandler.receive(leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER)); - double unavailable = 0.0; - connection.setAvailability(unavailable); - Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + ByteBuf frame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); + try { + + requesterLeaseTracker.handleLeaseFrame(frame); + double unavailable = 0.0; + connection.setAvailability(unavailable); + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + } finally { + frame.release(); + } } @ParameterizedTest - @MethodSource("interactions") - void responderMissingLeaseRequestsAreRejected( - BiFunction> interaction) { + @MethodSource("responderInteractions") + void responderMissingLeaseRequestsAreRejected(FrameType frameType) { ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - StepVerifier.create(interaction.apply(rSocketResponder, payload1)) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderPresentLeaseRequestsAreAccepted( - BiFunction> interaction, FrameType frameType) { - leaseSender.onNext(Lease.create(5_000, 2)); + @MethodSource("responderInteractions") + void responderPresentLeaseRequestsAreAccepted(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux.from(interaction.apply(rSocketResponder, payload1)) - .as(StepVerifier::create) - .expectComplete() - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFireAndForget(1, fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } switch (frameType) { case REQUEST_FNF: @@ -372,47 +506,119 @@ void responderPresentLeaseRequestsAreAccepted( } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderDepletedAllowedLeaseRequestsAreRejected( - BiFunction> interaction) { - leaseSender.onNext(Lease.create(5_000, 1)); + @MethodSource("responderInteractions") + void responderDepletedAllowedLeaseRequestsAreRejected(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 1)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux responder = Flux.from(interaction.apply(rSocketResponder, payload1)); - responder.subscribe(); + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test2", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf fnfFrame2 = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(fnfFrame); + rSocketResponder.handleFrame(fnfFrame2); + fnfFrame.release(); + fnfFrame2.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf requestResponseFrame2 = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(requestResponseFrame); + rSocketResponder.handleFrame(requestResponseFrame2); + requestResponseFrame.release(); + requestResponseFrame2.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + final ByteBuf requestStreamFrame2 = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, 1, payload2); + rSocketResponder.handleFrame(requestStreamFrame); + rSocketResponder.handleFrame(requestStreamFrame2); + requestStreamFrame.release(); + requestStreamFrame2.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + final ByteBuf requestChannelFrame2 = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, true, 1, payload2); + rSocketResponder.handleFrame(requestChannelFrame); + rSocketResponder.handleFrame(requestChannelFrame2); + requestChannelFrame.release(); + requestChannelFrame2.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); - ByteBuf buffer2 = byteBufAllocator.buffer(); - buffer2.writeCharSequence("test", CharsetUtil.UTF_8); - Payload payload2 = ByteBufPayload.create(buffer2); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(2) + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } - Flux.from(interaction.apply(rSocketResponder, payload2)) - .as(StepVerifier::create) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest @MethodSource("interactions") void expiredLeaseRequestsAreRejected(BiFunction> interaction) { - leaseSender.onNext(Lease.create(50, 1)); + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(50), 1)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); @@ -441,7 +647,7 @@ void sendLease() { metadata.writeCharSequence(metadataContent, utf8); int ttl = 5_000; int numberOfRequests = 2; - leaseSender.onNext(Lease.create(5_000, 2, metadata)); + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2, metadata)); ByteBuf leaseFrame = connection @@ -451,34 +657,41 @@ void sendLease() { .findFirst() .orElseThrow(() -> new IllegalStateException("Lease frame not sent")); - Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); - Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); - Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) - .isEqualTo(metadataContent); + try { + Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); + Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); + Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) + .isEqualTo(metadataContent); + } finally { + leaseFrame.release(); + } } - @Test - void receiveLease() { - Collection receivedLeases = new ArrayList<>(); - leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); - - ByteBuf metadata = byteBufAllocator.buffer(); - Charset utf8 = StandardCharsets.UTF_8; - String metadataContent = "test"; - metadata.writeCharSequence(metadataContent, utf8); - int ttl = 5_000; - int numberOfRequests = 2; - - ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); - - connection.addToReceivedBuffer(leaseFrame); - - Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); - Lease receivedLease = receivedLeases.iterator().next(); - Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); - Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); - Assertions.assertThat(receivedLease.getMetadata().toString(utf8)).isEqualTo(metadataContent); - } + // @Test + // void receiveLease() { + // Collection receivedLeases = new ArrayList<>(); + // leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); + // + // ByteBuf metadata = byteBufAllocator.buffer(); + // Charset utf8 = StandardCharsets.UTF_8; + // String metadataContent = "test"; + // metadata.writeCharSequence(metadataContent, utf8); + // int ttl = 5_000; + // int numberOfRequests = 2; + // + // ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); + // + // connection.addToReceivedBuffer(leaseFrame); + // + // Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); + // Lease receivedLease = receivedLeases.iterator().next(); + // Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); + // + // Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); + // Assertions.assertThat(receivedLease.metadata().toString(utf8)).isEqualTo(metadataContent); + // + // ReferenceCountUtil.safeRelease(leaseFrame); + // } ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { return LeaseFrameCodec.encode(byteBufAllocator, ttl, requests, metadata); @@ -500,4 +713,12 @@ static Stream interactions() { (rSocket, payload) -> rSocket.requestChannel(Mono.just(payload)), FrameType.REQUEST_CHANNEL)); } + + static Stream responderInteractions() { + return Stream.of( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java index 9ecdd13ba..966fd65f2 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -15,25 +15,25 @@ */ package io.rsocket.core; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.transport.ClientTransport; import java.io.UncheckedIOException; import java.time.Duration; import java.util.Iterator; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.mockito.Mockito; import reactor.core.Exceptions; import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; import reactor.util.retry.Retry; public class RSocketReconnectTest { @@ -42,14 +42,6 @@ public class RSocketReconnectTest { @Test public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws InterruptedException { - CountDownLatch latch = new CountDownLatch(1); - Schedulers.onScheduleHook( - "test", - r -> - () -> { - r.run(); - latch.countDown(); - }); TestClientTransport[] testClientTransport = new TestClientTransport[] {new TestClientTransport()}; Mono rSocketMono = @@ -60,29 +52,44 @@ public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws Interrupt RSocket rSocket1 = rSocketMono.block(); RSocket rSocket2 = rSocketMono.block(); - Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket1).isEqualTo(rSocket2); testClientTransport[0].testConnection().dispose(); - Assertions.assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); testClientTransport[0] = new TestClientTransport(); - System.out.println("here"); RSocket rSocket3 = rSocketMono.block(); RSocket rSocket4 = rSocketMono.block(); - Assertions.assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + rSocket3.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); } @Test - @SuppressWarnings({"rawtype", "unchecked"}) + @SuppressWarnings({"rawtype"}) public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); Mockito.when(transport.connect()) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) - .thenReturn(new TestClientTransport().connect()); + .thenReturn(transport1.connect()); Mono rSocketMono = RSocketConnector.create() .reconnect( @@ -94,25 +101,35 @@ public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono RSocket rSocket1 = rSocketMono.block(); RSocket rSocket2 = rSocketMono.block(); - Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + assertThat(rSocket1).isEqualTo(rSocket2); assertRetries( UncheckedIOException.class, UncheckedIOException.class, UncheckedIOException.class, UncheckedIOException.class); + + FrameAssert.assertThat(transport1.testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + transport1.testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + transport1.alloc().assertHasNoLeaks(); } @Test - @SuppressWarnings({"rawtype", "unchecked"}) + @SuppressWarnings({"rawtype"}) public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); Mockito.when(transport.connect()) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) - .thenReturn(new TestClientTransport().connect()); + .thenReturn(transport1.connect()); Mono rSocketMono = RSocketConnector.create() .reconnect( @@ -134,27 +151,48 @@ public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSo UncheckedIOException.class, UncheckedIOException.class, UncheckedIOException.class); + + transport1.alloc().assertHasNoLeaks(); } @Test public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { - - Mono rSocketMono = RSocketConnector.connectWith(new TestClientTransport()); + TestClientTransport transport = new TestClientTransport(); + Mono rSocketMono = RSocketConnector.connectWith(transport); RSocket rSocket1 = rSocketMono.block(); + TestDuplexConnection connection1 = transport.testConnection(); + + FrameAssert.assertThat(connection1.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + RSocket rSocket2 = rSocketMono.block(); + TestDuplexConnection connection2 = transport.testConnection(); + + assertThat(rSocket1).isNotEqualTo(rSocket2); + + FrameAssert.assertThat(connection2.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); - Assertions.assertThat(rSocket1).isNotEqualTo(rSocket2); + connection1.dispose(); + connection2.dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + rSocket2.onClose().block(Duration.ofSeconds(1)); + transport.alloc().assertHasNoLeaks(); } @SafeVarargs private final void assertRetries(Class... exceptions) { - assertEquals(exceptions.length, retries.size()); + assertThat(retries.size()).isEqualTo(exceptions.length); int index = 0; for (Iterator it = retries.iterator(); it.hasNext(); ) { Retry.RetrySignal retryContext = it.next(); - assertEquals(index, retryContext.totalRetries()); - assertEquals(exceptions[index], retryContext.failure().getClass()); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); index++; } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index 4ffd00f14..01eb998c7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -20,15 +20,15 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; import java.util.Arrays; @@ -38,11 +38,15 @@ import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.util.RaceTestUtils; class RSocketRequesterSubscribersTest { @@ -59,11 +63,20 @@ class RSocketRequesterSubscribersTest { private LeaksTrackingByteBufAllocator allocator; private RSocket rSocketRequester; private TestDuplexConnection connection; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @AfterEach + void tearDownAndCheckNoLeaks() { + allocator.assertHasNoLeaks(); + } @BeforeEach void setUp() { allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); connection = new TestDuplexConnection(allocator); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); rSocketRequester = new RSocketRequester( connection, @@ -71,16 +84,20 @@ void setUp() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); } @ParameterizedTest @MethodSource("allInteractions") - void singleSubscriber(Function> interaction) { + @SuppressWarnings({"rawtypes", "unchecked"}) + void singleSubscriber(Function> interaction, FrameType requestType) { Flux response = Flux.from(interaction.apply(rSocketRequester)); AssertSubscriber assertSubscriberA = AssertSubscriber.create(); @@ -89,17 +106,24 @@ void singleSubscriber(Function> interaction) { response.subscribe(assertSubscriberA); response.subscribe(assertSubscriberB); - connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), 1)); + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), 1)); + } assertSubscriberA.assertTerminated(); assertSubscriberB.assertTerminated(); - Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(1); + FrameAssert.assertThat(connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + + if (requestType == FrameType.REQUEST_CHANNEL) { + FrameAssert.assertThat(connection.pollFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } } @ParameterizedTest @MethodSource("allInteractions") - void singleSubscriberInCaseOfRacing(Function> interaction) { + void singleSubscriberInCaseOfRacing( + Function> interaction, FrameType requestType) { for (int i = 1; i < 20000; i += 2) { Flux response = Flux.from(interaction.apply(rSocketRequester)); AssertSubscriber assertSubscriberA = AssertSubscriber.create(); @@ -108,7 +132,9 @@ void singleSubscriberInCaseOfRacing(Function> interaction) RaceTestUtils.race( () -> response.subscribe(assertSubscriberA), () -> response.subscribe(assertSubscriberB)); - connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), i)); + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), i)); + } assertSubscriberA.assertTerminated(); assertSubscriberB.assertTerminated(); @@ -116,10 +142,23 @@ void singleSubscriberInCaseOfRacing(Function> interaction) Assertions.assertThat(new AssertSubscriber[] {assertSubscriberA, assertSubscriberB}) .anySatisfy(as -> as.assertError(IllegalStateException.class)); - Assertions.assertThat(connection.getSent()) - .hasSize(1) - .first() - .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))); + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.COMPLETE) + .matches(ByteBuf::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + } connection.clearSendReceiveBuffers(); } } @@ -139,12 +178,29 @@ static long requestFramesCount(Collection frames) { .count(); } - static Stream>> allInteractions() { + static Stream allInteractions() { return Stream.of( - rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), - rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), - rSocket -> rSocket.requestStream(DefaultPayload.create("test")), - // rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), - rSocket -> rSocket.metadataPush(DefaultPayload.create("", "test"))); + Arguments.of( + (Function>) + rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), + FrameType.REQUEST_FNF), + Arguments.of( + (Function>) + rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), + FrameType.REQUEST_RESPONSE), + Arguments.of( + (Function>) + rSocket -> rSocket.requestStream(DefaultPayload.create("test")), + FrameType.REQUEST_STREAM), + Arguments.of( + (Function>) + rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + FrameType.REQUEST_CHANNEL), + Arguments.of( + (Function>) + rSocket -> + rSocket.metadataPush( + DefaultPayload.create(new byte[0], "test".getBytes(CharsetUtil.UTF_8))), + FrameType.METADATA_PUSH)); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java index de6f86c57..5cfa76a1c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -1,45 +1,59 @@ package io.rsocket.core; +import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; +import io.rsocket.frame.FrameType; import io.rsocket.util.EmptyPayload; import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Arrays; import java.util.function.Function; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -@RunWith(Parameterized.class) public class RSocketRequesterTerminationTest { - @Rule public final ClientSocketRule rule = new ClientSocketRule(); - private Function> interaction; + public final ClientSocketRule rule = new ClientSocketRule(); - public RSocketRequesterTerminationTest(Function> interaction) { - this.interaction = interaction; + @BeforeEach + public void setup() { + rule.init(); } - @Test - public void testCurrentStreamIsTerminatedOnConnectionClose() { - RSocketRequester rSocket = rule.socket; + @AfterEach + public void tearDownAndCheckNoLeaks() { + rule.assertHasNoLeaks(); + } - Mono.delay(Duration.ofSeconds(1)).doOnNext(v -> rule.connection.dispose()).subscribe(); + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testCurrentStreamIsTerminatedOnConnectionClose( + FrameType requestType, Function> interaction) { + RSocketRequester rSocket = rule.socket; StepVerifier.create(interaction.apply(rSocket)) + .then( + () -> { + FrameAssert.assertThat(rule.connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + }) + .then(() -> rule.connection.dispose()) .expectError(ClosedChannelException.class) .verify(Duration.ofSeconds(5)); } - @Test - public void testSubsequentStreamIsTerminatedAfterConnectionClose() { + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testSubsequentStreamIsTerminatedAfterConnectionClose( + FrameType requestType, Function> interaction) { RSocketRequester rSocket = rule.socket; rule.connection.dispose(); @@ -48,14 +62,51 @@ public void testSubsequentStreamIsTerminatedAfterConnectionClose() { .verify(Duration.ofSeconds(5)); } - @Parameterized.Parameters - public static Iterable>> rsocketInteractions() { + public static Iterable rsocketInteractions() { EmptyPayload payload = EmptyPayload.INSTANCE; - Publisher payloadStream = Flux.just(payload); - Function> resp = rSocket -> rSocket.requestResponse(payload); - Function> stream = rSocket -> rSocket.requestStream(payload); - Function> channel = rSocket -> rSocket.requestChannel(payloadStream); + Arguments resp = + Arguments.of( + FrameType.REQUEST_RESPONSE, + new Function>() { + @Override + public Mono apply(RSocket rSocket) { + return rSocket.requestResponse(payload); + } + + @Override + public String toString() { + return "Request Response"; + } + }); + Arguments stream = + Arguments.of( + FrameType.REQUEST_STREAM, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestStream(payload); + } + + @Override + public String toString() { + return "Request Stream"; + } + }); + Arguments channel = + Arguments.of( + FrameType.REQUEST_CHANNEL, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestChannel(Flux.never().startWith(payload)); + } + + @Override + public String toString() { + return "Request Channel"; + } + }); return Arrays.asList(resp, stream, channel); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index cf5a164d7..a1199f698 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,23 @@ package io.rsocket.core; import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.ReassemblyUtils.ILLEGAL_REASSEMBLED_PAYLOAD_SIZE; +import static io.rsocket.core.TestRequesterResponderSupport.fixedSizePayload; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.core.TestRequesterResponderSupport.prepareFragments; +import static io.rsocket.core.TestRequesterResponderSupport.randomMetadataOnlyPayload; +import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; import static io.rsocket.frame.FrameHeaderCodec.frameType; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameType.*; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -35,9 +45,12 @@ import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; import io.rsocket.Payload; +import io.rsocket.PayloadAssert; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.exceptions.RejectedSetupException; @@ -54,7 +67,6 @@ import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; @@ -70,28 +82,26 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Stream; -import org.assertj.core.api.Assertions; import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.Scannable; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; @@ -105,26 +115,21 @@ public void setUp() throws Throwable { Hooks.onNextDropped(ReferenceCountUtil::safeRelease); Hooks.onErrorDropped((t) -> {}); rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); } @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); } @Test @Timeout(2_000) public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() { rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10)); - Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + assertThat(rule.socket.isDisposed()).isFalse(); rule.assertHasNoLeaks(); } @@ -142,19 +147,21 @@ protected void hookOnSubscribe(Subscription subscription) { }; stream.subscribe(subscriber); - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); subscriber.request(5); List sent = new ArrayList<>(rule.connection.getSent()); - assertThat("sent frame count", sent.size(), is(1)); + assertThat(sent.size()).describedAs("sent frame count").isEqualTo(1); ByteBuf f = sent.get(0); - assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); - assertThat("initial request n", RequestStreamFrameCodec.initialRequestN(f), is(5L)); - assertThat("should be released", f.release(), is(true)); + assertThat(frameType(f)).describedAs("initial frame").isEqualTo(REQUEST_STREAM); + assertThat(RequestStreamFrameCodec.initialRequestN(f)) + .describedAs("initial request n") + .isEqualTo(5L); + assertThat(f.release()).describedAs("should be released").isEqualTo(true); rule.assertHasNoLeaks(); } @@ -163,7 +170,7 @@ protected void hookOnSubscribe(Subscription subscription) { public void testHandleSetupException() { rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); - Assertions.assertThatThrownBy(() -> rule.socket.onClose().block()) + assertThatThrownBy(() -> rule.socket.onClose().block()) .isInstanceOf(RejectedSetupException.class); rule.assertHasNoLeaks(); } @@ -182,7 +189,7 @@ public void testHandleApplicationException() { verify(responseSub).onError(any(ApplicationErrorException.class)); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) // requestResponseFrame .hasSize(1) .allMatch(ReferenceCounted::release); @@ -203,7 +210,7 @@ public void testHandleValidFrame() { rule.alloc(), streamId, EmptyPayload.INSTANCE)); verify(sub).onComplete(); - Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -219,10 +226,13 @@ public void testRequestReplyWithCancel() { List sent = new ArrayList<>(rule.connection.getSent()); - assertThat( - "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); - assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); - Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + assertThat(frameType(sent.get(0))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(REQUEST_RESPONSE); + assertThat(frameType(sent.get(1))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(CANCEL); + assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -252,11 +262,11 @@ public void testRequestReplyErrorOnSend() { @Test @Timeout(2_000) public void testChannelRequestCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - Flux request = Flux.never().doOnCancel(cancelled::onComplete); + Sinks.Empty cancelled = Sinks.empty(); + Flux request = Flux.never().doOnCancel(cancelled::tryEmitEmpty); rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, + Flux.firstWithSignal( + cancelled.asMono(), Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); @@ -266,36 +276,39 @@ public void testChannelRequestCancellation() { @Test @Timeout(2_000) public void testChannelRequestCancellation2() { - MonoProcessor cancelled = MonoProcessor.create(); + Sinks.Empty cancelled = Sinks.empty(); Flux request = - Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::onComplete); + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::tryEmitEmpty); rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, + Flux.firstWithSignal( + cancelled.asMono(), Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @Test public void testChannelRequestServerSideCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - UnicastProcessor request = UnicastProcessor.create(); - request.onNext(EmptyPayload.INSTANCE); - rule.socket.requestChannel(request).subscribe(cancelled); + Sinks.One cancelled = Sinks.one(); + Sinks.Many request = Sinks.many().unicast().onBackpressureBuffer(); + request.tryEmitNext(EmptyPayload.INSTANCE); + rule.socket + .requestChannel(request.asFlux()) + .subscribe(cancelled::tryEmitValue, cancelled::tryEmitError, cancelled::tryEmitEmpty); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId)); rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId)); - Flux.first( - cancelled, + Flux.firstWithSignal( + cancelled.asMono(), Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); - Assertions.assertThat(request.isDisposed()).isTrue(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(request.scan(Scannable.Attr.TERMINATED) || request.scan(Scannable.Attr.CANCELLED)) + .isTrue(); + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> frameType(bb) == REQUEST_CHANNEL) @@ -305,7 +318,7 @@ public void testChannelRequestServerSideCancellation() { @Test public void testCorrectFrameOrder() { - MonoProcessor delayer = MonoProcessor.create(); + Sinks.One delayer = Sinks.one(); BaseSubscriber subscriber = new BaseSubscriber() { @Override @@ -313,26 +326,25 @@ protected void hookOnSubscribe(Subscription subscription) {} }; rule.socket .requestChannel( - Flux.concat(Flux.just(0).delayUntil(i -> delayer), Flux.range(1, 999)) + Flux.concat(Flux.just(0).delayUntil(i -> delayer.asMono()), Flux.range(1, 999)) .map(i -> DefaultPayload.create(i + ""))) .subscribe(subscriber); subscriber.request(1); subscriber.request(Long.MAX_VALUE); - delayer.onComplete(); + delayer.tryEmitEmpty(); Iterator iterator = rule.connection.getSent().iterator(); ByteBuf initialFrame = iterator.next(); - Assertions.assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); - Assertions.assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)) - .isEqualTo(Long.MAX_VALUE); - Assertions.assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) + assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)).isEqualTo(Long.MAX_VALUE); + assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) .isEqualTo("0"); - Assertions.assertThat(initialFrame.release()).isTrue(); + assertThat(initialFrame.release()).isTrue(); - Assertions.assertThat(iterator.hasNext()).isFalse(); + assertThat(iterator.hasNext()).isFalse(); rule.assertHasNoLeaks(); } @@ -353,14 +365,74 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen .expectSubscription() .expectErrorSatisfies( t -> - Assertions.assertThat(t) + assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .verify(); rule.assertHasNoLeaks(); }); } + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation1( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + assertThatThrownBy( + () -> { + final Publisher source = + generator.apply(rule.socket, DefaultPayload.create(data, metadata)); + + if (source instanceof Mono) { + ((Mono) source).block(); + } else { + ((Flux) source).blockLast(); + } + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength)); + + rule.assertHasNoLeaks(); + }); + } + + @Test + public void shouldRejectCallOfNoMetadataPayload() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + StepVerifier.create(rule.socket.metadataPush(payload)) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present")) + .verify(); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldRejectCallOfNoMetadataPayloadBlocking() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + + assertThatThrownBy(() -> rule.socket.metadataPush(payload).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present"); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + static Stream>> prepareCalls() { return Stream.of( RSocket::fireAndForget, @@ -370,29 +442,35 @@ static Stream>> prepareCalls() { RSocket::metadataPush); } - @Test + @ParameterizedTest + @ValueSource(ints = {128, 256, FrameLengthCodec.FRAME_LENGTH_MASK}) public void - shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { - byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; - byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; ThreadLocalRandom.current().nextBytes(metadata); ThreadLocalRandom.current().nextBytes(data); StepVerifier.create( rule.socket.requestChannel( - Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata))), + 0) .expectSubscription() + .thenRequest(2) .then( - () -> - rule.connection.addToReceivedBuffer( - RequestNFrameCodec.encode( - rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2))) + () -> { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2)); + }) .expectErrorSatisfies( t -> - Assertions.assertThat(t) + assertThat(t) .isInstanceOf(IllegalArgumentException.class) - .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .verify(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) // expect to be sent RequestChannelFrame // expect to be sent CancelFrame .hasSize(2) @@ -405,20 +483,10 @@ static Stream>> prepareCalls() { public void checkNoLeaksOnRacing( Function> initiator, BiConsumer, ClientSocketRule> runner) { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule clientSocketRule = new ClientSocketRule(); - try { - clientSocketRule - .apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); - } catch (Throwable throwable) { - throwable.printStackTrace(); - } + + clientSocketRule.init(); Publisher payloadP = initiator.apply(clientSocketRule); AssertSubscriber assertSubscriber = AssertSubscriber.create(0); @@ -431,8 +499,7 @@ public void evaluate() {} runner.accept(assertSubscriber, clientSocketRule); - Assertions.assertThat(clientSocketRule.connection.getSent()) - .allMatch(ReferenceCounted::release); + assertThat(clientSocketRule.connection.getSent()).allMatch(ReferenceCounted::release); clientSocketRule.assertHasNoLeaks(); } @@ -493,8 +560,8 @@ private static Stream racingCases() { RaceTestUtils.race(() -> as.request(1), as::cancel); // ensures proper frames order if (rule.connection.getSent().size() > 0) { - Assertions.assertThat(rule.connection.getSent()).hasSize(2); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()).hasSize(2); + assertThat(rule.connection.getSent()) .element(0) .matches( bb -> frameType(bb) == REQUEST_STREAM, @@ -503,7 +570,7 @@ private static Stream racingCases() { + "} but was {" + frameType(rule.connection.getSent().stream().findFirst().get()) + "}"); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(1) .matches( bb -> frameType(bb) == CANCEL, @@ -537,10 +604,11 @@ private static Stream racingCases() { (as, rule) -> { RaceTestUtils.race(() -> as.request(1), as::cancel); // ensures proper frames order - if (rule.connection.getSent().size() > 0) { - // - // Assertions.assertThat(rule.connection.getSent()).hasSize(2); - Assertions.assertThat(rule.connection.getSent()) + int size = rule.connection.getSent().size(); + if (size > 0) { + + assertThat(size).isLessThanOrEqualTo(3).isGreaterThanOrEqualTo(2); + assertThat(rule.connection.getSent()) .element(0) .matches( bb -> frameType(bb) == REQUEST_CHANNEL, @@ -549,16 +617,43 @@ private static Stream racingCases() { + "} but was {" + frameType(rule.connection.getSent().stream().findFirst().get()) + "}"); - Assertions.assertThat(rule.connection.getSent()) - .element(1) - .matches( - bb -> frameType(bb) == CANCEL, - "Expected first frame matches {" - + CANCEL - + "} but was {" - + frameType( - rule.connection.getSent().stream().skip(1).findFirst().get()) - + "}"); + if (size == 2) { + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected second frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } else { + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == COMPLETE || frameType(bb) == CANCEL, + "Expected second frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + assertThat(rule.connection.getSent()) + .element(2) + .matches( + bb -> frameType(bb) == CANCEL || frameType(bb) == COMPLETE, + "Expected third frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(2).findFirst().get()) + + "}"); + } } }), Arguments.of( @@ -616,7 +711,15 @@ private static Stream racingCases() { }), Arguments.of( (Function>) - (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), + (rule) -> { + ByteBuf data = rule.allocator.buffer(); + data.writeCharSequence("testData", CharsetUtil.UTF_8); + + ByteBuf metadata = rule.allocator.buffer(); + metadata.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + Payload requestPayload = ByteBufPayload.create(data, metadata); + return rule.socket.requestResponse(requestPayload); + }, (BiConsumer, ClientSocketRule>) (as, rule) -> { ByteBufAllocator allocator = rule.alloc(); @@ -630,6 +733,32 @@ private static Stream racingCases() { PayloadFrameCodec.encode( allocator, streamId, false, false, true, metadata, data); + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBuf data = rule.allocator.buffer(); + data.writeCharSequence("testData", CharsetUtil.UTF_8); + + ByteBuf metadata = rule.allocator.buffer(); + metadata.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + Payload requestPayload = ByteBufPayload.create(data, metadata); + return rule.socket.requestStream(requestPayload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, true, true, metadata, data); + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); })); } @@ -637,20 +766,25 @@ private static Stream racingCases() { @Test public void simpleOnDiscardRequestChannelTest() { AssertSubscriber assertSubscriber = AssertSubscriber.create(1); - TestPublisher testPublisher = TestPublisher.create(); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); - Flux payloadFlux = rule.socket.requestChannel(testPublisher); + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); payloadFlux.subscribe(assertSubscriber); - testPublisher.next( - ByteBufPayload.create("d", "m"), - ByteBufPayload.create("d1", "m1"), - ByteBufPayload.create("d2", "m2")); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); assertSubscriber.cancel(); - Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); rule.assertHasNoLeaks(); } @@ -659,22 +793,29 @@ public void simpleOnDiscardRequestChannelTest() { public void simpleOnDiscardRequestChannelTest2() { ByteBufAllocator allocator = rule.alloc(); AssertSubscriber assertSubscriber = AssertSubscriber.create(1); - TestPublisher testPublisher = TestPublisher.create(); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); - Flux payloadFlux = rule.socket.requestChannel(testPublisher); + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); payloadFlux.subscribe(assertSubscriber); - testPublisher.next(ByteBufPayload.create("d", "m")); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); - testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2")); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode( allocator, streamId, new CustomRSocketException(0x00000404, "test"))); - Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); rule.assertHasNoLeaks(); } @@ -692,7 +833,7 @@ public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( switch (frameType) { case REQUEST_FNF: response = - testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p).then(Mono.empty())); + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p)).then(Mono.empty()); break; case REQUEST_RESPONSE: response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); @@ -708,7 +849,7 @@ public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( } response.subscribe(assertSubscriber); - testPublisher.next(ByteBufPayload.create("d")); + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d"))); int streamId = rule.getStreamIdForRequestType(frameType); @@ -742,21 +883,21 @@ public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( } for (int i = 1; i < framesCnt; i++) { - testPublisher.next(ByteBufPayload.create("d" + i)); + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d" + i))); } - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .describedAs( "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) .hasSize(framesCnt) .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)) .allMatch(ByteBuf::release); - Assertions.assertThat(assertSubscriber.isTerminated()) + assertThat(assertSubscriber.isTerminated()) .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) .isTrue(); - Assertions.assertThat(assertSubscriber.values()) + assertThat(assertSubscriber.values()) .describedAs( "Interaction Type :[%s]. Expected to observe %s frames received", frameType, responsesCnt) @@ -779,25 +920,40 @@ static Stream encodeDecodePayloadCases() { @ParameterizedTest @MethodSource("refCntCases") public void ensureSendsErrorOnIllegalRefCntPayload( - BiFunction> sourceProducer) { - Payload invalidPayload = ByteBufPayload.create("test", "test"); + BiFunction> sourceProducer) { + Payload invalidPayload = + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "test"), + ByteBufUtil.writeUtf8(rule.alloc(), "test")); invalidPayload.release(); - Publisher source = sourceProducer.apply(invalidPayload, rule.socket); + Publisher source = sourceProducer.apply(invalidPayload, rule); - StepVerifier.create(source, 0) + StepVerifier.create(source, 1) .expectError(IllegalReferenceCountException.class) - .verify(Duration.ofMillis(100)); + .verify(Duration.ofMillis(1000)); } - private static Stream>> refCntCases() { + private static Stream>> refCntCases() { return Stream.of( - (p, r) -> r.fireAndForget(p), - (p, r) -> r.requestResponse(p), - (p, r) -> r.requestStream(p), - (p, r) -> r.requestChannel(Mono.just(p)), - (p, r) -> - r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1)))); + (p, clientSocketRule) -> clientSocketRule.socket.fireAndForget(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestResponse(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestStream(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestChannel(Mono.just(p)), + (p, clientSocketRule) -> { + Flux.from(clientSocketRule.connection.getSentAsPublisher()) + .filter(bb -> frameType(bb) == REQUEST_CHANNEL) + .doOnDiscard(ByteBuf.class, ReferenceCounted::release) + .subscribe( + bb -> { + clientSocketRule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + clientSocketRule.allocator, FrameHeaderCodec.streamId(bb), 1)); + bb.release(); + }); + + return clientSocketRule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE, p)); + }); } @Test @@ -808,12 +964,12 @@ public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { Payload payload2 = ByteBufPayload.create("abc2"); Mono fnf2 = rule.socket.fireAndForget(payload2); - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); // checks that fnf2 should have id 1 even though it was generated later than fnf1 AssertSubscriber voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0)); voidAssertSubscriber2.assertTerminated().assertNoError(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> frameType(bb) == REQUEST_FNF) @@ -831,7 +987,7 @@ public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { // checks that fnf1 should have id 3 even though it was generated earlier AssertSubscriber voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0)); voidAssertSubscriber1.assertTerminated().assertNoError(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> frameType(bb) == REQUEST_FNF) @@ -855,7 +1011,7 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( Payload payload2 = ByteBufPayload.create("abc2"); Publisher interaction2 = interaction.apply(rule, payload2); - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); AssertSubscriber assertSubscriber1 = AssertSubscriber.create(0); interaction1.subscribe(assertSubscriber1); @@ -864,13 +1020,13 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( assertSubscriber1.assertNotTerminated().assertNoError(); assertSubscriber2.assertNotTerminated().assertNoError(); // even though we subscribed, nothing should happen until the first requestN - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); // first request on the second interaction to ensure that stream id issuing on the first request assertSubscriber2.request(1); - Assertions.assertThat(rule.connection.getSent()) - .hasSize(1) + assertThat(rule.connection.getSent()) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) .first() .matches(bb -> frameType(bb) == frameType) .matches( @@ -897,11 +1053,23 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( }) .matches(ReferenceCounted::release); + if (frameType == REQUEST_CHANNEL) { + assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } + rule.connection.clearSendReceiveBuffers(); assertSubscriber1.request(1); - Assertions.assertThat(rule.connection.getSent()) - .hasSize(1) + assertThat(rule.connection.getSent()) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) .first() .matches(bb -> frameType(bb) == frameType) .matches( @@ -927,6 +1095,18 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( return false; }) .matches(ReferenceCounted::release); + + if (frameType == REQUEST_CHANNEL) { + assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } } private static Stream requestNInteractions() { @@ -947,6 +1127,7 @@ private static Stream requestNInteractions() { @ParameterizedTest @MethodSource("streamRacingCases") + @Disabled("Connection should take care of ordering if such is necessary") public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( BiFunction> interaction1, BiFunction> interaction2, @@ -954,7 +1135,7 @@ public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( FrameType interactionType2) { Assumptions.assumeThat(interactionType1).isNotEqualTo(METADATA_PUSH); Assumptions.assumeThat(interactionType2).isNotEqualTo(METADATA_PUSH); - for (int i = 1; i < 10000; i += 4) { + for (int i = 1; i < RaceTestConstants.REPEATS; i += 4) { Payload payload = DefaultPayload.create("test", "test"); Publisher publisher1 = interaction1.apply(rule, payload); Publisher publisher2 = interaction2.apply(rule, payload); @@ -962,7 +1143,7 @@ public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( () -> publisher1.subscribe(AssertSubscriber.create()), () -> publisher2.subscribe(AssertSubscriber.create())); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .extracting(FrameHeaderCodec::streamId) .containsExactly(i, i + 2); rule.connection.getSent().forEach(bb -> bb.release()); @@ -1039,7 +1220,7 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( BiFunction> interaction2, FrameType interactionType1, FrameType interactionType2) { - for (int i = 1; i < 10000; i++) { + for (int i = 1; i < RaceTestConstants.REPEATS; i++) { Payload payload1 = ByteBufPayload.create("test", "test"); Payload payload2 = ByteBufPayload.create("test", "test"); AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); @@ -1048,15 +1229,11 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( Publisher publisher2 = interaction2.apply(rule, payload2); RaceTestUtils.race( () -> rule.socket.dispose(), - () -> - RaceTestUtils.race( - () -> publisher1.subscribe(assertSubscriber1), - () -> publisher2.subscribe(assertSubscriber2), - Schedulers.parallel()), - Schedulers.parallel()); + () -> publisher1.subscribe(assertSubscriber1), + () -> publisher2.subscribe(assertSubscriber2)); assertSubscriber1.await().assertTerminated(); - if (interactionType1 != REQUEST_FNF) { + if (interactionType1 != REQUEST_FNF && interactionType1 != METADATA_PUSH) { assertSubscriber1.assertError(ClosedChannelException.class); } else { try { @@ -1067,7 +1244,7 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( } } assertSubscriber2.await().assertTerminated(); - if (interactionType2 != REQUEST_FNF) { + if (interactionType2 != REQUEST_FNF && interactionType2 != METADATA_PUSH) { assertSubscriber2.assertError(ClosedChannelException.class); } else { try { @@ -1078,11 +1255,11 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( } } - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.connection.getSent().clear(); - Assertions.assertThat(payload1.refCnt()).isZero(); - Assertions.assertThat(payload2.refCnt()).isZero(); + assertThat(payload1.refCnt()).isZero(); + assertThat(payload2.refCnt()).isZero(); } } @@ -1097,35 +1274,230 @@ public void testWorkaround858() { rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_RESPONSE) .matches(ByteBuf::release); - Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + assertThat(rule.socket.isDisposed()).isFalse(); rule.assertHasNoLeaks(); } + @DisplayName("reassembles data") + @ParameterizedTest + @MethodSource("requestNInteractions") + void reassembleData( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload randomPayload = randomPayload(leaksTrackingByteBufAllocator); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, randomPayload); + + final Publisher responsePublisher = requestFunction.apply(rule, requestPayload); + StepVerifier.create(responsePublisher) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .assertNext( + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .thenCancel() + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + if (!rule.connection.getSent().isEmpty()) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + } + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @DisplayName("reassembles metadata") + @ParameterizedTest + @MethodSource("requestNInteractions") + void reassembleMetadata( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload metadataOnlyPayload = randomMetadataOnlyPayload(leaksTrackingByteBufAllocator); + List fragments = + prepareFragments(leaksTrackingByteBufAllocator, mtu, metadataOnlyPayload); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .assertNext( + responsePayload -> { + PayloadAssert.assertThat(responsePayload).isEqualTo(metadataOnlyPayload).hasNoLeaks(); + metadataOnlyPayload.release(); + }) + .thenCancel() + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + if (!rule.connection.getSent().isEmpty()) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + } + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if reassembling payload size exceeds {0}") + @MethodSource("requestNInteractions") + public void errorTooBigPayload( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final int maxInboundPayloadSize = ThreadLocalRandom.current().nextInt(mtu + 1, 4096); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload responsePayload = + fixedSizePayload(leaksTrackingByteBufAllocator, maxInboundPayloadSize + 1); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, responsePayload); + responsePayload.release(); + + rule.setMaxInboundPayloadSize(maxInboundPayloadSize); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .expectErrorMessage(String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)) + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if fragment before the last is < min MTU {0}") + @MethodSource("requestNInteractions") + public void errorFragmentTooSmall( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = 32; + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload responsePayload = fixedSizePayload(leaksTrackingByteBufAllocator, 156); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, responsePayload); + responsePayload.release(); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .expectErrorMessage("Fragment is too small.") + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"stream", "channel"}) + // see https://github.com/rsocket/rsocket-java/issues/959 + public void testWorkaround959(String type) { + for (int i = 1; i < 20000; i += 2) { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(3); + if (type.equals("stream")) { + rule.socket.requestStream(ByteBufPayload.create(buffer)).subscribe(assertSubscriber); + } else if (type.equals("channel")) { + rule.socket + .requestChannel(Flux.just(ByteBufPayload.create(buffer))) + .subscribe(assertSubscriber); + } + + final ByteBuf payloadFrame = + PayloadFrameCodec.encode( + rule.alloc(), i, false, false, true, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame); + }, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + + assertSubscriber.values().forEach(ReferenceCountUtil::safeRelease); + assertSubscriber.assertNoError(); + + rule.connection.clearSendReceiveBuffers(); + rule.assertHasNoLeaks(); + } + } + public static class ClientSocketRule extends AbstractSocketRule { + + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + @Override protected RSocketRequester newRSocket() { + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); return new RSocketRequester( connection, PayloadDecoder.ZERO_COPY, StreamIdSupplier.clientSupplier(), 0, maxFrameLength, + maxInboundPayloadSize, Integer.MAX_VALUE, Integer.MAX_VALUE, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + (__) -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); } public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); + assertThat(connection.getSent().size()) + .describedAs("Unexpected frames sent.") + .isGreaterThanOrEqualTo(1); List framesFound = new ArrayList<>(); for (ByteBuf frame : connection.getSent()) { FrameType frameType = frameType(frame); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index 0d0fbd8c0..4f689e396 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,19 +17,24 @@ package io.rsocket.core; import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.ReassemblyUtils.ILLEGAL_REASSEMBLED_PAYLOAD_SIZE; +import static io.rsocket.core.TestRequesterResponderSupport.fixedSizePayload; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.core.TestRequesterResponderSupport.prepareFragments; +import static io.rsocket.core.TestRequesterResponderSupport.randomMetadataOnlyPayload; +import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; import static io.rsocket.frame.FrameHeaderCodec.frameType; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.COMPLETE; import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.NEXT_COMPLETE; import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; import static io.rsocket.frame.FrameType.REQUEST_FNF; import static io.rsocket.frame.FrameType.REQUEST_N; import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; import static io.rsocket.frame.FrameType.REQUEST_STREAM; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.anyOf; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; +import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -37,8 +42,11 @@ import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; import io.rsocket.Payload; +import io.rsocket.PayloadAssert; import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; @@ -52,30 +60,31 @@ import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; -import java.util.Collection; +import java.util.List; import java.util.concurrent.CancellationException; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; -import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.publisher.BaseSubscriber; @@ -84,8 +93,7 @@ import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Sinks; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; @@ -94,66 +102,65 @@ public class RSocketResponderTest { ServerSocketRule rule; @BeforeEach - public void setUp() throws Throwable { + public void setUp() { Hooks.onNextDropped(ReferenceCountUtil::safeRelease); Hooks.onErrorDropped(t -> {}); rule = new ServerSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); } @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); } @Test @Timeout(2_000) @Disabled - public void testHandleKeepAlive() throws Exception { + public void testHandleKeepAlive() { rule.connection.addToReceivedBuffer( KeepAliveFrameCodec.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); - ByteBuf sent = rule.connection.awaitSend(); - assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); + ByteBuf sent = rule.connection.awaitFrame(); + assertThat(frameType(sent)) + .describedAs("Unexpected frame sent.") + .isEqualTo(FrameType.KEEPALIVE); /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ - assertThat( - "Unexpected keep-alive frame respond flag.", - KeepAliveFrameCodec.respondFlag(sent), - is(false)); + assertThat(KeepAliveFrameCodec.respondFlag(sent)) + .describedAs("Unexpected keep-alive frame respond flag.") + .isEqualTo(false); } @Test @Timeout(2_000) - @Disabled - public void testHandleResponseFrameNoError() throws Exception { + public void testHandleResponseFrameNoError() { final int streamId = 4; rule.connection.clearSendReceiveBuffers(); - + final TestPublisher testPublisher = TestPublisher.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return testPublisher.mono(); + } + }); rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - Collection> sendSubscribers = rule.connection.getSendSubscribers(); - assertThat("Request not sent.", sendSubscribers, hasSize(1)); - Subscriber sendSub = sendSubscribers.iterator().next(); - assertThat( - "Unexpected frame sent.", - frameType(rule.connection.awaitSend()), - anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); + testPublisher.complete(); + FrameAssert.assertThat(rule.connection.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + testPublisher.assertWasNotCancelled(); } @Test @Timeout(2_000) - @Disabled - public void testHandlerEmitsError() throws Exception { + public void testHandlerEmitsError() { final int streamId = 4; + rule.prefetch = 1; rule.sendRequest(streamId, FrameType.REQUEST_STREAM); - assertThat( - "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("Request-Stream not implemented.") + .hasNoLeaks(); } @Test @@ -172,12 +179,12 @@ public Mono requestResponse(Payload payload) { }); rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(allocator, streamId)); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); rule.assertHasNoLeaks(); } @@ -233,14 +240,17 @@ protected void hookOnSubscribe(Subscription subscription) { for (Runnable runnable : runnables) { rule.connection.clearSendReceiveBuffers(); runnable.run(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) - .matches(bb -> ErrorFrameCodec.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)) + .matches( + bb -> + ErrorFrameCodec.dataUtf8(bb) + .contains(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .matches(ReferenceCounted::release); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); } rule.assertHasNoLeaks(); @@ -249,15 +259,18 @@ protected void hookOnSubscribe(Subscription subscription) { @Test public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); + final Sinks.One sink = Sinks.one(); rule.setAcceptingSocket( new RSocket() { @Override public Flux requestChannel(Publisher payloads) { payloads.subscribe(assertSubscriber); - return Flux.never(); + return sink.asMono().flux(); } }, Integer.MAX_VALUE); @@ -283,19 +296,21 @@ public Flux requestChannel(Publisher payloads) { ByteBuf data3 = allocator.buffer(); data3.writeCharSequence("def3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = - PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + PayloadFrameCodec.encode(allocator, 1, false, true, true, metadata3, data3); RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), () -> { - rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); - }, - assertSubscriber::cancel); + assertSubscriber.cancel(); + sink.tryEmitEmpty(); + }); - Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnComplete(1).expectNothing(); } } @@ -303,7 +318,9 @@ public Flux requestChannel(Publisher payloads) { public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); FluxSink[] sinks = new FluxSink[1]; @@ -330,20 +347,23 @@ public Flux requestChannel(Publisher payloads) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); } } @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); FluxSink[] sinks = new FluxSink[1]; @@ -366,20 +386,17 @@ public Flux requestChannel(Publisher payloads) { ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); FluxSink sink = sinks[0]; RaceTestUtils.race( - () -> - RaceTestUtils.race( - () -> rule.connection.addToReceivedBuffer(requestNFrame), - () -> rule.connection.addToReceivedBuffer(cancelFrame), - parallel), + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), () -> { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); - }, - parallel); - - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + sink.complete(); + }); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); rule.assertHasNoLeaks(); } } @@ -387,10 +404,11 @@ public Flux requestChannel(Publisher payloads) { @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { FluxSink[] sinks = new FluxSink[1]; AssertSubscriber assertSubscriber = AssertSubscriber.create(); rule.setAcceptingSocket( @@ -453,41 +471,39 @@ public Flux requestChannel(Publisher payloads) { FluxSink sink = sinks[0]; RaceTestUtils.race( - () -> - RaceTestUtils.race( - () -> rule.connection.addToReceivedBuffer(requestNFrame), - () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), - parallel), + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), () -> { sink.next(np1); sink.next(np2); sink.next(np3); sink.error(new RuntimeException()); - }, - parallel); + }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); assertSubscriber .assertTerminated() .assertError(CancellationException.class) - .assertErrorMessage("Disposed"); - Assertions.assertThat(assertSubscriber.values()) + .assertErrorMessage("Outbound has terminated with an error"); + assertThat(assertSubscriber.values()) .allMatch( msg -> { ReferenceCountUtil.safeRelease(msg); return msg.refCnt() == 0; }); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnError(1).expectNothing(); } } @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { FluxSink[] sinks = new FluxSink[1]; rule.setAcceptingSocket( @@ -510,21 +526,23 @@ public Flux requestStream(Payload payload) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); - }, - parallel); + }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor.expectOnStart(1, REQUEST_STREAM).expectOnCancel(1).expectNothing(); } } @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; rule.setAcceptingSocket( @@ -550,12 +568,21 @@ public void subscribe(CoreSubscriber actual) { () -> rule.connection.addToReceivedBuffer(cancelFrame), () -> { sources[0].complete(ByteBufPayload.create("d1", "m1")); - }, - parallel); + }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, REQUEST_RESPONSE) + .assertNext( + e -> + assertThat(e.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_CANCEL)) + .expectNothing(); } } @@ -584,7 +611,7 @@ public Flux requestStream(Payload payload) { sink.next(ByteBufPayload.create("d3", "m3")); rule.connection.addToReceivedBuffer(cancelFrame); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -630,7 +657,7 @@ public Flux requestChannel(Publisher payloads) { rule.connection.addToReceivedBuffer(cancelFrame); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -700,8 +727,7 @@ public Flux requestChannel(Publisher payloads) { } if (responsesCnt > 0) { - Assertions.assertThat( - rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) .describedAs( "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) .hasSize(responsesCnt) @@ -709,8 +735,7 @@ public Flux requestChannel(Publisher payloads) { } if (framesCnt > 1) { - Assertions.assertThat( - rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) .describedAs( "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", frameType, framesCnt - 1) @@ -719,9 +744,9 @@ public Flux requestChannel(Publisher payloads) { .matches(bb -> RequestNFrameCodec.requestN(bb) == (framesCnt - 1)); } - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); - Assertions.assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) .hasSize(framesCnt) .allMatch(p -> !p.hasMetadata()) .allMatch(ReferenceCounted::release); @@ -766,7 +791,7 @@ public Flux requestChannel(Publisher payloads) { rule.sendRequest(1, frameType); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches( @@ -775,7 +800,8 @@ public Flux requestChannel(Publisher payloads) { + ERROR + "} but was {" + frameType(rule.connection.getSent().iterator().next()) - + "}"); + + "}") + .matches(ByteBuf::release); } private static Stream refCntCases() { @@ -806,25 +832,362 @@ public Flux requestChannel(Publisher payloads) { rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_N) .matches(ReferenceCounted::release); - Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + assertThat(rule.socket.isDisposed()).isFalse(); testPublisher.assertWasCancelled(); rule.assertHasNoLeaks(); } + static Stream requestCases() { + return Stream.of(REQUEST_FNF, REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + @DisplayName("reassembles payload") + @ParameterizedTest + @MethodSource("requestCases") + void reassemblePayload(FrameType frameType) { + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final Payload randomPayload = randomPayload(rule.allocator); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(frameType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasNoLeaks(); + if (frameType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @DisplayName("reassembles metadata") + @ParameterizedTest + @MethodSource("requestCases") + void reassembleMetadataOnly(FrameType frameType) { + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final Payload randomMetadataOnlyPayload = randomMetadataOnlyPayload(rule.allocator); + List fragments = + prepareFragments(rule.allocator, mtu, randomMetadataOnlyPayload, frameType); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()) + .isEqualTo(randomMetadataOnlyPayload) + .hasNoLeaks(); + randomMetadataOnlyPayload.release(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(frameType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasNoLeaks(); + if (frameType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if reassembling payload size exceeds {0}") + @MethodSource("requestCases") + public void errorTooBigPayload(FrameType frameType) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final int maxInboundPayloadSize = ThreadLocalRandom.current().nextInt(mtu + 1, 4096); + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setMaxInboundPayloadSize(maxInboundPayloadSize); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + final Payload randomPayload = fixedSizePayload(rule.allocator, maxInboundPayloadSize + 1); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + randomPayload.release(); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isNull(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(ERROR) + .hasData( + "Failed to reassemble payload. Cause: " + + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)) + .hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if fragment before the last is < min MTU {0}") + @MethodSource("requestCases") + public void errorFragmentTooSmall(FrameType frameType) { + final int mtu = 32; + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + final Payload randomPayload = fixedSizePayload(rule.allocator, 156); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + randomPayload.release(); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isNull(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(ERROR) + .hasData("Failed to reassemble payload. Cause: Fragment is too small.") + .hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("requestCases") + void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored_ReassemblyCase( + FrameType requestType) { + AtomicReference receivedPayload = new AtomicReference<>(); + final Sinks.Empty delayer = Sinks.empty(); + rule.setAcceptingSocket( + new RSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return delayer.asMono(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + }); + final Payload randomPayload1 = fixedSizePayload(rule.allocator, 128); + final List fragments1 = + prepareFragments(rule.allocator, 64, randomPayload1, requestType); + final Payload randomPayload2 = fixedSizePayload(rule.allocator, 128); + final List fragments2 = + prepareFragments(rule.allocator, 64, randomPayload2, requestType); + randomPayload2.release(); + rule.connection.addToReceivedBuffer(fragments1.remove(0)); + rule.connection.addToReceivedBuffer(fragments2.remove(0)); + + rule.connection.addToReceivedBuffer(fragments1.toArray(new ByteBuf[0])); + if (requestType != REQUEST_CHANNEL) { + rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); + delayer.tryEmitEmpty(); + } else { + delayer.tryEmitEmpty(); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.allocator, 1)); + rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); + } + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); + randomPayload1.release(); + + if (requestType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(requestType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasNoLeaks(); + + if (requestType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("requestCases") + void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored(FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(REQUEST_FNF); + AtomicReference receivedPayload = new AtomicReference<>(); + final Sinks.One delayer = Sinks.one(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + }); + final Payload randomPayload1 = fixedSizePayload(rule.allocator, 64); + final Payload randomPayload2 = fixedSizePayload(rule.allocator, 64); + rule.sendRequest(1, requestType, randomPayload1.retain()); + rule.sendRequest(1, requestType, randomPayload2); + + delayer.tryEmitEmpty(); + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); + randomPayload1.release(); + + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(requestType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasNoLeaks(); + + if (requestType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + public static class ServerSocketRule extends AbstractSocketRule { private RSocket acceptingSocket; private volatile int prefetch; + private RequestInterceptor requestInterceptor; + protected Sinks.Empty onCloseSink; @Override - protected void init() { + protected void doInit() { acceptingSocket = new RSocket() { @Override @@ -832,7 +1195,7 @@ public Mono requestResponse(Payload payload) { return Mono.just(payload); } }; - super.init(); + super.doInit(); } public void setAcceptingSocket(RSocket acceptingSocket) { @@ -840,7 +1203,12 @@ public void setAcceptingSocket(RSocket acceptingSocket) { connection = new TestDuplexConnection(alloc()); connectSub = TestSubscriber.create(); this.prefetch = Integer.MAX_VALUE; - super.init(); + super.doInit(); + } + + public void setRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + super.doInit(); } public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { @@ -848,18 +1216,22 @@ public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { connection = new TestDuplexConnection(alloc()); connectSub = TestSubscriber.create(); this.prefetch = prefetch; - super.init(); + super.doInit(); } @Override protected RSocketResponder newRSocket() { + onCloseSink = Sinks.empty(); return new RSocketResponder( connection, acceptingSocket, PayloadDecoder.ZERO_COPY, - ResponderLeaseHandler.None, + null, 0, - maxFrameLength); + maxFrameLength, + maxInboundPayloadSize, + __ -> requestInterceptor, + onCloseSink); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java index 073ebfd06..90e881257 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -1,9 +1,12 @@ package io.rsocket.core; +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; import io.rsocket.test.util.TestClientTransport; import io.rsocket.test.util.TestServerTransport; import org.assertj.core.api.Assertions; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class RSocketServerFragmentationTest { @@ -16,12 +19,18 @@ public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { @Test public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { - RSocketServer.create().fragment(100).bind(new TestServerTransport()).block(); + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().fragment(100).bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); } @Test public void serverSucceedsWithDisabledFragmentation() { - RSocketServer.create().bind(new TestServerTransport()).block(); + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); } @Test @@ -33,11 +42,23 @@ public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { @Test public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { - RSocketConnector.create().fragment(100).connect(new TestClientTransport()).block(); + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create().fragment(100).connect(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); } @Test public void clientSucceedsWithDisabledFragmentation() { - RSocketConnector.connectWith(new TestClientTransport()).block(); + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.connectWith(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java index 16c75d521..a335ac1f3 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java @@ -1,13 +1,99 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Random; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; public class RSocketServerTest { + @Test + public void unexpectedFramesBeforeSetupFrame() { + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("SETUP or RESUME frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + duplexConnection.alloc().assertHasNoLeaks(); + } + + @Test + public void timeoutOnNoFirstFrame() { + final VirtualTimeScheduler scheduler = VirtualTimeScheduler.getOrSet(); + TestServerTransport transport = new TestServerTransport(); + try { + RSocketServer.create().maxTimeToFirstFrame(Duration.ofMinutes(2)).bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + Assertions.assertThat(duplexConnection.isDisposed()).isFalse(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + } finally { + transport.alloc().assertHasNoLeaks(); + VirtualTimeScheduler.reset(); + } + } + @Test public void ensuresMaxFrameLengthCanNotBeLessThenMtu() { RSocketServer.create() @@ -42,4 +128,74 @@ public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPossibleFrameLength() { + FRAME_LENGTH_MASK) .verify(); } + + @Test + public void unexpectedFramesBeforeSetup() { + Sinks.Empty connectedSink = Sinks.empty(); + + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> { + connectedSink.tryEmitEmpty(); + return Mono.just(new RSocket() {}); + }) + .bind(transport) + .block(); + + byte[] bytes = new byte[16_000_000]; + new Random().nextBytes(bytes); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.EMPTY_BUFFER, + ByteBufAllocator.DEFAULT.buffer(bytes.length).writeBytes(bytes))); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + assertThat(connectedSink.scan(Scannable.Attr.TERMINATED)) + .as("Connection should not succeed") + .isFalse(); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("SETUP or RESUME frame must be received before any others") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresErrorFrameDeliveredPriorConnectionDisposal() { + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> Mono.error(new RejectedSetupException("ACCESS_DENIED"))) + .bind(transport) + .block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + false, + 0, + 1, + Unpooled.EMPTY_BUFFER, + "metadata_type", + "data_type", + EmptyPayload.INSTANCE)); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("ACCESS_DENIED") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 1e7bb337f..e01e6ebdc 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,14 +22,11 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.TestScheduler; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; import io.rsocket.test.util.LocalDuplexConnection; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; @@ -38,23 +35,32 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReference; import org.assertj.core.api.Assertions; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.Disposables; -import reactor.core.publisher.DirectProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; public class RSocketTest { - @Rule public final SocketRule rule = new SocketRule(); + public final SocketRule rule = new SocketRule(); + + @BeforeEach + public void setup() { + rule.init(); + } + + @AfterEach + public void tearDownAndCheckOnLeaks() { + rule.alloc().assertHasNoLeaks(); + } @Test public void rsocketDisposalShouldEndupWithNoErrorsOnClose() { @@ -84,7 +90,8 @@ public boolean isDisposed() { Assertions.assertThat(requestHandlingRSocket.isDisposed()).isTrue(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testRequestReplyNoError() { StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) .expectNextCount(1) @@ -92,7 +99,8 @@ public void testRequestReplyNoError() { .verify(); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testHandlerEmitsError() { rule.setRequestAcceptor( new RSocket() { @@ -112,7 +120,8 @@ public Mono requestResponse(Payload payload) { .verify(Duration.ofMillis(100)); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testHandlerEmitsCustomError() { rule.setRequestAcceptor( new RSocket() { @@ -134,7 +143,8 @@ public Mono requestResponse(Payload payload) { .verify(); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testRequestPropagatesCorrectlyForRequestChannel() { rule.setRequestAcceptor( new RSocket() { @@ -143,7 +153,7 @@ public Flux requestChannel(Publisher payloads) { return Flux.from(payloads) // specifically limits request to 3 in order to prevent 256 request from limitRate // hidden on the responder side - .limitRequest(3); + .take(3, true); } }); @@ -157,21 +167,24 @@ public Flux requestChannel(Publisher payloads) { .verify(Duration.ofMillis(5000)); } - @Test(timeout = 2000) - public void testStream() throws Exception { + @Test + @Timeout(2000) + public void testStream() { Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); } - @Test(timeout = 2000) - public void testChannel() throws Exception { + @Test + @Timeout(200000) + public void testChannel() { Flux requests = Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); Flux responses = rule.crs.requestChannel(requests); StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testErrorPropagatesCorrectly() { AtomicReference error = new AtomicReference<>(); rule.setRequestAcceptor( @@ -490,10 +503,10 @@ void errorFromRequesterPublisher( responderPublisher.assertNoSubscribers(); } - public static class SocketRule extends ExternalResource { + public static class SocketRule { - DirectProcessor serverProcessor; - DirectProcessor clientProcessor; + Sinks.Many serverProcessor; + Sinks.Many clientProcessor; private RSocketRequester crs; @SuppressWarnings("unused") @@ -502,26 +515,20 @@ public static class SocketRule extends ExternalResource { private RSocket requestAcceptor; private LeaksTrackingByteBufAllocator allocator; - - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - init(); - base.evaluate(); - } - }; - } + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; public LeaksTrackingByteBufAllocator alloc() { return allocator; } - protected void init() { + public void init() { allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - serverProcessor = DirectProcessor.create(); - clientProcessor = DirectProcessor.create(); + serverProcessor = Sinks.many().multicast().directBestEffort(); + clientProcessor = Sinks.many().multicast().directBestEffort(); + + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); LocalDuplexConnection serverConnection = new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); @@ -543,8 +550,7 @@ public Mono requestResponse(Payload payload) { @Override public Flux requestStream(Payload payload) { return Flux.range(1, 10) - .map( - i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); + .map(i -> DefaultPayload.create("server got -> [" + payload + "]")); } @Override @@ -567,9 +573,12 @@ public Flux requestChannel(Publisher payloads) { serverConnection, requestAcceptor, PayloadDecoder.DEFAULT, - ResponderLeaseHandler.None, + null, 0, - FRAME_LENGTH_MASK); + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); crs = new RSocketRequester( @@ -578,11 +587,14 @@ public Flux requestChannel(Publisher payloads) { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); } public void setRequestAcceptor(RSocket requestAcceptor) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java index 8d96222df..3112a0943 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ package io.rsocket.core; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.subscriber.AssertSubscriber; import java.io.IOException; import java.time.Duration; import java.util.ArrayList; @@ -29,9 +32,10 @@ import java.util.concurrent.TimeoutException; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.assertj.core.api.Assertions; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -39,7 +43,6 @@ import reactor.core.Scannable; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; @@ -58,7 +61,7 @@ public class ReconnectMonoTests { public void shouldExpireValueOnRacingDisposeAndNext() { Hooks.onErrorDropped(t -> {}); Hooks.onNextDropped(System.out::println); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; Subscription mockSubscription = Mockito.mock(Subscription.class); @@ -76,26 +79,27 @@ public void subscribe(CoreSubscriber actual) { .doOnDiscard(Object.class, System.out::println) .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); monoSubscribers[0].onComplete(); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); Mockito.verify(mockSubscription).cancel(); - if (processor.isError()) { - Assertions.assertThat(processor.getError()) - .isInstanceOf(CancellationException.class) - .hasMessage("ReconnectMono has already been disposed"); + if (!subscriber.errors().isEmpty()) { + subscriber + .assertError(CancellationException.class) + .assertErrorMessage("ReconnectMono has already been disposed"); - Assertions.assertThat(expired).containsOnly("value" + i); + assertThat(expired).containsOnly("value" + i); } else { - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + subscriber.assertValues("value" + i); } expired.clear(); @@ -106,41 +110,38 @@ public void subscribe(CoreSubscriber actual) { @Test public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); - RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(racerProcessor)); - - Assertions.assertThat(processor.isTerminated()).isTrue(); + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(raceSubscriber)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); - Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + subscriber.assertTerminated(); + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, processor))) + reconnectMono.resolvingInner, subscriber))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -149,7 +150,7 @@ public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete( @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); @@ -157,11 +158,12 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); reconnectMono.resolvingInner.mainSubscriber.onComplete(); @@ -169,38 +171,33 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() RaceTestUtils.race( reconnectMono::invalidate, () -> { - reconnectMono.subscribe(racerProcessor); - if (!racerProcessor.isTerminated()) { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_not_expire" + index); reconnectMono.resolvingInner.mainSubscriber.onComplete(); } - }, - Schedulers.parallel()); - - Assertions.assertThat(processor.isTerminated()).isTrue(); - - Assertions.assertThat(processor.peek()).isEqualTo("value_to_expire" + i); - StepVerifier.create(racerProcessor) - .expectNextMatches( - (v) -> { - if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { - return v.equals("value_to_not_expire" + index); - } else { - return v.equals("value_to_expire" + index); - } - }) - .expectComplete() - .verify(Duration.ofMillis(100)); + }); + + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); + + raceSubscriber.assertComplete(); + String v = raceSubscriber.values().get(0); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + assertThat(v).isEqualTo("value_to_not_expire" + index); + } else { + assertThat(v).isEqualTo("value_to_expire" + index); + } - Assertions.assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { - Assertions.assertThat(received) + assertThat(received) .hasSize(2) .containsExactly( Tuples.of("value_to_expire" + i, reconnectMono), Tuples.of("value_to_not_expire" + i, reconnectMono)); } else { - Assertions.assertThat(received) + assertThat(received) .hasSize(1) .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); } @@ -213,7 +210,7 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); @@ -221,55 +218,50 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); reconnectMono.resolvingInner.mainSubscriber.onComplete(); RaceTestUtils.race( - () -> - RaceTestUtils.race( - reconnectMono::invalidate, reconnectMono::invalidate, Schedulers.parallel()), + reconnectMono::invalidate, + reconnectMono::invalidate, () -> { - reconnectMono.subscribe(racerProcessor); - if (!racerProcessor.isTerminated()) { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { reconnectMono.resolvingInner.mainSubscriber.onNext( "value_to_possibly_expire" + index); reconnectMono.resolvingInner.mainSubscriber.onComplete(); } - }, - Schedulers.parallel()); + }); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); - Assertions.assertThat(processor.peek()).isEqualTo("value_to_expire" + i); - StepVerifier.create(racerProcessor) - .expectNextMatches( - (v) -> - v.equals("value_to_possibly_expire" + index) - || v.equals("value_to_expire" + index)) - .expectComplete() - .verify(Duration.ofMillis(100)); + raceSubscriber.assertComplete(); + assertThat(raceSubscriber.values().get(0)) + .isIn("value_to_possibly_expire" + index, "value_to_expire" + index); if (expired.size() == 2) { - Assertions.assertThat(expired) + assertThat(expired) .hasSize(2) .containsExactly("value_to_expire" + i, "value_to_possibly_expire" + i); } else { - Assertions.assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); } if (received.size() == 2) { - Assertions.assertThat(received) + assertThat(received) .hasSize(2) .containsExactly( Tuples.of("value_to_expire" + i, reconnectMono), Tuples.of("value_to_possibly_expire" + i, reconnectMono)); } else { - Assertions.assertThat(received) + assertThat(received) .hasSize(1) .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); } @@ -282,58 +274,62 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( @Test public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; - final TestPublisher cold = - TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + final Mono source = + Mono.fromSupplier( + new Supplier() { + boolean once = false; + + @Override + public String get() { + + if (!once) { + once = true; + return "value_to_expire" + index; + } + + return "value_to_not_expire" + index; + } + }); final ReconnectMono reconnectMono = - cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + new ReconnectMono<>( + source.subscribeOn(Schedulers.boundedElastic()), onExpire(), onValue()); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); - reconnectMono.resolvingInner.mainSubscriber.onComplete(); + subscriber.await().assertComplete(); + + assertThat(expired).isEmpty(); RaceTestUtils.race( () -> - Assertions.assertThat(reconnectMono.block()) + assertThat(reconnectMono.block()) .matches( (v) -> v.equals("value_to_not_expire" + index) || v.equals("value_to_expire" + index)), - () -> - RaceTestUtils.race( - reconnectMono::invalidate, - () -> { - for (; ; ) { - if (reconnectMono.resolvingInner.subscribers != ResolvingOperator.READY) { - reconnectMono.resolvingInner.mainSubscriber.onNext( - "value_to_not_expire" + index); - reconnectMono.resolvingInner.mainSubscriber.onComplete(); - break; - } - } - }, - Schedulers.parallel()), - Schedulers.parallel()); - - Assertions.assertThat(processor.isTerminated()).isTrue(); - - Assertions.assertThat(processor.peek()).isEqualTo("value_to_expire" + i); - - Assertions.assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + reconnectMono::invalidate); + + subscriber.assertTerminated(); + + subscriber.assertValues("value_to_expire" + i); + + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { - Assertions.assertThat(received) + await().atMost(Duration.ofSeconds(5)).until(() -> received.size() == 2); + assertThat(received) .hasSize(2) .containsExactly( Tuples.of("value_to_expire" + i, reconnectMono), Tuples.of("value_to_not_expire" + i, reconnectMono)); } else { - Assertions.assertThat(received) + assertThat(received) .hasSize(1) .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); } @@ -345,45 +341,42 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value" + i); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = MonoProcessor.create(); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = new AssertSubscriber<>(); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(cold.subscribeCount()).isZero(); + assertThat(cold.subscribeCount()).isZero(); RaceTestUtils.race( - () -> reconnectMono.subscribe(processor), () -> reconnectMono.subscribe(racerProcessor)); + () -> reconnectMono.subscribe(subscriber), () -> reconnectMono.subscribe(raceSubscriber)); - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(racerProcessor.isTerminated()).isTrue(); + subscriber.assertTerminated(); + assertThat(raceSubscriber.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); - Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat(cold.subscribeCount()).isOne(); + assertThat(cold.subscribeCount()).isOne(); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, processor))) + reconnectMono.resolvingInner, subscriber))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -392,45 +385,43 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { Duration timeout = Duration.ofMillis(100); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value" + i); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = MonoProcessor.create(); + final AssertSubscriber subscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(cold.subscribeCount()).isZero(); + assertThat(cold.subscribeCount()).isZero(); String[] values = new String[1]; RaceTestUtils.race( - () -> values[0] = reconnectMono.block(timeout), () -> reconnectMono.subscribe(processor)); + () -> values[0] = reconnectMono.block(timeout), + () -> reconnectMono.subscribe(subscriber)); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); - Assertions.assertThat(values).containsExactly("value" + i); + subscriber.assertValues("value" + i); + assertThat(values).containsExactly("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat(cold.subscribeCount()).isOne(); + assertThat(cold.subscribeCount()).isOne(); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, processor))) + reconnectMono.resolvingInner, subscriber))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -439,17 +430,17 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { Duration timeout = Duration.ofMillis(100); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value" + i); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(cold.subscribeCount()).isZero(); + assertThat(cold.subscribeCount()).isZero(); String[] values1 = new String[1]; String[] values2 = new String[1]; @@ -458,24 +449,21 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { () -> values1[0] = reconnectMono.block(timeout), () -> values2[0] = reconnectMono.block(timeout)); - Assertions.assertThat(values2).containsExactly("value" + i); - Assertions.assertThat(values1).containsExactly("value" + i); + assertThat(values2).containsExactly("value" + i); + assertThat(values1).containsExactly("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat(cold.subscribeCount()).isOne(); + assertThat(cold.subscribeCount()).isOne(); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, MonoProcessor.create()))) + reconnectMono.resolvingInner, new AssertSubscriber<>()))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -484,35 +472,36 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { @Test public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); RaceTestUtils.race(cold::complete, reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - Throwable error = processor.getError(); + Throwable error = subscriber.errors().get(0); if (error instanceof CancellationException) { - Assertions.assertThat(error) + assertThat(error) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(error) + assertThat(error) .isInstanceOf(IllegalStateException.class) .hasMessage("Source completed empty"); } - Assertions.assertThat(expired).isEmpty(); + assertThat(expired).isEmpty(); expired.clear(); received.clear(); @@ -522,36 +511,35 @@ public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { @Test public void shouldExpireValueOnRacingDisposeAndComplete() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); RaceTestUtils.race(cold::complete, reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - if (processor.isError()) { - Assertions.assertThat(processor.getError()) + if (!subscriber.errors().isEmpty()) { + assertThat(subscriber.errors().get(0)) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); } - Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + assertThat(expired).hasSize(1).containsOnly("value" + i); expired.clear(); received.clear(); @@ -562,42 +550,40 @@ public void shouldExpireValueOnRacingDisposeAndComplete() { public void shouldExpireValueOnRacingDisposeAndError() { Hooks.onErrorDropped(t -> {}); RuntimeException runtimeException = new RuntimeException("test"); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - if (processor.isError()) { - if (processor.getError() instanceof CancellationException) { - Assertions.assertThat(processor.getError()) + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(processor.getError()) - .isInstanceOf(RuntimeException.class) - .hasMessage("test"); + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("test"); } } else { - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); } - Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + assertThat(expired).hasSize(1).containsOnly("value" + i); expired.clear(); received.clear(); @@ -608,7 +594,7 @@ public void shouldExpireValueOnRacingDisposeAndError() { public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { Hooks.onErrorDropped(t -> {}); RuntimeException runtimeException = new RuntimeException("test"); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); @@ -617,35 +603,32 @@ public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - if (processor.isError()) { - - if (processor.getError() instanceof CancellationException) { - Assertions.assertThat(processor.getError()) + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(processor.getError()) - .matches(t -> Exceptions.isRetryExhausted(t)) - .hasCause(runtimeException); + assertThat(error).matches(Exceptions::isRetryExhausted).hasCause(runtimeException); } - Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + assertThat(expired).hasSize(1).containsOnly("value" + i); } else { - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); } expired.clear(); @@ -693,20 +676,20 @@ public void shouldBeScannable() { final Scannable scannableOfReconnect = Scannable.from(reconnectMono); - Assertions.assertThat( + assertThat( (List) scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) .hasSize(1) .containsExactly(publisher.mono().getClass()); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) - .isEqualTo(false); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(false); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - final Scannable scannableOfMonoProcessor = Scannable.from(processor); + final Scannable scannableOfMonoProcessor = Scannable.from(subscriber); - Assertions.assertThat( + assertThat( (List) scannableOfMonoProcessor .parents() @@ -721,9 +704,8 @@ public void shouldBeScannable() { reconnectMono.dispose(); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) - .isEqualTo(true); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(true); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) .isInstanceOf(CancellationException.class); } @@ -735,36 +717,36 @@ public void shouldNotExpiredIfNotCompleted() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = new AssertSubscriber<>(); - reconnectMono.subscribe(processor); + reconnectMono.subscribe(subscriber); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.next("test"); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); reconnectMono.invalidate(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.assertSubscribers(1); - Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); publisher.complete(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(processor.isTerminated()).isTrue(); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); publisher.assertSubscribers(0); - Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); } @Test @@ -775,26 +757,26 @@ public void shouldNotEmitUntilCompletion() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = new AssertSubscriber<>(); - reconnectMono.subscribe(processor); + reconnectMono.subscribe(subscriber); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.next("test"); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.complete(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo("test"); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); + subscriber.assertValues("test"); } @Test @@ -805,31 +787,30 @@ public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = new AssertSubscriber<>(); - reconnectMono.subscribe(processor); + reconnectMono.subscribe(subscriber); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.next("test"); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); - processor.cancel(); + subscriber.cancel(); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) + assertThat(reconnectMono.resolvingInner.subscribers) .isEqualTo(ResolvingOperator.EMPTY_SUBSCRIBED); publisher.complete(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(processor.isTerminated()).isFalse(); - Assertions.assertThat(processor.peek()).isNull(); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + assertThat(subscriber.values()).isEmpty(); } @Test @@ -848,16 +829,16 @@ public void shouldExpireValueOnDispose() { .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); reconnectMono.dispose(); - Assertions.assertThat(expired).hasSize(1); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(reconnectMono.isDisposed()).isTrue(); + assertThat(expired).hasSize(1); + assertThat(received).hasSize(1); + assertThat(reconnectMono.isDisposed()).isTrue(); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectError(CancellationException.class) .verify(Duration.ofSeconds(timeout)); @@ -870,85 +851,91 @@ public void shouldNotifyAllTheSubscribers() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor sub1 = MonoProcessor.create(); - final MonoProcessor sub2 = MonoProcessor.create(); - final MonoProcessor sub3 = MonoProcessor.create(); - final MonoProcessor sub4 = MonoProcessor.create(); + final AssertSubscriber sub1 = new AssertSubscriber<>(); + final AssertSubscriber sub2 = new AssertSubscriber<>(); + final AssertSubscriber sub3 = new AssertSubscriber<>(); + final AssertSubscriber sub4 = new AssertSubscriber<>(); reconnectMono.subscribe(sub1); reconnectMono.subscribe(sub2); reconnectMono.subscribe(sub3); reconnectMono.subscribe(sub4); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers).hasSize(4); + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(4); - final ArrayList> processors = new ArrayList<>(200); + final ArrayList> subscribers = new ArrayList<>(200); - for (int i = 0; i < 100; i++) { - final MonoProcessor subA = MonoProcessor.create(); - final MonoProcessor subB = MonoProcessor.create(); - processors.add(subA); - processors.add(subB); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final AssertSubscriber subA = new AssertSubscriber<>(); + final AssertSubscriber subB = new AssertSubscriber<>(); + subscribers.add(subA); + subscribers.add(subB); RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); } - Assertions.assertThat(reconnectMono.resolvingInner.subscribers).hasSize(204); + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 4); - sub1.dispose(); + sub1.cancel(); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers).hasSize(203); + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 3); publisher.next("value"); - Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); - Assertions.assertThat(sub2.peek()).isEqualTo("value"); - Assertions.assertThat(sub3.peek()).isEqualTo("value"); - Assertions.assertThat(sub4.peek()).isEqualTo("value"); + assertThat(sub1.scan(Scannable.Attr.CANCELLED)).isTrue(); + assertThat(sub2.values().get(0)).isEqualTo("value"); + assertThat(sub3.values().get(0)).isEqualTo("value"); + assertThat(sub4.values().get(0)).isEqualTo("value"); - for (MonoProcessor sub : processors) { - Assertions.assertThat(sub.peek()).isEqualTo("value"); - Assertions.assertThat(sub.isTerminated()).isTrue(); + for (AssertSubscriber sub : subscribers) { + assertThat(sub.values().get(0)).isEqualTo("value"); + assertThat(sub.isTerminated()).isTrue(); } - Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); } @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value"); + cold.complete(); final int timeout = 10; final ReconnectMono reconnectMono = - cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + cold.flux() + .takeLast(1) + .next() + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectNext("value") .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + cold.next("value2"); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() - .expectNext("value") + .expectNext("value2") .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received) + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received) .hasSize(2) - .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value", reconnectMono)); + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value2", reconnectMono)); - Assertions.assertThat(cold.subscribeCount()).isEqualTo(2); + assertThat(cold.subscribeCount()).isEqualTo(2); expired.clear(); received.clear(); @@ -957,7 +944,7 @@ public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value"); final int timeout = 10000; @@ -965,29 +952,29 @@ public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectNext("value") .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::dispose); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectError(CancellationException.class) .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); - Assertions.assertThat(cold.subscribeCount()).isEqualTo(1); + assertThat(cold.subscribeCount()).isEqualTo(1); expired.clear(); received.clear(); @@ -1011,14 +998,14 @@ public void shouldTimeoutRetryWithVirtualTime() { .maxBackoff(Duration.ofSeconds(maxBackoff))) .timeout(Duration.ofSeconds(timeout)) .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) - .subscribeOn(Schedulers.elastic())) + .subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .thenAwait(Duration.ofSeconds(timeout)) .expectError(TimeoutException.class) .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } @Test @@ -1027,11 +1014,11 @@ public void ensuresThatMainSubscriberAllowsOnlyTerminationWithValue() { final ReconnectMono reconnectMono = new ReconnectMono<>(Mono.empty(), onExpire(), onValue()); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectErrorSatisfies( t -> - Assertions.assertThat(t) + assertThat(t) .hasMessage("Source completed empty") .isInstanceOf(IllegalStateException.class)) .verify(Duration.ofSeconds(timeout)); @@ -1047,8 +1034,8 @@ public void monoRetryNoBackoff() { StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); assertRetries(IOException.class, IOException.class); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } @Test @@ -1066,8 +1053,8 @@ public void monoRetryFixedBackoff() { assertRetries(IOException.class); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } @Test @@ -1091,8 +1078,8 @@ public void monoRetryExponentialBackoff() { assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } Consumer onRetry() { @@ -1109,12 +1096,12 @@ Consumer onExpire() { @SafeVarargs private final void assertRetries(Class... exceptions) { - assertEquals(exceptions.length, retries.size()); + assertThat(retries.size()).isEqualTo(exceptions.length); int index = 0; for (Iterator it = retries.iterator(); it.hasNext(); ) { Retry.RetrySignal retryContext = it.next(); - assertEquals(index, retryContext.totalRetries()); - assertEquals(exceptions[index], retryContext.failure().getClass()); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); index++; } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java new file mode 100644 index 000000000..c1e0a6876 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java @@ -0,0 +1,845 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.CANCEL; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(10); + + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + stateAssert.hasSubscribedFlag().hasRequestN(10).hasNoFirstFrameSentFlag(); + + publisher.assertMaxRequested(1).next(payload); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(10).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(10) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(11).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelRequesterFlux.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldErrorWithoutInitializingRemoteStreamIfSourceIsEmpty(boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.complete(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Empty Source"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldPropagateErrorWithoutInitializingRemoteStreamIfTheFirstSignalIsError( + boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.error(new RuntimeException("test")); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(RuntimeException.class) + .assertErrorMessage("test"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + requestChannelRequesterFlux.handleCancel(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasOutboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleComplete(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasInboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleCancel(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void errorShouldTerminateExecution(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + publisher.error(new ApplicationErrorException("test")); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("test") + .hasNoLeaks(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleError(new ApplicationErrorException("test")); + publisher.assertWasCancelled(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(1) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + publisher.assertWasCancelled(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelRequesterFlux.handlePayload(responsePayload1); + requestChannelRequesterFlux.handlePayload(responsePayload2); + requestChannelRequesterFlux.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelRequesterFlux.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelRequesterFlux.handleComplete(); + } else { + requestChannelRequesterFlux.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.pollFrame(); + if (errorFrameOrEmpty != null) { + if (outboundTerminationMode.equals("onError")) { + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData("outboundException") + .hasNoLeaks(); + } else { + FrameAssert.assertThat(errorFrameOrEmpty).typeOf(FrameType.CANCEL).hasNoLeaks(); + } + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + PayloadAssert.assertThat(signal.get()) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + Assertions.assertThat(throwable).isEqualTo(inboundException); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + Assertions.assertThat(throwable) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else if (inboundTerminationMode.equals("complete")) { + if (signal.isOnComplete()) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).isEmpty(); + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the error signal[%s] is the last signal, but the last was %s", + j, values.size() - 1) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"complete", "cancel"}) + public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminated( + String outboundTerminationMode) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("cancel")) { + requestChannelRequesterFlux.handleCancel(); + } else { + publisher.complete(); + } + }, + requestChannelRequesterFlux::handleComplete); + + ByteBuf completeFrameOrNull = sender.pollFrame(); + if (completeFrameOrNull != null) { + FrameAssert.assertThat(completeFrameOrNull) + .hasStreamId(1) + .typeOf(FrameType.COMPLETE) + .hasNoLeaks(); + } + + assertSubscriber.assertTerminated().assertComplete(); + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java new file mode 100644 index 000000000..890458caf --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -0,0 +1,890 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.*; +import static reactor.test.publisher.TestPublisher.Violation.CLEANUP_ON_TERMINATE; +import static reactor.test.publisher.TestPublisher.Violation.DEFER_CANCELLATION; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Exceptions; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelResponderSubscriberTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + publisher.next(TestRequesterResponderSupport.genericPayload(allocator)); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelResponderSubscriber.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelResponderSubscriber.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("inboundCancel")) { + assertSubscriber.cancel(); + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }); + + FrameAssert.assertThat(sender.awaitFrame()).typeOf(CANCEL).hasStreamId(1).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(nextPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflowBeforeFirstPayloadIsSent() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(1); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> + requestChannelResponderSubscriber + .doOnNext(__ -> assertSubscriber.request(1)) + .subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleComplete()); + + stateAssert + .hasSubscribedFlag() + .hasInboundTerminated() + .hasFirstFrameSentFlag() + .hasRequestNBetween(1, 2); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertTerminated() + .assertComplete(); + + publisher.complete(); + + if (sender.getSent().size() > 1) { + FrameAssert.assertThat(sender.awaitFrame()) + .hasStreamId(1) + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasNoLeaks(); + } + FrameAssert.assertThat(sender.awaitFrame()).hasStreamId(1).typeOf(COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { + ApplicationErrorException applicationErrorException = new ApplicationErrorException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleError(applicationErrorException)); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(applicationErrorException.getClass()) + .assertErrorMessage("test"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingOutboundErrorWithSubscription() { + RuntimeException exception = new RuntimeException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> publisher.error(exception)); + + stateAssert.isTerminated(); + + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(ERROR) + .hasData("test") + .hasStreamId(1) + .hasNoLeaks(); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleCancel()); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Inbound has been canceled"); + + allocator.assertHasNoLeaks(); + } + } + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, requestPayload, activeStreams); + + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + final AssertSubscriber> assertSubscriber = + requestChannelResponderSubscriber + .materialize() + .subscribeWith(AssertSubscriber.create(0)); + + assertSubscriber.request(Integer.MAX_VALUE); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelResponderSubscriber.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelResponderSubscriber.handlePayload(responsePayload1); + requestChannelResponderSubscriber.handlePayload(responsePayload2); + requestChannelResponderSubscriber.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelResponderSubscriber.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelResponderSubscriber.handleComplete(); + } else { + requestChannelResponderSubscriber.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.pollFrame(); + if (errorFrameOrEmpty != null) { + String message; + if (outboundTerminationMode.equals("onError")) { + message = outboundException.getMessage(); + } else if (outboundTerminationMode.equals("sizeError")) { + message = String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK); + } else { + message = "Failed to validate payload. Cause:refCnt: 0"; + } + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData(message) + .hasNoLeaks(); + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + Payload payload = signal.get(); + if (j == 0) { + Assertions.assertThat(payload).isEqualTo(requestPayload); + } + + PayloadAssert.assertThat(payload) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + inboundException.getMessage(), + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors) + .hasSize(1) + .first() + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + } + } + } else if (inboundTerminationMode.equals("complete")) { + Assertions.assertThat(droppedErrors).isEmpty(); + if (signal.isOnError()) { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf(CancellationException.class) + .matches( + t -> + t.getMessage().equals("Inbound has been canceled") + || t.getMessage().equals("Outbound has terminated with an error")); + } + } else { + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + "Inbound has been canceled", + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + Assertions.assertThat(throwable).isExactlyInstanceOf(CancellationException.class); + } + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the %s signal[%s] is the last signal, but the last was %s", + signal, j, values.get(values.size() - 1)) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"onError", "sizeError", "refCntError", "cancel"}) + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(String terminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(2); + + Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final RequestChannelResponderSubscriber requestOperator = + new RequestChannelResponderSubscriber(1, Long.MAX_VALUE, firstPayload, activeStreams); + + publisher.subscribe(requestOperator); + requestOperator.subscribe(assertSubscriber); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, responsePayload); + + Payload releasedPayload1 = ByteBufPayload.create(new byte[0]); + Payload releasedPayload2 = ByteBufPayload.create(new byte[0]); + releasedPayload1.release(); + releasedPayload2.release(); + + RaceTestUtils.race( + () -> { + switch (terminationMode) { + case "onError": + publisher.error(outboundException); + break; + case "sizeError": + publisher.next(oversizePayload); + break; + case "refCntError": + publisher.next(releasedPayload1); + break; + case "cancel": + default: + assertSubscriber.cancel(); + } + }, + () -> { + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + requestOperator.handleNext(frame, lastFragmentId != j, false); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + + PayloadAssert.assertThat(values.get(0)).isEqualTo(firstPayload).hasNoLeaks(); + + if (values.size() > 1) { + Payload payload = values.get(1); + PayloadAssert.assertThat(payload).isEqualTo(responsePayload).hasNoLeaks(); + } + + if (!sender.isEmpty()) { + if (terminationMode.equals("cancel")) { + assertSubscriber.assertNotTerminated(); + } else { + assertSubscriber.assertTerminated().assertError(); + } + + final ByteBuf requstFrame = sender.awaitFrame(); + FrameAssert.assertThat(requstFrame) + .isNotNull() + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf terminalFrame = sender.awaitFrame(); + FrameAssert.assertThat(terminalFrame) + .isNotNull() + .typeOf(terminationMode.equals("cancel") ? CANCEL : ERROR) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + + PayloadAssert.assertThat(responsePayload).hasNoLeaks(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java new file mode 100644 index 000000000..b39ac62d9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java @@ -0,0 +1,698 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestResponseRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + * + */ + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends frame on the first request 3) + * terminates up on receiving the first signal (terminates on first next | error | next over + * reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(RequestResponseRequesterMono.STATE, requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + PayloadAssert.assertThat(payload).isReleased(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + stateAssert.isTerminated(); + + if (!sender.isEmpty()) { + ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnSubscriptionResponses() { + return Stream.of( + // next case + (rrm, sv) -> + sv.then(() -> rrm.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .expectComplete(), + // complete case + (rrm, sv) -> sv.then(rrm::handleComplete).expectComplete(), + // error case + (rrm, sv) -> + sv.then(() -> rrm.handleError(new ApplicationErrorException("test"))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + // fragmentation case + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then(stateAssert::isTerminated) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + p.release(); + }) + .then(payload::release) + .expectComplete(); + }, + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then( + () -> { + rrm.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends fragments frames on the first + * request 3) terminates up on receiving the first signal (terminates on first next | error | next + * over reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameFragmentsShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General StateMachine transition test. Ensures that no fragment is sent if mono was cancelled + * before any requests + */ + @Test + public void shouldBeNoOpsOnCancel() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(() -> stateAssert.hasSubscribedFlagOnly()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenCancel() + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload is an invalid one. + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestResponseRequesterMono); + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is disabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload payload = ByteBufPayload.create(""); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is enabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrates + * to the terminated in case the given payload is too big with disabled fragmentation + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that error check happens exactly before frame sent. This cases ensures that in case no + * lease / other external errors appeared, the local subscriber received the same one. No frames + * should be sent + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then(() -> StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestResponseRequesterMono).name()) + .isEqualTo("source(RequestResponseMono)"); + requestResponseRequesterMono.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java new file mode 100644 index 000000000..8702d1a80 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java @@ -0,0 +1,1227 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestStreamRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + + /** + * State Machine check. Ensure migration from + * + *
    +   * UNSUBSCRIBED -> SUBSCRIBED
    +   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
    +   * REQUESTED(0) -> REQUESTED(1) -> REQUESTED(0)
    +   * REQUESTED(0) -> REQUESTED(MAX)
    +   * REQUESTED(MAX) -> REQUESTED(MAX) && REASSEMBLY (extra flag enabled which indicates
    +   * reassembly)
    +   * REQUESTED(MAX) && REASSEMBLY -> TERMINATED
    +   * 
    + */ + @Test + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + ByteBuf firstFragment = fragments.remove(0); + requestStreamRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollowing = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestStreamRequesterFlux.handleNext(followingFragment, hasFollowing, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + Payload finalRandomPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(finalRandomPayload); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(finalRandomPayload).hasNoLeaks()) + .assertComplete(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
    +   * UNSUBSCRIBED -> SUBSCRIBED
    +   * SUBSCRIBED -> REQUESTED(MAX)
    +   * REQUESTED(MAX) -> TERMINATED
    +   * 
    + */ + @Test + public void requestNFrameShouldBeSentExactlyOnceIfItIsMaxAllowed() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Long.MAX_VALUE / 2 + 1); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + requestStreamRequesterFlux.handlePayload(EmptyPayload.INSTANCE); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber.assertValues(EmptyPayload.INSTANCE).assertComplete(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
    +   * UNSUBSCRIBED -> SUBSCRIBED
    +   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
    +   * 
    + * + * And then for the following cases: + * + *
    +   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
    +   * affect state anyhow and should not sent any extra frames)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [3]: REQUESTED(0) -> REASSEMBLY
    +   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
    +   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [4]: REQUESTED(0) -> REQUESTED(MAX)
    +   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
    +   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
    +   * 
    + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnFirstRequestResponses() { + return Stream.of( + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handleError(new ApplicationErrorException("test"))) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .thenRequest(1L) + .thenRequest(1L) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + final Payload payload2 = ByteBufPayload.create(data, metadata); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .then(payload::release) + .then(() -> rsf.handlePayload(payload2)) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(); + }, + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload0 = ByteBufPayload.create(data, metadata); + final Payload payload = ByteBufPayload.create(data, metadata); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then(() -> rsf.handlePayload(payload0)) + .assertNext(p -> PayloadAssert.assertThat(p).isEqualTo(payload0).hasNoLeaks()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + // state machine check + StateAssert.assertThat(rsf).isTerminated(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * State Machine check with fragmentation of the first payload. Ensure migration from + * + *
    +   * UNSUBSCRIBED -> SUBSCRIBED
    +   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
    +   * 
    + * + * And then for the following cases: + * + *
    +   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
    +   * affect state anyhow and should not sent any extra frames)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [3]: REQUESTED(0) -> REASSEMBLY
    +   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
    +   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
    +   *      REQUESTED(MAX) -> TERMINATED
    +   *
    +   * [4]: REQUESTED(0) -> REQUESTED(MAX)
    +   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
    +   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
    +   * 
    + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameFragmentsShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N) + // InitialRequestN size + .hasMetadata(Arrays.copyOf(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA) + .hasMetadata( + Arrays.copyOfRange(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N, 65)) + .hasData(Arrays.copyOf(data, 35)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 35, 35 + 55)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(39) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 90, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Case which ensures that if Payload has incorrect refCnt, the flux ends up with an appropriate + * error + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will exponse + * the error immediatelly and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final Payload payload = ByteBufPayload.create(""); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will expose + * the error immediately and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if the given payload is exits 16mb size with disabled fragmentation, than the + * appropriate validation happens and a corresponding error will be propagagted to the subscriber + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that the interactions check and respect rsocket availability (such as leasing) and + * propagate an error to the final subscriber. No frame should be sent. Check should happens + * exactly on the first request. + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + Payload requestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(requestedPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isEqualTo(requestedPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + PayloadAssert.assertThat(requestedPayload).isReleased(); + PayloadAssert.assertThat(unrequestedPayload).isReleased(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestStreamRequesterFlux).name()) + .isEqualTo("source(RequestStreamFlux)"); + requestStreamRequesterFlux.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java new file mode 100644 index 000000000..06d050f6f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -0,0 +1,790 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +@SuppressWarnings("ALL") +public class RequesterOperatorsRacingTest { + + interface Scenario { + FrameType requestType(); + + Publisher requestOperator( + Supplier payloadsSupplier, RequesterResponderSupport requesterResponderSupport); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return METADATA_PUSH; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new MetadataPushRequesterMono(payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return MetadataPushRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_FNF; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new FireAndForgetRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return FireAndForgetRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestResponseRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestStreamRequesterFlux( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestStreamRequesterFlux.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestChannelRequesterFlux( + Flux.generate(s -> s.next(payloadsSupplier.get())), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestChannelRequesterFlux.class.getSimpleName(); + } + }); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + /** Ensures single subscription happens in case of racing */ + @ParameterizedTest(name = "Should subscribe exactly once to {0}") + @MethodSource("scenarios") + public void shouldSubscribeExactlyOnce(Scenario scenario) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport requesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> + TestRequesterResponderSupport.genericPayload( + requesterResponderSupport.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, requesterResponderSupport); + + StepVerifier stepVerifier = + StepVerifier.create(requesterResponderSupport.getDuplexConnection().getSentAsPublisher()) + .assertNext( + frame -> { + FrameAssert frameAssert = + FrameAssert.assertThat(frame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()); + if (scenario.requestType() == METADATA_PUSH) { + frameAssert + .hasStreamIdZero() + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT); + } else { + frameAssert + .hasClientSideStreamId() + .hasStreamId(1) + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length + + TestRequesterResponderSupport.DATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT); + } + frameAssert.hasNoLeaks(); + + if (requestOperator instanceof FrameHandler) { + ((FrameHandler) requestOperator).handleComplete(); + if (scenario.requestType() == REQUEST_CHANNEL) { + ((FrameHandler) requestOperator).handleCancel(); + } + } + }) + .thenCancel() + .verifyLater(); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + }, + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + })) + .matches( + t -> { + Assertions.assertThat(t).hasMessageContaining("allows only a single Subscriber"); + return true; + }); + + stepVerifier.verify(Duration.ofSeconds(1)); + requesterResponderSupport.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } + } + } + + /** Ensures single frame is sent only once racing between requests */ + @ParameterizedTest(name = "{0} should sent requestFrame exactly once if request(n) is racing") + @MethodSource("scenarios") + public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + requestOperator.subscribe(assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.request(1), () -> assertSubscriber.request(1)); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + + if (scenario.requestType().hasInitialRequestN()) { + if (RequestStreamFrameCodec.initialRequestN(sentFrame) == 1) { + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .isNotNull() + .hasStreamId(1) + .hasRequestN(1) + .typeOf(REQUEST_N) + .hasNoLeaks(); + } else { + Assertions.assertThat(RequestStreamFrameCodec.initialRequestN(sentFrame)).isEqualTo(2); + } + } + + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + ((RequesterFrameHandler) requestOperator).handleComplete(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + ((CoreSubscriber) requestOperator).onComplete(); + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } + + assertSubscriber + .assertTerminated() + .assertValuesWith( + p -> { + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + } + } + + /** + * Ensures that no ByteBuf is leaked if reassembly is starting and cancel is happening at the same + * time + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(1); + + requestOperator.subscribe(assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = + TestRequesterResponderSupport.randomPayload(activeStreams.getAllocator()); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments( + activeStreams.getAllocator(), mtu, responsePayload); + RaceTestUtils.race( + assertSubscriber::cancel, + () -> { + FrameHandler frameHandler = (FrameHandler) requestOperator; + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + frameHandler.handleNext(frame, lastFragmentId != j, lastFragmentId == j); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + if (!values.isEmpty()) { + Assertions.assertThat(values) + .hasSize(1) + .first() + .matches( + p -> { + Assertions.assertThat(p.sliceData()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceData())); + Assertions.assertThat(p.hasMetadata()).isEqualTo(responsePayload.hasMetadata()); + Assertions.assertThat(p.sliceMetadata()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceMetadata())); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + return true; + }); + } + + if (!activeStreams.getDuplexConnection().isEmpty()) { + if (scenario.requestType() != REQUEST_CHANNEL) { + assertSubscriber.assertNotTerminated(); + } + + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + + Assertions.assertThat(responsePayload.release()).isTrue(); + Assertions.assertThat(responsePayload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case of racing between next element and cancel we will not have any memory + * leaks + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnNextAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handlePayload(response)); + + assertSubscriber.values().forEach(Payload::release); + Assertions.assertThat(response.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + assertSubscriber.assertTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case we have element reassembling and then it happens the remote sends + * (errorFrame) and downstream subscriber sends cancel() and we have racing between onError and + * cancel we will not have any memory leaks + */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + boolean[] withReassemblyOptions = new boolean[] {true, false}; + final ArrayList droppedErrors = new ArrayList<>(); + Hooks.onErrorDropped(droppedErrors::add); + + try { + for (boolean withReassembly : withReassemblyOptions) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + final StateAssert stateAssert; + if (requestOperator instanceof RequestResponseRequesterMono) { + stateAssert = StateAssert.assertThat((RequestResponseRequesterMono) requestOperator); + } else if (requestOperator instanceof RequestStreamRequesterFlux) { + stateAssert = StateAssert.assertThat((RequestStreamRequesterFlux) requestOperator); + } else { + stateAssert = StateAssert.assertThat((RequestChannelRequesterFlux) requestOperator); + } + + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + ((RequesterFrameHandler) requestOperator).handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } + + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handleError(testException)); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnError(1) + .expectNothing(); + + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + /** + * Ensures that in case of racing between first request and cancel does not going to introduce + * leaks.
    + *
    + * + *

    Please note, first request may or may not happen so in case it happened before cancellation + * signal we have to observe + * + *

      + *
    • RequestResponseFrame + *
    • CancellationFrame + *
    + * + *

    exactly in that order + * + *

    Ensures full serialization of outgoing signal (frames) + */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); + + if (!activeStreams.getDuplexConnection().isEmpty()) { + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .typeOf(scenario.requestType()) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); + + Assertions.assertThat(response.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldSentCancelFrameExactlyOnce(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requesterOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requesterOperator.subscribe((AssertSubscriber) assertSubscriber); + + assertSubscriber.request(1); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + ((Subscription) requesterOperator)::cancel, ((Subscription) requesterOperator)::cancel); + + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + + activeStreams.assertNoActiveStreams(); + + ((RequesterFrameHandler) requesterOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); + Assertions.assertThat(response.refCnt()).isZero(); + + ((RequesterFrameHandler) requesterOperator).handleComplete(); + assertSubscriber.assertNotTerminated(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java index 29748abbe..382240c4a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import io.rsocket.RaceTestConstants; import io.rsocket.internal.subscriber.AssertSubscriber; import java.time.Duration; import java.util.ArrayList; @@ -36,31 +37,31 @@ import org.mockito.Mockito; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Hooks; -import reactor.core.publisher.MonoProcessor; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.util.RaceTestUtils; -import reactor.util.retry.Retry; public class ResolvingOperatorTests { - private Queue retries = new ConcurrentLinkedQueue<>(); - @Test public void shouldExpireValueOnRacingDisposeAndComplete() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; ResolvingTest.create() @@ -76,42 +77,48 @@ public void shouldExpireValueOnRacingDisposeAndComplete() { .ifResolvedAssertEqual("value" + index) .assertIsDisposed(); - if (processor.isError()) { - Assertions.assertThat(processor.getError()) + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Assertions.assertThat(subscriber.errors().get(0)) .isInstanceOf(CancellationException.class) .hasMessage("Disposed"); } else { - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(subscriber.values()).containsExactly("value" + i); } } } @Test public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -123,10 +130,7 @@ public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete( self -> { RaceTestUtils.race(() -> self.complete(valueToSend), () -> self.observe(consumer)); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); }) .assertDisposeCalled(0) .assertReceivedExactly(valueToSend) @@ -134,39 +138,40 @@ public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete( .thenAddObserver(consumer2) .assertPendingSubscribers(0); - StepVerifier.create(processor2) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber2.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); } } @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; final String valueToSend2 = "value2" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -179,10 +184,7 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() self -> { self.complete(valueToSend); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); }) .assertReceivedExactly(valueToSend) .then( @@ -191,11 +193,10 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() self::invalidate, () -> { self.observe(consumer2); - if (!processor2.isTerminated()) { + if (!subscriber2.isTerminated()) { self.complete(valueToSend2); } - }, - Schedulers.parallel())) + })) .then( self -> { if (self.isPending()) { @@ -209,46 +210,51 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() .assertDisposeCalled(0) .then( self -> - StepVerifier.create(processor2) - .expectNextMatches( - (v) -> { + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { if (self.subscribers == ResolvingOperator.READY) { - return v.equals(valueToSend2); + Assertions.assertThat(v).isEqualTo(valueToSend2); } else { - return v.equals(valueToSend); + Assertions.assertThat(v).isEqualTo(valueToSend); } }) - .expectComplete() - .verify(Duration.ofMillis(100))); + .assertComplete()); } } @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; final String valueToSend2 = "value_to_possibly_expire" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -261,25 +267,20 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( self -> { self.complete(valueToSend); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(100)).assertValues(valueToSend).assertComplete(); }) .assertReceivedExactly(valueToSend) .then( self -> RaceTestUtils.race( - () -> - RaceTestUtils.race( - self::invalidate, self::invalidate, Schedulers.parallel()), + self::invalidate, + self::invalidate, () -> { self.observe(consumer2); - if (!processor2.isTerminated()) { + if (!subscriber2.isTerminated()) { self.complete(valueToSend2); } - }, - Schedulers.parallel())) + })) .then( self -> { if (!self.isPending()) { @@ -296,7 +297,7 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( .haveAtMost( 2, new Condition<>( - new Predicate() { + new Predicate() { int time = 0; @Override @@ -313,40 +314,39 @@ public boolean test(Object s) { .assertPendingSubscribers(0) .assertDisposeCalled(0) .then( - new Consumer>() { - @Override - public void accept(ResolvingTest self) { - StepVerifier.create(processor2) - .expectNextMatches( - (v) -> { + self -> + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { if (self.subscribers == ResolvingOperator.READY) { - return v.equals(valueToSend2); + Assertions.assertThat(v).isEqualTo(valueToSend2); } else { - return v.equals(valueToSend) || v.equals(valueToSend2); + Assertions.assertThat(v).isIn(valueToSend, valueToSend2); } }) - .expectComplete() - .verify(Duration.ofMillis(100)); - } - }); + .assertComplete()); } } @Test public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; final String valueToSend2 = "value2" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; ResolvingTest.create() @@ -359,10 +359,7 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { self -> { self.complete(valueToSend); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); }) .assertReceivedExactly(valueToSend) .then( @@ -371,19 +368,15 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { () -> Assertions.assertThat(self.block(null)) .matches((v) -> v.equals(valueToSend) || v.equals(valueToSend2)), - () -> - RaceTestUtils.race( - self::invalidate, - () -> { - for (; ; ) { - if (self.subscribers != ResolvingOperator.READY) { - self.complete(valueToSend2); - break; - } - } - }, - Schedulers.parallel()), - Schedulers.parallel())) + self::invalidate, + () -> { + for (; ; ) { + if (self.subscribers != ResolvingOperator.READY) { + self.complete(valueToSend2); + break; + } + } + })) .then( self -> { if (self.isPending()) { @@ -400,29 +393,33 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -442,11 +439,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { .assertDisposeCalled(0) .then( self -> { - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor2.isTerminated()).isTrue(); + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo(valueToSend); - Assertions.assertThat(processor2.peek()).isEqualTo(valueToSend); + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); @@ -457,20 +454,23 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -482,7 +482,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { .then( self -> RaceTestUtils.race( - () -> processor.onNext(self.block(null)), () -> self.observe(consumer2))) + () -> { + subscriber.onNext(self.block(null)); + subscriber.onComplete(); + }, + () -> self.observe(consumer2))) .assertSubscribeCalled(1) .assertPendingSubscribers(0) .assertReceivedExactly(valueToSend) @@ -490,11 +494,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { .assertDisposeCalled(0) .then( self -> { - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor2.isTerminated()).isTrue(); + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo(valueToSend); - Assertions.assertThat(processor2.peek()).isEqualTo(valueToSend); + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); @@ -506,11 +510,14 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { Duration timeout = Duration.ofMillis(100); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); ResolvingTest.create() .assertNothingExpired() @@ -521,8 +528,14 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { .then( self -> RaceTestUtils.race( - () -> processor.onNext(self.block(timeout)), - () -> processor2.onNext(self.block(timeout)))) + () -> { + subscriber.onNext(self.block(timeout)); + subscriber.onComplete(); + }, + () -> { + subscriber2.onNext(self.block(timeout)); + subscriber2.onComplete(); + })) .assertSubscribeCalled(1) .assertPendingSubscribers(0) .assertReceivedExactly(valueToSend) @@ -530,11 +543,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { .assertDisposeCalled(0) .then( self -> { - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor2.isTerminated()).isTrue(); + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo(valueToSend); - Assertions.assertThat(processor2.peek()).isEqualTo(valueToSend); + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); @@ -548,26 +561,31 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { public void shouldExpireValueOnRacingDisposeAndError() { Hooks.onErrorDropped(t -> {}); RuntimeException runtimeException = new RuntimeException("test"); - for (int i = 0; i < 10000; i++) { - MonoProcessor processor = MonoProcessor.create(); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -591,8 +609,9 @@ public void shouldExpireValueOnRacingDisposeAndError() { }) .thenAddObserver(consumer2); - StepVerifier.create(processor) - .expectErrorSatisfies( + subscriber + .await(Duration.ofMillis(10)) + .assertErrorWith( t -> { if (t instanceof CancellationException) { Assertions.assertThat(t) @@ -601,11 +620,11 @@ public void shouldExpireValueOnRacingDisposeAndError() { } else { Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); } - }) - .verify(Duration.ofMillis(10)); + }); - StepVerifier.create(processor2) - .expectErrorSatisfies( + subscriber2 + .await(Duration.ofMillis(10)) + .assertErrorWith( t -> { if (t instanceof CancellationException) { Assertions.assertThat(t) @@ -614,8 +633,7 @@ public void shouldExpireValueOnRacingDisposeAndError() { } else { Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); } - }) - .verify(Duration.ofMillis(10)); + }); // no way to guarantee equality because of racing // Assertions.assertThat(processor.getError()) @@ -664,9 +682,10 @@ public void shouldThrowOnBlockingIfHasAlreadyTerminated() { static Stream, Publisher>> innerCases() { return Stream.of( (self) -> { - final MonoProcessor processor = MonoProcessor.create(); + final Sinks.One processor = Sinks.unsafe().one(); final ResolvingOperator.DeferredResolution operator = - new ResolvingOperator.DeferredResolution(self, processor) { + new ResolvingOperator.DeferredResolution( + self, new SinkOneSubscriber(processor)) { @Override public void accept(String v, Throwable t) { if (t != null) { @@ -677,14 +696,21 @@ public void accept(String v, Throwable t) { onNext(v); } }; - return processor.doOnSubscribe(s -> self.observe(operator)).doOnCancel(operator::cancel); + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); }, (self) -> { - final MonoProcessor processor = MonoProcessor.create(); + final Sinks.One processor = Sinks.unsafe().one(); + final SinkOneSubscriber subscriber = new SinkOneSubscriber(processor); final ResolvingOperator.MonoDeferredResolutionOperator operator = - new ResolvingOperator.MonoDeferredResolutionOperator<>(self, processor); - processor.onSubscribe(operator); - return processor.doOnSubscribe(s -> self.observe(operator)).doOnCancel(operator::cancel); + new ResolvingOperator.MonoDeferredResolutionOperator<>(self, subscriber); + subscriber.onSubscribe(operator); + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); }); } @@ -737,12 +763,13 @@ public void shouldExpireValueOnDispose( public void shouldNotifyAllTheSubscribers( Function, Publisher> caseProducer) { - final MonoProcessor sub1 = MonoProcessor.create(); - final MonoProcessor sub2 = MonoProcessor.create(); - final MonoProcessor sub3 = MonoProcessor.create(); - final MonoProcessor sub4 = MonoProcessor.create(); + AssertSubscriber sub1 = AssertSubscriber.create(); + AssertSubscriber sub2 = AssertSubscriber.create(); + AssertSubscriber sub3 = AssertSubscriber.create(); + AssertSubscriber sub4 = AssertSubscriber.create(); - final ArrayList> processors = new ArrayList<>(200); + final ArrayList> processors = + new ArrayList<>(RaceTestConstants.REPEATS * 2); ResolvingTest.create() .assertDisposeCalled(0) @@ -761,9 +788,9 @@ public void shouldNotifyAllTheSubscribers( .assertPendingSubscribers(4) .then( self -> { - for (int i = 0; i < 100; i++) { - final MonoProcessor subA = MonoProcessor.create(); - final MonoProcessor subB = MonoProcessor.create(); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subA = AssertSubscriber.create(); + AssertSubscriber subB = AssertSubscriber.create(); processors.add(subA); processors.add(subB); RaceTestUtils.race( @@ -772,21 +799,21 @@ public void shouldNotifyAllTheSubscribers( } }) .assertSubscribeCalled(1) - .assertPendingSubscribers(204) - .then(self -> sub1.dispose()) - .assertPendingSubscribers(203) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 4) + .then(self -> sub1.cancel()) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 3) .then( self -> { String valueToSend = "value"; self.complete(valueToSend); - Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); - Assertions.assertThat(sub2.peek()).isEqualTo(valueToSend); - Assertions.assertThat(sub3.peek()).isEqualTo(valueToSend); - Assertions.assertThat(sub4.peek()).isEqualTo(valueToSend); + Assertions.assertThat(sub1.isTerminated()).isFalse(); + Assertions.assertThat(sub2.values()).containsExactly(valueToSend); + Assertions.assertThat(sub3.values()).containsExactly(valueToSend); + Assertions.assertThat(sub4.values()).containsExactly(valueToSend); - for (MonoProcessor sub : processors) { - Assertions.assertThat(sub.peek()).isEqualTo(valueToSend); + for (AssertSubscriber sub : processors) { + Assertions.assertThat(sub.values()).containsExactly(valueToSend); Assertions.assertThat(sub.isTerminated()).isTrue(); } }) @@ -797,7 +824,7 @@ public void shouldNotifyAllTheSubscribers( @Test public void shouldBeSerialIfRacyMonoInner() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { long[] requested = new long[] {0}; Subscription mockSubscription = Mockito.mock(Subscription.class); Mockito.doAnswer( @@ -833,7 +860,7 @@ public void accept(Object o, Object o2) {} @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ResolvingTest.create() .assertNothingExpired() .assertNothingReceived() @@ -847,7 +874,7 @@ public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ResolvingTest.create() .assertNothingExpired() .assertNothingReceived() @@ -967,4 +994,37 @@ protected void doOnDispose() { onDisposeCalls.incrementAndGet(); } } + + private static class SinkOneSubscriber implements CoreSubscriber { + + private final Sinks.One processor; + private boolean valueReceived; + + public SinkOneSubscriber(Sinks.One processor) { + this.processor = processor; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String s) { + valueReceived = true; + processor.tryEmitValue(s); + } + + @Override + public void onError(Throwable t) { + processor.tryEmitError(t); + } + + @Override + public void onComplete() { + if (!valueReceived) { + processor.tryEmitEmpty(); + } + } + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java new file mode 100755 index 000000000..4f7821e4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.publisher.TestPublisher; + +public class ResponderOperatorsCommonTest { + + interface Scenario { + FrameType requestType(); + + int maxElements(); + + ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler); + + ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public int maxElements() { + return 1; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber( + streamId, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + + return handler.requestResponse(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + + streamManager.activeStreams.put(streamId, subscriber); + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + + return handler.requestStream(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestStreamResponderSubscriber.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber responderSubscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstPayload, streamManager); + streamManager.activeStreams.put(streamId, responderSubscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); + } + + @Override + public String toString() { + return RequestChannelResponderSubscriber.class.getSimpleName(); + } + }); + } + + static class TestHandler implements RSocket { + + final TestPublisher producer; + final AssertSubscriber consumer; + + TestHandler(TestPublisher producer, AssertSubscriber consumer) { + this.producer = producer; + this.consumer = consumer; + } + + @Override + public Mono fireAndForget(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono().then(); + } + + @Override + public Mono requestResponse(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(consumer); + return producer.flux(); + } + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, + TestRequesterResponderSupport.genericPayload(allocator), + testRequesterResponderSupport, + testHandler); + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + responderFrameHandler.handleComplete(); + testHandler.consumer.assertComplete(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleFragmentedRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + responderFrameHandler.handleNext(fragment, hasFollows, !hasFollows); + fragment.release(); + } + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.pollFrame()).isNull(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith( + p -> PayloadAssert.assertThat(p).hasData(firstPayload.sliceData()).hasNoLeaks()) + .assertComplete(); + + testRequesterResponderSupport.assertNoActiveStreams(); + + firstPayload.release(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleInterruptedFragmentation(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + firstPayload.release(); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + if (hasFollows) { + responderFrameHandler.handleNext(fragment, true, false); + } else { + responderFrameHandler.handleCancel(); + } + fragment.release(); + } + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertNoActiveStreams(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java new file mode 100644 index 000000000..9a51b9419 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java @@ -0,0 +1,31 @@ +package io.rsocket.core; + +import static org.mockito.Mockito.*; + +import io.netty.util.ReferenceCounted; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; + +public class SendUtilsTest { + + @Test + void droppedElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer value = extractDroppedElementConsumer(); + value.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = mock(ReferenceCounted.class); + when(referenceCounted.release()).thenReturn(true); + + Consumer value = extractDroppedElementConsumer(); + value.accept(referenceCounted); + + verify(referenceCounted).release(); + } + + private static Consumer extractDroppedElementConsumer() { + return (Consumer) SendUtils.DISCARD_CONTEXT.stream().findAny().get().getValue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index fe53b7df4..87c3a865f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; @@ -6,7 +21,12 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.*; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.exceptions.RejectedSetupException; @@ -14,15 +34,13 @@ import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.SetupFrameCodec; -import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; public class SetupRejectionTest { @@ -40,18 +58,21 @@ void responderRejectSetup() { ByteBuf sentFrame = transport.awaitSent(); assertThat(FrameHeaderCodec.frameType(sentFrame)).isEqualTo(FrameType.ERROR); RuntimeException error = Exceptions.from(0, sentFrame); + sentFrame.release(); assertThat(errorMsg).isEqualTo(error.getMessage()); assertThat(error).isInstanceOf(RejectedSetupException.class); RSocket acceptorSender = acceptor.senderRSocket().block(); assertThat(acceptorSender.isDisposed()).isTrue(); + transport.allocator.assertHasNoLeaks(); } @Test - @Disabled("FIXME: needs to be revised") void requesterStreamsTerminatedOnZeroErrorFrame() { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); + RSocketRequester rSocket = new RSocketRequester( conn, @@ -59,11 +80,14 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); String errorMsg = "error"; @@ -82,6 +106,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { .verify(Duration.ofSeconds(5)); assertThat(rSocket.isDisposed()).isTrue(); + allocator.assertHasNoLeaks(); } @Test @@ -89,6 +114,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); RSocketRequester rSocket = new RSocketRequester( conn, @@ -96,11 +122,14 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, + Integer.MAX_VALUE, 0, 0, null, - RequesterLeaseHandler.None, - TestScheduler.INSTANCE); + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); conn.addToReceivedBuffer( ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); @@ -112,11 +141,13 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { .expectErrorMatches( err -> err instanceof RejectedSetupException && "error".equals(err.getMessage())) .verify(Duration.ofSeconds(5)); + allocator.assertHasNoLeaks(); } private static class RejectingAcceptor implements SocketAcceptor { private final String errorMessage; - private final UnicastProcessor senderRSockets = UnicastProcessor.create(); + private final Sinks.Many senderRSockets = + Sinks.many().unicast().onBackpressureBuffer(); public RejectingAcceptor(String errorMessage) { this.errorMessage = errorMessage; @@ -124,12 +155,12 @@ public RejectingAcceptor(String errorMessage) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { - senderRSockets.onNext(sendingSocket); + senderRSockets.tryEmitNext(sendingSocket); return Mono.error(new RuntimeException(errorMessage)); } public Mono senderRSocket() { - return senderRSockets.next(); + return senderRSockets.asFlux().next(); } } @@ -145,11 +176,7 @@ public Mono start(ConnectionAcceptor acceptor) { } public ByteBuf awaitSent() { - try { - return conn.awaitSend(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + return conn.awaitFrame(); } public void connect() { diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java new file mode 100644 index 000000000..88e0dc8e2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java @@ -0,0 +1,98 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%s)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_HAVE_FLAG = "Expected state\n\t%s\nto have\n\t%s\nbut had\n\t[%s]"; + + private ShouldHaveFlag(long currentState, String expectedFlag, String actualFlags) { + super(SHOULD_HAVE_FLAG, toBinaryString(currentState), expectedFlag, actualFlags); + } + + static ErrorMessageFactory shouldHaveFlag(long currentState, long expectedFlag) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag(currentState, FLAGS_NAMES.get(expectedFlag), stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestN(long currentState, long expectedRequestN) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + expectedRequestN == Integer.MAX_VALUE ? "MAX" : expectedRequestN), + stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestNBetween( + long currentState, long expectedRequestNMin, long expectedRequestNMax) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + (expectedRequestNMin == Integer.MAX_VALUE ? "MAX" : expectedRequestNMin) + + " - " + + (expectedRequestNMax == Integer.MAX_VALUE ? "MAX" : expectedRequestNMax)), + stateAsString); + } + + private static String extractStateAsString(long currentState) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append( + String.format( + FLAGS_NAMES.get(REQUEST_MASK), requestN >= Integer.MAX_VALUE ? "MAX" : requestN)); + } + return stringBuilder.toString(); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java new file mode 100644 index 000000000..e281e548c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java @@ -0,0 +1,73 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldNotHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%n)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_NOT_HAVE_FLAG = + "Expected state\n\t%s\nto not have\n\t%s\nbut had\n\t[%s]"; + + private ShouldNotHaveFlag(long currentState, long expectedFlag, String actualFlags) { + super( + SHOULD_NOT_HAVE_FLAG, + toBinaryString(currentState), + FLAGS_NAMES.get(expectedFlag), + actualFlags); + } + + static ErrorMessageFactory shouldNotHaveFlag(long currentState, long expectedFlag) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(String.format(FLAGS_NAMES.get(REQUEST_MASK), requestN)); + } + return new ShouldNotHaveFlag(currentState, expectedFlag, stringBuilder.toString()); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java new file mode 100644 index 000000000..64253984b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java @@ -0,0 +1,161 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.ShouldHaveFlag.*; +import static io.rsocket.core.ShouldNotHaveFlag.shouldNotHaveFlag; +import static io.rsocket.core.StateUtils.*; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.internal.Failures; + +public class StateAssert extends AbstractAssert, AtomicLongFieldUpdater> { + + public static StateAssert assertThat(AtomicLongFieldUpdater updater, T instance) { + return new StateAssert<>(updater, instance); + } + + public static StateAssert assertThat( + FireAndForgetRequesterMono instance) { + return new StateAssert<>(FireAndForgetRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestResponseRequesterMono instance) { + return new StateAssert<>(RequestResponseRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestStreamRequesterFlux instance) { + return new StateAssert<>(RequestStreamRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelRequesterFlux instance) { + return new StateAssert<>(RequestChannelRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelResponderSubscriber instance) { + return new StateAssert<>(RequestChannelResponderSubscriber.STATE, instance); + } + + private final Failures failures = Failures.instance(); + private final T instance; + + public StateAssert(AtomicLongFieldUpdater updater, T instance) { + super(updater, StateAssert.class); + this.instance = instance; + } + + public StateAssert isUnsubscribed() { + long currentState = actual.get(instance); + if (isSubscribed(currentState) || StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, UNSUBSCRIBED_STATE)); + } + return this; + } + + public StateAssert hasSubscribedFlagOnly() { + long currentState = actual.get(instance); + if (currentState != SUBSCRIBED_FLAG) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasSubscribedFlag() { + long currentState = actual.get(instance); + if (!isSubscribed(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasRequestN(long n) { + long currentState = actual.get(instance); + if (extractRequestN(currentState) != n) { + throw failures.failure(info, shouldHaveRequestN(currentState, n)); + } + return this; + } + + public StateAssert hasRequestNBetween(long min, long max) { + long currentState = actual.get(instance); + final long requestN = extractRequestN(currentState); + if (requestN < min || requestN > max) { + throw failures.failure(info, shouldHaveRequestNBetween(currentState, min, max)); + } + return this; + } + + public StateAssert hasFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (!isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasNoFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasReassemblingFlag() { + long currentState = actual.get(instance); + if (!isReassembling(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasNoReassemblingFlag() { + long currentState = actual.get(instance); + if (isReassembling(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasInboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isInboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, INBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert hasOutboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isOutboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, OUTBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert isTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, TERMINATED_STATE)); + } + return this; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java index 00248b6d8..16bd9f16e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -16,27 +16,28 @@ package io.rsocket.core; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; -import io.rsocket.internal.SynchronizedIntObjectHashMap; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class StreamIdSupplierTest { @Test public void testClientSequence() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); - assertEquals(1, s.nextStreamId(map)); - assertEquals(3, s.nextStreamId(map)); - assertEquals(5, s.nextStreamId(map)); + assertThat(s.nextStreamId(map)).isEqualTo(1); + assertThat(s.nextStreamId(map)).isEqualTo(3); + assertThat(s.nextStreamId(map)).isEqualTo(5); } @Test public void testServerSequence() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.serverSupplier(); assertEquals(2, s.nextStreamId(map)); assertEquals(4, s.nextStreamId(map)); @@ -45,7 +46,7 @@ public void testServerSequence() { @Test public void testClientIsValid() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); assertFalse(s.isBeforeOrCurrent(1)); @@ -68,7 +69,7 @@ public void testClientIsValid() { @Test public void testServerIsValid() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.serverSupplier(); assertFalse(s.isBeforeOrCurrent(2)); @@ -91,7 +92,7 @@ public void testServerIsValid() { @Test public void testWrap() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = new StreamIdSupplier(Integer.MAX_VALUE - 3); assertEquals(2147483646, s.nextStreamId(map)); @@ -107,7 +108,7 @@ public void testWrap() { @Test public void testSkipFound() { - IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + IntObjectMap map = new IntObjectHashMap<>(); map.put(5, new Object()); map.put(9, new Object()); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..e282d72d5 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,281 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import reactor.core.Exceptions; +import reactor.util.annotation.Nullable; + +final class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + static final String DATA_CONTENT = "testData"; + static final String METADATA_CONTENT = "testMetadata"; + + final Throwable error; + + TestRequesterResponderSupport( + @Nullable Throwable error, + StreamIdSupplier streamIdSupplier, + DuplexConnection connection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + (__) -> requestInterceptor); + this.error = error; + } + + @Override + public TestDuplexConnection getDuplexConnection() { + return (TestDuplexConnection) super.getDuplexConnection(); + } + + static Payload genericPayload(LeaksTrackingByteBufAllocator allocator) { + ByteBuf data = allocator.buffer(); + data.writeCharSequence(DATA_CONTENT, CharsetUtil.UTF_8); + + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence(METADATA_CONTENT, CharsetUtil.UTF_8); + + return ByteBufPayload.create(data, metadata); + } + + static Payload fixedSizePayload(LeaksTrackingByteBufAllocator allocator, int contentSize) { + final int dataSize = ThreadLocalRandom.current().nextInt(0, contentSize); + final byte[] dataBytes = new byte[dataSize]; + ThreadLocalRandom.current().nextBytes(dataBytes); + ByteBuf data = allocator.buffer(dataSize); + data.writeBytes(dataBytes); + + ByteBuf metadata; + int metadataSize = contentSize - dataSize; + if (metadataSize > 0) { + final byte[] metadataBytes = new byte[metadataSize]; + metadata = allocator.buffer(metadataSize); + metadata.writeBytes(metadataBytes); + } else { + metadata = ThreadLocalRandom.current().nextBoolean() ? Unpooled.EMPTY_BUFFER : null; + } + + return ByteBufPayload.create(data, metadata); + } + + static Payload randomPayload(LeaksTrackingByteBufAllocator allocator) { + boolean hasMetadata = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadataByteBuf; + if (hasMetadata) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(0, 512)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + } else { + metadataByteBuf = null; + } + byte[] randomData = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomData); + + ByteBuf dataByteBuf = allocator.buffer().writeBytes(randomData); + return ByteBufPayload.create(dataByteBuf, metadataByteBuf); + } + + static Payload randomMetadataOnlyPayload(LeaksTrackingByteBufAllocator allocator) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + ByteBuf metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + + return ByteBufPayload.create(Unpooled.EMPTY_BUFFER, metadataByteBuf); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload) { + + return prepareFragments(allocator, mtu, payload, FrameType.NEXT_COMPLETE); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload, FrameType frameType) { + + boolean hasMetadata = payload.hasMetadata(); + ByteBuf data = payload.sliceData(); + ByteBuf metadata = payload.sliceMetadata(); + ArrayList fragments = new ArrayList<>(); + + fragments.add( + frameType.hasInitialRequestN() + ? FragmentationUtils.encodeFirstFragment( + allocator, mtu, 1L, frameType, 1, hasMetadata, metadata, data) + : FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, 1, hasMetadata, metadata, data)); + + while (metadata.isReadable() || data.isReadable()) { + fragments.add( + FragmentationUtils.encodeFollowsFragment(allocator, mtu, 1, true, metadata, data)); + } + + return fragments; + } + + @Override + public synchronized int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); + + if (error != null) { + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + @Override + public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); + + if (error != null) { + super.remove(nextStreamId, frameHandler); + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + public static TestRequesterResponderSupport client( + @Nullable Throwable e, @Nullable RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor, + e); + } + + public static TestRequesterResponderSupport client(@Nullable Throwable e) { + return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize, @Nullable Throwable e) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + mtu, + maxFrameLength, + maxInboundPayloadSize, + null, + e); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + return client(duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + return client( + duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, requestInterceptor, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor, + @Nullable Throwable e) { + return new TestRequesterResponderSupport( + e, + StreamIdSupplier.clientSupplier(), + duplexConnection, + mtu, + maxFrameLength, + maxInboundPayloadSize, + requestInterceptor); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize) { + return client(mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client(int mtu, int maxFrameLength) { + return client(mtu, maxFrameLength, Integer.MAX_VALUE); + } + + public static TestRequesterResponderSupport client(int mtu) { + return client(mtu, FRAME_LENGTH_MASK); + } + + public static TestRequesterResponderSupport client() { + return client(0); + } + + public static TestRequesterResponderSupport client(RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor); + } + + public TestRequesterResponderSupport assertNoActiveStreams() { + Assertions.assertThat(activeStreams).isEmpty(); + return this; + } + + public TestRequesterResponderSupport assertHasStream(int i, FrameHandler stream) { + Assertions.assertThat(activeStreams).containsEntry(i, stream); + return this; + } + + @Override + public LeaksTrackingByteBufAllocator getAllocator() { + return (LeaksTrackingByteBufAllocator) super.getAllocator(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java deleted file mode 100644 index e0ebf5064..000000000 --- a/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.rsocket.core; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import org.junit.Test; - -public class TestingStuff { - private String f = "00000001110000000068656c6c6f"; - private String f1 = - "00000001286004232e667127bb590fb097cf657776761dcdfe84863f67da47e9de9ac72197424116aae3aadf9e4d8347d8f7923107f7eacb56f741c82327e9c4cbe23e92b5afc306aa24a2153e27082ba0bb1e707bed43100ea84a87539a23442e10431584a42f6eb78fdf140fb14ee71cf4a23ad644dcd3ebceb86e4d0617aa0d2cfee56ce1afea5062842580275b5fdc96cae1bbe9f3a129148dbe1dfc44c8e11aa6a7ec8dafacbbdbd0b68731c16bd16c21eb857c9eb1bb6c359415674b6d41d14e99b1dd56a40fc836d723dd534e83c44d6283745332c627e13bcfc2cd483bccec67232fff0b2ccb7388f0d37da27562d7c3635fef061767400e45729bdef57ca8c041e33074ea1a42004c1b8cb02eb3afeaf5f6d82162d4174c549f840bdb88632faf2578393194f67538bf581a22f31850f88831af632bdaf32c80aa6d96a7afc20b8067c4f9d17859776e4c40beafff18a848df45927ca1c9b024ef278a9fb60bdf965b5822b64bebc74a8b7d95a4bd9d1c1fc82b4fbacc29e36458a878079ddd402788a462528d6c79df797218563cc70811c09b154588a3edd2e948bb61db7b3a36774e0bd5ab67fec4bf1e70811733213f292a728389473b9f68d288ac481529e10cfd428b14ad3f4592d1cc6dd08b1a7842bb492b51057c4d88ac5d538174560f94b49dce6d20ef71671d2e80c2b92ead6d4a26ed8f4187a563cb53eb0c558fe9f77b2133e835e2d2e671978e82a6f60ed61a6a945e39fe0dedcf73d7cb80253a5eb9f311c78ef2587649436f4ab42bcc882faba7bfd57d451407a07ce1d5ac7b5f0cf1ef84047c92e3fbdb64128925ef6e87def450ae8a1643e9906b7dc1f672bd98e012df3039f2ee412909f4b03db39f45b83955f31986b6fd3b5e4f26b6ec2284dcf55ff5fbbfbfb31cd6b22753c6435dbd3ec5558132c6ede9babd7945ac6e697d28b9697f9b2450db2b643a1abc4c9ad5bfa4529d0e1f261df1da5ee035738a5d8c536466fa741e9190c58cf1cacc819838a6b20d85f901f026c66dbaf23cde3a12ce4b443ef15cc247ba48cc0812c6f2c834c5773f3d4042219727404f0f2640cab486e298ae9f1c2f7a7e6f0619f130895d9f41d343fbdb05d68d6e0308d8d046314811066a13300b1346b8762919d833de7f55fea919ad55500ba4ec7e100b32bbabbf9d378eab61532fd91d4d1977db72b828e8d700062b045459d7729f140d889a67472a035d564384844ff16697743e4017e2bf21511ebb4c939bbab202bf6ef59e2be557027272f1bb21c325cf3e0432120bccba17bea52a7621031466e7973415437cd50cc950e63e6e2d17aad36f7a943892901e763e19082260b88f8971b35b4d9cc8725d6e4137b4648427ae68255e076dfb511871de0f7100d2ece6c8be88a0326ba8d73b5c9883f83c0dccd362e61cb16c7a0cc5ff00f7"; - private String f2 = "00000003110000000068656c6c6f"; - - @Test - public void testStuff() { - ByteBuf byteBuf = Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(f1)); - System.out.println(ByteBufUtil.prettyHexDump(byteBuf)); - - new DefaultConnectionSetupPayload(byteBuf); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java index b3f596a37..a316aed8b 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -31,6 +31,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.ErrorFrameCodec; import java.util.concurrent.ThreadLocalRandom; import org.junit.jupiter.api.DisplayName; @@ -42,14 +43,18 @@ final class ExceptionsTest { void fromApplicationException() { ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(ApplicationErrorException.class) - .hasMessage("test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns CanceledException") @@ -57,28 +62,37 @@ void fromApplicationException() { void fromCanceledException() { ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(CanceledException.class) - .hasMessage("test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CanceledException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns ConnectionCloseException") @Test void fromConnectionCloseException() { ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(ConnectionCloseException.class) - .hasMessage("test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionCloseException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns ConnectionErrorException") @@ -86,122 +100,152 @@ void fromConnectionCloseException() { void fromConnectionErrorException() { ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(ConnectionErrorException.class) - .hasMessage("test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionErrorException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") @Test void fromIllegalErrorFrame() { ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns InvalidException") @Test void fromInvalidException() { ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(InvalidException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(InvalidException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns InvalidSetupException") @Test void fromInvalidSetupException() { ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(InvalidSetupException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(InvalidSetupException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns RejectedException") @Test void fromRejectedException() { ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); + try { - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(RejectedException.class) - .withFailMessage("test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(RejectedException.class) + .withFailMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns RejectedResumeException") @Test void fromRejectedResumeException() { ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(RejectedResumeException.class) - .hasMessage("test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedResumeException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns RejectedSetupException") @Test void fromRejectedSetupException() { ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(RejectedSetupException.class) - .withFailMessage("test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedSetupException.class) + .withFailMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns UnsupportedSetupException") @Test void fromUnsupportedSetupException() { ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(UnsupportedSetupException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(UnsupportedSetupException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns CustomRSocketException") @Test void fromCustomRSocketException() { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { int randomCode = ThreadLocalRandom.current().nextBoolean() ? ThreadLocalRandom.current() @@ -209,14 +253,18 @@ void fromCustomRSocketException() { : ThreadLocalRandom.current() .nextInt(ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); - - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(CustomRSocketException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") - .isInstanceOf(IllegalArgumentException.class); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java index ccf7649d2..9aa8fc364 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java @@ -18,10 +18,11 @@ import static org.assertj.core.api.Assertions.assertThat; +import io.rsocket.RSocketErrorException; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -interface RSocketExceptionTest { +interface RSocketExceptionTest { @DisplayName("constructor does not throw NullPointerException with null message") @Test diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java index 6c2e63730..15685aa43 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java @@ -1,6 +1,9 @@ package io.rsocket.exceptions; -public class TestRSocketException extends RSocketException { +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; + +public class TestRSocketException extends RSocketErrorException { private static final long serialVersionUID = 7873267740343446585L; private final int errorCode; @@ -14,7 +17,7 @@ public class TestRSocketException extends RSocketException { * @throws IllegalArgumentException if {@code errorCode} is out of allowed range */ public TestRSocketException(int errorCode, String message) { - super(message); + super(ErrorFrameCodec.APPLICATION_ERROR, message); this.errorCode = errorCode; } @@ -28,7 +31,7 @@ public TestRSocketException(int errorCode, String message) { * @throws IllegalArgumentException if {@code errorCode} is out of allowed range */ public TestRSocketException(int errorCode, String message, Throwable cause) { - super(message, cause); + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); this.errorCode = errorCode; } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java deleted file mode 100644 index 246fa1184..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; -import static org.mockito.Mockito.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.DuplexConnection; -import io.rsocket.buffer.LeaksTrackingByteBufAllocator; -import io.rsocket.frame.*; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.mockito.Mockito; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -final class FragmentationDuplexConnectionTest { - private static byte[] data = new byte[1024]; - private static byte[] metadata = new byte[1024]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); - - { - Mockito.when(delegate.onClose()).thenReturn(Mono.never()); - } - - @SuppressWarnings("unchecked") - private final ArgumentCaptor> publishers = - ArgumentCaptor.forClass(Publisher.class); - - private LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - - @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") - @Test - void constructorInvalidMaxFragmentSize() { - assertThatIllegalArgumentException() - .isThrownBy( - () -> - new FragmentationDuplexConnection( - delegate, Integer.MIN_VALUE, Integer.MAX_VALUE, "")) - .withMessage("The smallest allowed mtu size is 64 bytes, provided: -2147483648"); - } - - @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") - @Test - void constructorMtuLessThanMin() { - assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(delegate, 2, Integer.MAX_VALUE, "")) - .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); - } - - @DisplayName("constructor throws NullPointerException with null delegate") - @Test - void constructorNullDelegate() { - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(null, 64, Integer.MAX_VALUE, "")) - .withMessage("delegate must not be null"); - } - - @DisplayName("fragments data") - @Test - void sendData() { - ByteBuf encode = - RequestResponseFrameCodec.encode( - allocator, 1, false, Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(data)); - - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new FragmentationDuplexConnection(delegate, 64, Integer.MAX_VALUE, "").sendOne(encode.retain()); - - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())) - .expectNextCount(17) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java deleted file mode 100644 index d27905f90..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java +++ /dev/null @@ -1,56 +0,0 @@ -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameUtil; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.util.DefaultPayload; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; - -public class FragmentationIntegrationTest { - private static byte[] data = new byte[128]; - private static byte[] metadata = new byte[128]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - - @DisplayName("fragments and reassembles data") - @Test - void fragmentAndReassembleData() { - ByteBuf frame = - PayloadFrameCodec.encodeNextCompleteReleasingPayload( - allocator, 2, DefaultPayload.create(data)); - System.out.println(FrameUtil.toString(frame)); - - frame.retain(); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 64, frame, FrameHeaderCodec.frameType(frame)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - ByteBuf assembled = - Flux.from(fragments) - .doOnNext(byteBuf -> System.out.println(FrameUtil.toString(byteBuf))) - .handle(reassembler::reassembleFrame) - .blockLast(); - - System.out.println("assembled"); - String s = FrameUtil.toString(assembled); - System.out.println(s); - - Assert.assertEquals(FrameHeaderCodec.frameType(frame), FrameHeaderCodec.frameType(assembled)); - Assert.assertEquals(frame.readableBytes(), assembled.readableBytes()); - Assert.assertEquals(PayloadFrameCodec.data(frame), PayloadFrameCodec.data(assembled)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java deleted file mode 100644 index 4548e4696..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ /dev/null @@ -1,350 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.frame.*; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.test.StepVerifier; - -final class FrameFragmenterTest { - private static byte[] data = new byte[4096]; - private static byte[] metadata = new byte[4096]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - - @Test - void testGettingData() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - ByteBuf fnf = - RequestFireAndForgetFrameCodec.encode( - allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - ByteBuf rs = - RequestStreamFrameCodec.encode(allocator, 1, true, 1, null, Unpooled.wrappedBuffer(data)); - ByteBuf rc = - RequestChannelFrameCodec.encode( - allocator, 1, true, false, 1, null, Unpooled.wrappedBuffer(data)); - - ByteBuf data = FrameFragmenter.getData(rr, FrameType.REQUEST_RESPONSE); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - - data = FrameFragmenter.getData(fnf, FrameType.REQUEST_FNF); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - - data = FrameFragmenter.getData(rs, FrameType.REQUEST_STREAM); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - - data = FrameFragmenter.getData(rc, FrameType.REQUEST_CHANNEL); - Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); - data.release(); - } - - @Test - void testGettingMetadata() { - ByteBuf rr = - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - ByteBuf fnf = - RequestFireAndForgetFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - ByteBuf rs = - RequestStreamFrameCodec.encode( - allocator, 1, true, 1, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - ByteBuf rc = - RequestChannelFrameCodec.encode( - allocator, - 1, - true, - false, - 1, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)); - - ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - - data = FrameFragmenter.getMetadata(fnf, FrameType.REQUEST_FNF); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - - data = FrameFragmenter.getMetadata(rs, FrameType.REQUEST_STREAM); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - - data = FrameFragmenter.getMetadata(rc, FrameType.REQUEST_CHANNEL); - Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); - data.release(); - } - - @Test - void returnEmptBufferWhenNoMetadataPresent() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - - ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); - Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); - data.release(); - } - - @DisplayName("encode first frame") - @Test - void encodeFirstFrameWithData() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rr, - FrameType.REQUEST_RESPONSE, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestResponseFrameCodec.data(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); - Assert.assertEquals(byteBuf, data); - - Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first channel frame") - @Test - void encodeFirstWithDataChannel() { - ByteBuf rc = - RequestChannelFrameCodec.encode( - allocator, 1, true, false, 10, null, Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rc, - FrameType.REQUEST_CHANNEL, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_CHANNEL, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertEquals(10, RequestChannelFrameCodec.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestChannelFrameCodec.data(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); - Assert.assertEquals(byteBuf, data); - - Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first stream frame") - @Test - void encodeFirstWithDataStream() { - ByteBuf rc = - RequestStreamFrameCodec.encode(allocator, 1, true, 50, null, Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rc, - FrameType.REQUEST_STREAM, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestStreamFrameCodec.data(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); - Assert.assertEquals(byteBuf, data); - - Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first frame with only metadata") - @Test - void encodeFirstFrameWithMetadata() { - ByteBuf rr = - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rr, - FrameType.REQUEST_RESPONSE, - 1, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestResponseFrameCodec.data(fragment); - Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); - - Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("encode first stream frame with data and metadata") - @Test - void encodeFirstWithDataAndMetadataStream() { - ByteBuf rc = - RequestStreamFrameCodec.encode( - allocator, 1, true, 50, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - - ByteBuf fragment = - FrameFragmenter.encodeFirstFragment( - allocator, - 256, - rc, - FrameType.REQUEST_STREAM, - 1, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)); - - Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); - Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - - ByteBuf data = RequestStreamFrameCodec.data(fragment); - Assert.assertEquals(0, data.readableBytes()); - - ByteBuf metadata = RequestStreamFrameCodec.metadata(fragment); - ByteBuf byteBuf = Unpooled.wrappedBuffer(this.metadata).readSlice(metadata.readableBytes()); - Assert.assertEquals(byteBuf, metadata); - - Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); - } - - @DisplayName("fragments frame with only data") - @Test - void fragmentData() { - ByteBuf rr = - RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE); - - StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) - .expectNextCount(1) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .expectNextCount(2) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("fragments frame with only metadata") - @Test - void fragmentMetadata() { - ByteBuf rr = - RequestStreamFrameCodec.encode( - allocator, 1, true, 10, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_STREAM); - - StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) - .expectNextCount(1) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .expectNextCount(2) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("fragments frame with data and metadata") - @Test - void fragmentDataAndMetadata() { - ByteBuf rr = - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); - - Publisher fragments = - FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE); - - StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .expectNextCount(6) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); - }) - .verifyComplete(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java deleted file mode 100644 index 6f9762042..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ /dev/null @@ -1,526 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.*; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; -import org.assertj.core.api.Assertions; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Flux; -import reactor.test.StepVerifier; - -final class FrameReassemblerTest { - private static byte[] data = new byte[1024]; - private static byte[] metadata = new byte[1024]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - - @DisplayName("reassembles data") - @Test - void reassembleData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - ReferenceCountUtil.safeRelease(data); - } - - @DisplayName("pass through frames without follows") - @Test - void passthrough() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, false, null, Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents(true, Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - ReferenceCountUtil.safeRelease(data); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata request channel") - @Test - void reassembleMetadataChannel() { - List byteBufs = - Arrays.asList( - RequestChannelFrameCodec.encode( - allocator, - 1, - true, - false, - 100, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestChannelFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - Assert.assertEquals(100, RequestChannelFrameCodec.initialRequestN(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - - ReferenceCountUtil.safeRelease(metadata); - } - - @DisplayName("reassembles metadata request stream") - @Test - void reassembleMetadataStream() { - List byteBufs = - Arrays.asList( - RequestStreamFrameCodec.encode( - allocator, 1, true, 250, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestStreamFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - Assert.assertEquals(250, RequestChannelFrameCodec.initialRequestN(byteBuf)); - ReferenceCountUtil.safeRelease(byteBuf); - }) - .verifyComplete(); - - ReferenceCountUtil.safeRelease(metadata); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - - Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.data), - Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), - Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - - StepVerifier.create(assembled) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); - }) - .verifyComplete(); - ReferenceCountUtil.safeRelease(data); - ReferenceCountUtil.safeRelease(metadata); - } - - @DisplayName("cancel removes inflight frames") - @Test - public void cancelBeforeAssembling() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); - - Assert.assertTrue(reassembler.headers.containsKey(1)); - Assert.assertTrue(reassembler.metadata.containsKey(1)); - Assert.assertTrue(reassembler.data.containsKey(1)); - - Flux.just(CancelFrameCodec.encode(allocator, 1)) - .handle(reassembler::reassembleFrame) - .blockLast(); - - Assert.assertFalse(reassembler.headers.containsKey(1)); - Assert.assertFalse(reassembler.metadata.containsKey(1)); - Assert.assertFalse(reassembler.data.containsKey(1)); - } - - @ParameterizedTest(name = "throws error if reassembling payload size exist {0}") - @ValueSource(ints = {64, 1024, 2048, 4096}) - public void errorTooBigPayload(int maxFrameLength) { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, maxFrameLength); - - Assertions.assertThatThrownBy( - Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame)::blockLast) - .hasMessage("Reassembled payload went out of allowed size") - .isExactlyInstanceOf(IllegalStateException.class); - } - - @DisplayName("dispose should clean up maps") - @Test - public void dispose() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - FrameReassembler reassembler = new FrameReassembler(allocator, Integer.MAX_VALUE); - Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); - - Assert.assertTrue(reassembler.headers.containsKey(1)); - Assert.assertTrue(reassembler.metadata.containsKey(1)); - Assert.assertTrue(reassembler.data.containsKey(1)); - - reassembler.dispose(); - - Assert.assertFalse(reassembler.headers.containsKey(1)); - Assert.assertFalse(reassembler.metadata.containsKey(1)); - Assert.assertFalse(reassembler.data.containsKey(1)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java deleted file mode 100644 index 061c17ada..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java +++ /dev/null @@ -1,334 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static org.mockito.Mockito.RETURNS_SMART_NULLS; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; -import io.rsocket.DuplexConnection; -import io.rsocket.buffer.LeaksTrackingByteBufAllocator; -import io.rsocket.frame.CancelFrameCodec; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameType; -import io.rsocket.frame.PayloadFrameCodec; -import io.rsocket.frame.RequestResponseFrameCodec; -import java.time.Duration; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; -import org.assertj.core.api.Assertions; -import org.junit.Assert; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.test.StepVerifier; - -final class ReassembleDuplexConnectionTest { - private static byte[] data = new byte[1024]; - private static byte[] metadata = new byte[1024]; - - static { - ThreadLocalRandom.current().nextBytes(data); - ThreadLocalRandom.current().nextBytes(metadata); - } - - private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); - - private LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - - @DisplayName("reassembles data") - @Test - void reassembleData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - false, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); - Assert.assertEquals(metadata, m); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data)), - PayloadFrameCodec.encode( - allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble a non-fragment frame") - @Test - void reassembleNonFragment() { - ByteBuf encode = - RequestResponseFrameCodec.encode(allocator, 1, false, null, Unpooled.wrappedBuffer(data)); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals( - Unpooled.wrappedBuffer(data), RequestResponseFrameCodec.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble non fragmentable frame") - @Test - void reassembleNonFragmentableFrame() { - ByteBuf encode = CancelFrameCodec.encode(allocator, 2); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, Integer.MAX_VALUE) - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.CANCEL, FrameHeaderCodec.frameType(byteBuf)); - }) - .verifyComplete(); - } - - @ParameterizedTest(name = "throws error if reassembling payload size exist {0}") - @ValueSource(ints = {64, 1024, 2048, 4096}) - public void errorTooBigPayload(int maxFrameLength) { - List byteBufs = - Arrays.asList( - RequestResponseFrameCodec.encode( - allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.EMPTY_BUFFER), - PayloadFrameCodec.encode( - allocator, - 1, - true, - false, - true, - Unpooled.wrappedBuffer(metadata), - Unpooled.wrappedBuffer(data))); - - MonoProcessor onClose = MonoProcessor.create(); - - when(delegate.receive()) - .thenReturn( - Flux.fromIterable(byteBufs) - .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release)); - when(delegate.onClose()).thenReturn(onClose); - when(delegate.alloc()).thenReturn(allocator); - - new ReassemblyDuplexConnection(delegate, maxFrameLength) - .receive() - .doFinally(__ -> onClose.onComplete()) - .as(StepVerifier::create) - .expectErrorSatisfies( - t -> - Assertions.assertThat(t) - .hasMessage("Reassembled payload went out of allowed size") - .isExactlyInstanceOf(IllegalStateException.class)) - .verify(Duration.ofSeconds(1)); - - allocator.assertHasNoLeaks(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index 63300c718..b12d72b51 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -18,21 +18,35 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.api.Assertions; import org.assertj.core.presentation.StandardRepresentation; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; -public final class ByteBufRepresentation extends StandardRepresentation { +public final class ByteBufRepresentation extends StandardRepresentation + implements BeforeAllCallback { + + @Override + public void beforeAll(ExtensionContext context) { + Assertions.useRepresentation(this); + } @Override protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { try { String normalBufferString = object.toString(); - String prettyHexDump = ByteBufUtil.prettyHexDump((ByteBuf) object); - return new StringBuilder() - .append(normalBufferString) - .append("\n") - .append(prettyHexDump) - .toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 128) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } } catch (IllegalReferenceCountException e) { // noops } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java index fe05335d2..4815bfb8e 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java @@ -16,11 +16,12 @@ package io.rsocket.frame; +import static org.assertj.core.api.Assertions.assertThat; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import java.util.Arrays; -import org.junit.Assert; import org.junit.jupiter.api.Test; public class ResumeFrameCodecTest { @@ -31,10 +32,10 @@ void testEncoding() { Arrays.fill(tokenBytes, (byte) 1); ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); ByteBuf byteBuf = ResumeFrameCodec.encode(ByteBufAllocator.DEFAULT, token, 21, 12); - Assert.assertEquals(ResumeFrameCodec.CURRENT_VERSION, ResumeFrameCodec.version(byteBuf)); - Assert.assertEquals(token, ResumeFrameCodec.token(byteBuf)); - Assert.assertEquals(21, ResumeFrameCodec.lastReceivedServerPos(byteBuf)); - Assert.assertEquals(12, ResumeFrameCodec.firstAvailableClientPos(byteBuf)); + assertThat(ResumeFrameCodec.version(byteBuf)).isEqualTo(ResumeFrameCodec.CURRENT_VERSION); + assertThat(ResumeFrameCodec.token(byteBuf)).isEqualTo(token); + assertThat(ResumeFrameCodec.lastReceivedServerPos(byteBuf)).isEqualTo(21); + assertThat(ResumeFrameCodec.firstAvailableClientPos(byteBuf)).isEqualTo(12); byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java index 33dd8eb70..b818d579d 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java @@ -1,16 +1,17 @@ package io.rsocket.frame; +import static org.assertj.core.api.Assertions.assertThat; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class ResumeOkFrameCodecTest { @Test public void testEncoding() { ByteBuf byteBuf = ResumeOkFrameCodec.encode(ByteBufAllocator.DEFAULT, 42); - Assert.assertEquals(42, ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)); + assertThat(ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)).isEqualTo(42); byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java index 9607ad327..3317b4618 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java @@ -25,8 +25,8 @@ void testEncodingNoResume() { assertEquals(0, SetupFrameCodec.resumeToken(frame).readableBytes()); assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); - assertEquals(metadata, SetupFrameCodec.metadata(frame)); - assertEquals(data, SetupFrameCodec.data(frame)); + assertEquals(payload.metadata(), SetupFrameCodec.metadata(frame)); + assertEquals(payload.data(), SetupFrameCodec.data(frame)); assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); frame.release(); } @@ -49,8 +49,8 @@ void testEncodingResume() { assertEquals(token, SetupFrameCodec.resumeToken(frame)); assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); - assertEquals(metadata, SetupFrameCodec.metadata(frame)); - assertEquals(data, SetupFrameCodec.data(frame)); + assertEquals(payload.metadata(), SetupFrameCodec.metadata(frame)); + assertEquals(payload.data(), SetupFrameCodec.data(frame)); assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java deleted file mode 100644 index 63acc40aa..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import static org.junit.Assert.assertEquals; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.buffer.LeaksTrackingByteBufAllocator; -import io.rsocket.frame.*; -import io.rsocket.plugins.InitializingInterceptorRegistry; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.util.DefaultPayload; -import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Before; -import org.junit.Test; - -public class ClientServerInputMultiplexerTest { - private TestDuplexConnection source; - private ClientServerInputMultiplexer clientMultiplexer; - private LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - private ClientServerInputMultiplexer serverMultiplexer; - - @Before - public void setup() { - source = new TestDuplexConnection(allocator); - clientMultiplexer = - new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); - serverMultiplexer = - new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); - } - - @Test - public void clientSplits() { - AtomicInteger clientFrames = new AtomicInteger(); - AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger setupFrames = new AtomicInteger(); - - clientMultiplexer - .asClientConnection() - .receive() - .doOnNext(f -> clientFrames.incrementAndGet()) - .subscribe(); - clientMultiplexer - .asServerConnection() - .receive() - .doOnNext(f -> serverFrames.incrementAndGet()) - .subscribe(); - clientMultiplexer - .asSetupConnection() - .receive() - .doOnNext(f -> setupFrames.incrementAndGet()) - .subscribe(); - - source.addToReceivedBuffer(errorFrame(1)); - assertEquals(1, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(1)); - assertEquals(2, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(leaseFrame()); - assertEquals(3, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(keepAliveFrame()); - assertEquals(4, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(2)); - assertEquals(4, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(0)); - assertEquals(5, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(metadataPushFrame()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(setupFrame()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(resumeFrame()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(2, setupFrames.get()); - - source.addToReceivedBuffer(resumeOkFrame()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(3, setupFrames.get()); - } - - @Test - public void serverSplits() { - AtomicInteger clientFrames = new AtomicInteger(); - AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger setupFrames = new AtomicInteger(); - - serverMultiplexer - .asClientConnection() - .receive() - .doOnNext(f -> clientFrames.incrementAndGet()) - .subscribe(); - serverMultiplexer - .asServerConnection() - .receive() - .doOnNext(f -> serverFrames.incrementAndGet()) - .subscribe(); - serverMultiplexer - .asSetupConnection() - .receive() - .doOnNext(f -> setupFrames.incrementAndGet()) - .subscribe(); - - source.addToReceivedBuffer(errorFrame(1)); - assertEquals(1, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(1)); - assertEquals(2, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(leaseFrame()); - assertEquals(2, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(keepAliveFrame()); - assertEquals(2, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(2)); - assertEquals(2, clientFrames.get()); - assertEquals(3, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(0)); - assertEquals(2, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(metadataPushFrame()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(0, setupFrames.get()); - - source.addToReceivedBuffer(setupFrame()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(resumeFrame()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(2, setupFrames.get()); - - source.addToReceivedBuffer(resumeOkFrame()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(3, setupFrames.get()); - } - - private ByteBuf resumeFrame() { - return ResumeFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER, 0, 0); - } - - private ByteBuf setupFrame() { - return SetupFrameCodec.encode( - ByteBufAllocator.DEFAULT, - false, - 0, - 42, - "application/octet-stream", - "application/octet-stream", - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER)); - } - - private ByteBuf leaseFrame() { - return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); - } - - private ByteBuf errorFrame(int i) { - return ErrorFrameCodec.encode(allocator, i, new Exception()); - } - - private ByteBuf resumeOkFrame() { - return ResumeOkFrameCodec.encode(allocator, 0); - } - - private ByteBuf keepAliveFrame() { - return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); - } - - private ByteBuf metadataPushFrame() { - return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java index 7bf975543..343a93beb 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-present the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,115 +16,351 @@ package io.rsocket.internal; -import io.rsocket.Payload; -import io.rsocket.util.ByteBufPayload; -import io.rsocket.util.EmptyPayload; -import java.util.concurrent.CountDownLatch; -import org.junit.Assert; -import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; public class UnboundedProcessorTest { - @Test - public void testOnNextBeforeSubscribe_10() { - testOnNextBeforeSubscribeN(10); - } - @Test - public void testOnNextBeforeSubscribe_100() { - testOnNextBeforeSubscribeN(100); + @BeforeAll + public static void setup() { + Hooks.onErrorDropped(__ -> {}); } - @Test - public void testOnNextBeforeSubscribe_10_000() { - testOnNextBeforeSubscribeN(10_000); + @AfterAll + public static void teardown() { + Hooks.resetOnErrorDropped(); } - @Test - public void testOnNextBeforeSubscribe_100_000() { - testOnNextBeforeSubscribeN(100_000); - } + @ParameterizedTest( + name = + "Test that emitting {0} onNext before subscribe and requestN should deliver all the signals once the subscriber is available") + @ValueSource(ints = {10, 100, 10_000, 100_000, 1_000_000, 10_000_000}) + public void testOnNextBeforeSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); - @Test - public void testOnNextBeforeSubscribe_1_000_000() { - testOnNextBeforeSubscribeN(1_000_000); - } + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + processor.onComplete(); - @Test - public void testOnNextBeforeSubscribe_10_000_000() { - testOnNextBeforeSubscribeN(10_000_000); + StepVerifier.create(processor.count()).expectNext(Long.valueOf(n)).verifyComplete(); } - public void testOnNextBeforeSubscribeN(int n) { - UnboundedProcessor processor = new UnboundedProcessor<>(); + @ParameterizedTest( + name = + "Test that emitting {0} onNext after subscribe and requestN should deliver all the signals") + @ValueSource(ints = {10, 100, 10_000}) + public void testOnNextAfterSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + processor.subscribe(assertSubscriber); for (int i = 0; i < n; i++) { - processor.onNext(EmptyPayload.INSTANCE); + processor.onNext(Unpooled.EMPTY_BUFFER); } - processor.onComplete(); + assertSubscriber.awaitAndAssertNextValueCount(n); + } - long count = processor.count().block(); + @ParameterizedTest( + name = + "Test that prioritized value sending deliver prioritized signals before the others mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void testPrioritizedSending(boolean fusedCase) { + UnboundedProcessor processor = new UnboundedProcessor(); - Assert.assertEquals(n, count); - } + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } - @Test - public void testOnNextAfterSubscribe_10() throws Exception { - testOnNextAfterSubscribeN(10); - } + processor.onNextPrioritized(Unpooled.copiedBuffer("test", CharsetUtil.UTF_8)); - @Test - public void testOnNextAfterSubscribe_100() throws Exception { - testOnNextAfterSubscribeN(100); + assertThat(fusedCase ? processor.poll() : processor.next().block()) + .isNotNull() + .extracting(bb -> bb.toString(CharsetUtil.UTF_8)) + .isEqualTo("test"); } - @Test - public void testOnNextAfterSubscribe_1000() throws Exception { - testOnNextAfterSubscribeN(1000); - } + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensureUnboundedProcessorDisposesQueueProperly(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); - @Test - public void testPrioritizedSending() { - UnboundedProcessor processor = new UnboundedProcessor<>(); + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); - for (int i = 0; i < 1000; i++) { - processor.onNext(EmptyPayload.INSTANCE); + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + }, + unboundedProcessor::dispose, + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest1(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); - processor.onNextPrioritized(ByteBufPayload.create("test")); + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); - Payload closestPayload = processor.next().block(); + unboundedProcessor.subscribe(assertSubscriber); - Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } } - @Test - public void testPrioritizedFused() { - UnboundedProcessor processor = new UnboundedProcessor<>(); + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest2(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> { + unboundedProcessor.subscribe(assertSubscriber); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); - for (int i = 0; i < 1000; i++) { - processor.onNext(EmptyPayload.INSTANCE); + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest3(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + assertSubscriber.cancel(); - processor.onNextPrioritized(ByteBufPayload.create("test")); + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber)); - Payload closestPayload = processor.poll(); + assertSubscriber.values().forEach(ReferenceCountUtil::release); - Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + allocator.assertHasNoLeaks(); + } } - public void testOnNextAfterSubscribeN(int n) throws Exception { - CountDownLatch latch = new CountDownLatch(n); - UnboundedProcessor processor = new UnboundedProcessor<>(); - processor.log().doOnNext(integer -> latch.countDown()).subscribe(); + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest31(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); - for (int i = 0; i < n; i++) { - System.out.println("onNexting -> " + i); - processor.onNext(EmptyPayload.INSTANCE); + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber), + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }, + assertSubscriber::cancel); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + allocator.assertHasNoLeaks(); } + } - processor.drain(); + @ParameterizedTest( + name = + "Ensures that racing between onNext + dispose | downstream async drain should not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensuresAsyncFusionAndDisposureHasNoDeadlock(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + final ByteBuf buffer5 = allocator.buffer(5); + final ByteBuf buffer6 = allocator.buffer(6); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber() + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + unboundedProcessor.onNext(buffer3); + unboundedProcessor.onNext(buffer4); + unboundedProcessor.onNext(buffer5); + unboundedProcessor.onNext(buffer6); + unboundedProcessor.dispose(); + }, + unboundedProcessor::dispose); + + assertSubscriber.await(Duration.ofSeconds(50)).values().forEach(ReferenceCountUtil::release); + } - latch.await(); + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java index 84a589a8d..b6eac9835 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011-2017 Pivotal Software Inc, All Rights Reserved. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.BooleanSupplier; @@ -36,6 +37,7 @@ import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.Fuseable; +import reactor.core.Scannable; import reactor.core.publisher.Operators; import reactor.util.annotation.NonNull; import reactor.util.context.Context; @@ -77,7 +79,7 @@ * @author Stephane Maldini * @author Brian Clozel */ -public class AssertSubscriber implements CoreSubscriber, Subscription { +public class AssertSubscriber implements CoreSubscriber, Subscription, Scannable { /** Default timeout for waiting next values to be received */ public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); @@ -86,6 +88,10 @@ public class AssertSubscriber implements CoreSubscriber, Subscription { private static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(AssertSubscriber.class, "wip"); + @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater NEXT_VALUES = AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); @@ -100,10 +106,14 @@ public class AssertSubscriber implements CoreSubscriber, Subscription { private final CountDownLatch cdl = new CountDownLatch(1); + volatile boolean done; + volatile Subscription s; volatile long requested; + volatile int wip; + volatile List values = new LinkedList<>(); /** The fusion mode to request. */ @@ -377,7 +387,7 @@ public final AssertSubscriber assertError(Class clazz) { } } if (s > 1) { - throw new AssertionError("Multiple errors: " + s, null); + throw new AssertionError("Multiple errors: " + errors, null); } return this; } @@ -854,6 +864,13 @@ public void cancel() { a = S.getAndSet(this, Operators.cancelledSubscription()); if (a != null && a != Operators.cancelledSubscription()) { a.cancel(); + + if (establishedFusionMode == Fuseable.ASYNC) { + final int previousState = markWorkAdded(); + if (!isWorkInProgress(previousState)) { + clearAndFinalize(); + } + } } } } @@ -868,37 +885,121 @@ public final boolean isTerminated() { @Override public void onComplete() { + done = true; completionCount++; + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + cdl.countDown(); } @Override public void onError(Throwable t) { + done = true; errors.add(t); + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + cdl.countDown(); } @Override public void onNext(T t) { if (establishedFusionMode == Fuseable.ASYNC) { - for (; ; ) { - t = qs.poll(); - if (t == null) { - break; - } - valueCount++; - if (valuesStorage) { - List nextValuesSnapshot; - for (; ; ) { - nextValuesSnapshot = values; - nextValuesSnapshot.add(t); - if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { - break; - } + drain(); + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; } } } - } else { + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return state > 0; + } + + int markWorkAdded() { + for (; ; ) { + int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + if ((state & Integer.MAX_VALUE) == Integer.MAX_VALUE) { + return state; + } + int nextState = state + 1; + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + void clearAndFinalize() { + final Fuseable.QueueSubscription qs = this.qs; + for (; ; ) { + int state = this.wip; + + qs.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + void drain() { + final int previousState = markWorkAdded(); + if (isWorkInProgress(previousState)) { + return; + } + + if (isFinalized(previousState)) { + qs.clear(); + return; + } + + T t; + int m = 1; + for (; ; ) { + if (isCancelled()) { + clearAndFinalize(); + break; + } + boolean done = this.done; + t = qs.poll(); + if (t == null) { + if (done) { + clearAndFinalize(); + cdl.countDown(); + return; + } + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + continue; + } valueCount++; if (valuesStorage) { List nextValuesSnapshot; @@ -919,39 +1020,41 @@ public void onSubscribe(Subscription s) { subscriptionCount++; int requestMode = requestedFusionMode; if (requestMode >= 0) { - if (!setWithoutRequesting(s)) { - if (!isCancelled()) { - errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); - } - } else { - if (s instanceof Fuseable.QueueSubscription) { - this.qs = (Fuseable.QueueSubscription) s; + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; - int m = qs.requestFusion(requestMode); - establishedFusionMode = m; + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; - if (m == Fuseable.SYNC) { - for (; ; ) { - T v = qs.poll(); - if (v == null) { - onComplete(); - break; - } + if (!setWithoutRequesting(s)) { + qs.clear(); + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + return; + } - onNext(v); + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; } - } else { - requestDeferred(); + + onNext(v); } } else { requestDeferred(); } + + return; } - } else { - if (!set(s)) { - if (!isCancelled()) { - errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); - } + } + + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); } } } @@ -1143,6 +1246,10 @@ public List values() { return values; } + public List errors() { + return errors; + } + public final AssertSubscriber assertNoEvents() { return assertNoValues().assertNoError().assertNotComplete(); } @@ -1151,4 +1258,20 @@ public final AssertSubscriber assertNoEvents() { public final AssertSubscriber assertIncomplete(T... values) { return assertValues(values).assertNotComplete().assertNoError(); } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) { + return upstream(); + } + + boolean t = isTerminated(); + if (key == Attr.TERMINATED) return t; + if (key == Attr.ERROR) return (!errors.isEmpty() ? errors.get(0) : null); + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + if (key == Attr.CANCELLED) return isCancelled(); + if (key == Attr.RUN_STYLE) return Attr.RunStyle.SYNC; + + return null; + } } diff --git a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java index d5b2eeb41..9ebca34f7 100644 --- a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java +++ b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java @@ -16,71 +16,63 @@ package io.rsocket.lease; -import static org.junit.Assert.*; - -import io.netty.buffer.Unpooled; -import java.time.Duration; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; - public class LeaseImplTest { - - @Test - public void emptyLeaseNoAvailability() { - LeaseImpl empty = LeaseImpl.empty(); - Assertions.assertTrue(empty.isEmpty()); - Assertions.assertFalse(empty.isValid()); - Assertions.assertEquals(0.0, empty.availability(), 1e-5); - } - - @Test - public void emptyLeaseUseNoAvailability() { - LeaseImpl empty = LeaseImpl.empty(); - boolean success = empty.use(); - assertFalse(success); - Assertions.assertEquals(0.0, empty.availability(), 1e-5); - } - - @Test - public void leaseAvailability() { - LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); - Assertions.assertEquals(1.0, lease.availability(), 1e-5); - } - - @Test - public void leaseUseDecreasesAvailability() { - LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); - boolean success = lease.use(); - Assertions.assertTrue(success); - Assertions.assertEquals(0.5, lease.availability(), 1e-5); - Assertions.assertTrue(lease.isValid()); - success = lease.use(); - Assertions.assertTrue(success); - Assertions.assertEquals(0.0, lease.availability(), 1e-5); - Assertions.assertFalse(lease.isValid()); - Assertions.assertEquals(0, lease.getAllowedRequests()); - success = lease.use(); - Assertions.assertFalse(success); - } - - @Test - public void leaseTimeout() { - int numberOfRequests = 1; - LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); - Mono.delay(Duration.ofMillis(100)).block(); - boolean success = lease.use(); - Assertions.assertFalse(success); - Assertions.assertTrue(lease.isExpired()); - Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); - Assertions.assertFalse(lease.isValid()); - } - - @Test - public void useLeaseChangesAllowedRequests() { - int numberOfRequests = 2; - LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); - lease.use(); - assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); - } + // + // @Test + // public void emptyLeaseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // Assertions.assertTrue(empty.isEmpty()); + // Assertions.assertFalse(empty.isValid()); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void emptyLeaseUseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // boolean success = empty.use(); + // assertFalse(success); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void leaseAvailability() { + // LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); + // Assertions.assertEquals(1.0, lease.availability(), 1e-5); + // } + // + // @Test + // public void leaseUseDecreasesAvailability() { + // LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); + // boolean success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.5, lease.availability(), 1e-5); + // Assertions.assertTrue(lease.isValid()); + // success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.0, lease.availability(), 1e-5); + // Assertions.assertFalse(lease.isValid()); + // Assertions.assertEquals(0, lease.getAllowedRequests()); + // success = lease.use(); + // Assertions.assertFalse(success); + // } + // + // @Test + // public void leaseTimeout() { + // int numberOfRequests = 1; + // LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); + // Mono.delay(Duration.ofMillis(100)).block(); + // boolean success = lease.use(); + // Assertions.assertFalse(success); + // Assertions.assertTrue(lease.isExpired()); + // Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); + // Assertions.assertFalse(lease.isValid()); + // } + // + // @Test + // public void useLeaseChangesAllowedRequests() { + // int numberOfRequests = 2; + // LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); + // lease.use(); + // assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); + // } } diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java new file mode 100644 index 000000000..a35e89391 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java @@ -0,0 +1,94 @@ +package io.rsocket.loadbalance; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +@ExtendWith(MockitoExtension.class) +class LoadbalanceRSocketClientTest { + + @Mock private ClientTransport clientTransport; + @Mock private RSocketConnector rSocketConnector; + + public static final Duration SHORT_DURATION = Duration.ofMillis(25); + public static final Duration LONG_DURATION = Duration.ofMillis(75); + + private static final Publisher SOURCE = + Flux.interval(SHORT_DURATION) + .onBackpressureBuffer() + .map(String::valueOf) + .map(DefaultPayload::create); + + private static final Mono PROGRESSING_HANDLER = + Mono.just( + new RSocket() { + private final AtomicInteger i = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .delayElements(SHORT_DURATION) + .map(Payload::getDataUtf8) + .map(DefaultPayload::create) + .take(i.incrementAndGet()); + } + }); + + @Test + void testChannelReconnection() { + when(rSocketConnector.connect(clientTransport)).thenReturn(PROGRESSING_HANDLER); + + RSocketClient client = + LoadbalanceRSocketClient.create( + rSocketConnector, + Mono.just(singletonList(LoadbalanceTarget.from("key", clientTransport)))); + + Publisher result = + client + .requestChannel(SOURCE) + .repeatWhen(longFlux -> longFlux.delayElements(LONG_DURATION).take(5)) + .map(Payload::getDataUtf8) + .log(); + + StepVerifier.create(result) + .expectSubscription() + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("4")) + .verifyComplete(); + + verify(rSocketConnector).connect(clientTransport); + verifyNoMoreInteractions(rSocketConnector, clientTransport); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java new file mode 100644 index 000000000..c1b509297 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java @@ -0,0 +1,470 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; + +public class LoadbalanceTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverAllTheRequestsWithRoundRobinStrategy() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = new TestClientTransport(); + final RSocket rSocket = + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }; + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(new TestRSocket(rSocket))); + + final List collectionOfDestination1 = + Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)); + final List collectionOfDestination2 = + Collections.singletonList(LoadbalanceTarget.from("2", mockTransport)); + final List collectionOfDestinations1And2 = + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), LoadbalanceTarget.from("2", mockTransport)); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, source.asFlux(), new RoundRobinLoadbalanceStrategy()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void shouldDeliverAllTheRequestsWithWeightedStrategy() throws InterruptedException { + final AtomicInteger counter = new AtomicInteger(); + + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + + final LoadbalanceTarget target1 = LoadbalanceTarget.from("1", mockTransport1); + final LoadbalanceTarget target2 = LoadbalanceTarget.from("2", mockTransport2); + + final WeightedRSocket weightedRSocket1 = new WeightedRSocket(counter); + final WeightedRSocket weightedRSocket2 = new WeightedRSocket(counter); + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + Mockito.when(rSocketConnectorMock.connect(mockTransport1)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket1))); + Mockito.when(rSocketConnectorMock.connect(mockTransport2)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket2))); + final List collectionOfDestination1 = Collections.singletonList(target1); + final List collectionOfDestination2 = Collections.singletonList(target2); + final List collectionOfDestinations1And2 = Arrays.asList(target1, target2); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source.asFlux(), + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver( + rsocket -> { + if (rsocket instanceof TestRSocket) { + return (WeightedRSocket) ((TestRSocket) rsocket).source(); + } + return ((PooledRSocket) rsocket).target() == target1 + ? weightedRSocket1 + : weightedRSocket2; + }) + .build()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void ensureRSocketIsCleanedFromThePoolIfSourceRSocketIsDisposed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + final TestRSocket testRSocket = + new TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.delay(Duration.ofMillis(200)).map(__ -> testRSocket)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + testRSocket.dispose(); + + Assertions.assertThatThrownBy( + () -> + rSocketPool + .select() + .fireAndForget(EmptyPayload.INSTANCE) + .block(Duration.ofSeconds(2))) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on blocking read for 2000000000 NANOSECONDS"); + + Assertions.assertThat(counter.get()).isOne(); + } + + @Test + public void ensureContextIsPropagatedCorrectlyForRequestChannel() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.delay(Duration.ofMillis(200)) + .map( + __ -> + new TestRSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher source) { + counter.incrementAndGet(); + return Flux.from(source); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + // check that context is propagated when there is no rsocket + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .then( + () -> + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)))) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport))); + // check that context is propagated when there is an RSocket but it is unresolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + // check that context is propagated when there is an RSocket and it is resolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(counter.get()).isEqualTo(3); + } + + @Test + public void shouldNotifyOnCloseWhenAllTheActiveSubscribersAreClosed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Sinks.Empty onCloseSocket1 = Sinks.empty(); + Sinks.Empty onCloseSocket2 = Sinks.empty(); + + RSocket socket1 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket1.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + RSocket socket2 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket2.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(socket1)) + .then(im -> Mono.just(socket2)) + .then(im -> Mono.never().doOnCancel(() -> counter.incrementAndGet())); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport), + LoadbalanceTarget.from("3", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + + rSocketPool.dispose(); + + AssertSubscriber onCloseSubscriber = + rSocketPool.onClose().subscribeWith(AssertSubscriber.create()); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket1.tryEmitEmpty(); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket2.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(counter.get()).isOne(); + } + + static class TestRSocket extends RSocketProxy { + + final Sinks.Empty sink = Sinks.empty(); + + public TestRSocket(RSocket rSocket) { + super(rSocket); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return source; + } + } + + private static class WeightedRSocket extends BaseWeightedStats implements RSocket { + + private final AtomicInteger counter; + + public WeightedRSocket(AtomicInteger counter) { + this.counter = counter; + } + + @Override + public Mono fireAndForget(Payload payload) { + final long startTime = startRequest(); + counter.incrementAndGet(); + return Mono.empty() + .doFinally( + (__) -> { + final long stopTime = stopRequest(startTime); + record(stopTime - startTime); + }); + } + } + + static class RetrySubscriber implements CoreSubscriber { + + final Publisher source; + + private RetrySubscriber(Publisher source) { + this.source = source; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void unused) {} + + @Override + public void onError(Throwable t) { + source.subscribe(this); + } + + @Override + public void onComplete() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java new file mode 100644 index 000000000..e43068dbd --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java @@ -0,0 +1,170 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; + +public class RoundRobinLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverValuesProportionally() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()).isCloseTo(500, Offset.offset(1)); + Assertions.assertThat(counter2.get()).isCloseTo(500, Offset.offset(1)); + } + + @Test + public void shouldDeliverValuesToNewlyConnectedSockets() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + if (im.getArgument(0) == mockTransport1) { + counter1.incrementAndGet(); + } else { + counter2.incrementAndGet(); + } + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 3, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java new file mode 100644 index 000000000..8cc254cbb --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java @@ -0,0 +1,254 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.Clock; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.publisher.TestPublisher; + +public class WeightedLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void allRequestsShouldGoToTheSocketWithHigherWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + Assertions.assertThat(counter2.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + } + + @Test + public void shouldDeliverValuesToTheSocketWithTheHighestCalculatedWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket3 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)) + .then(im -> Mono.just(rSocket3)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + rSocket1.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + + rSocket2.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 2, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + } + + static class WeightedTestRSocket extends BaseWeightedStats implements RSocket { + + final Sinks.Empty sink = Sinks.empty(); + + final RSocket rSocket; + + public WeightedTestRSocket(RSocket rSocket) { + this.rSocket = rSocket; + } + + @Override + public Mono fireAndForget(Payload payload) { + startRequest(); + final long startTime = Clock.now(); + return this.rSocket + .fireAndForget(payload) + .doFinally( + __ -> { + stopRequest(startTime); + record(Clock.now() - startTime); + updateAvailability(1.0); + }); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return rSocket; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java similarity index 79% rename from rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java index 13d910e15..58ab30021 100644 --- a/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java @@ -1,4 +1,4 @@ -package io.rsocket.metadata.security; +package io.rsocket.metadata; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -8,10 +8,10 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; -class AuthMetadataFlyweightTest { +public class AuthMetadataCodecTest { public static final int AUTH_TYPE_ID_LENGTH = 1; - public static final int USER_NAME_BYTES_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 2; public static final String TEST_BEARER_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; @@ -24,7 +24,7 @@ void shouldCorrectlyEncodeData() { int passwordLength = password.length(); ByteBuf byteBuf = - AuthMetadataFlyweight.encodeSimpleMetadata( + AuthMetadataCodec.encodeSimpleMetadata( ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); byteBuf.markReaderIndex(); @@ -44,7 +44,7 @@ void shouldCorrectlyEncodeData1() { int passwordLength = password.length(); ByteBuf byteBuf = - AuthMetadataFlyweight.encodeSimpleMetadata( + AuthMetadataCodec.encodeSimpleMetadata( ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); byteBuf.markReaderIndex(); @@ -64,7 +64,7 @@ void shouldCorrectlyEncodeData2() { int passwordLength = password.length(); ByteBuf byteBuf = - AuthMetadataFlyweight.encodeSimpleMetadata( + AuthMetadataCodec.encodeSimpleMetadata( ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); byteBuf.markReaderIndex(); @@ -82,7 +82,7 @@ private static void checkSimpleAuthMetadataEncoding( Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); - Assertions.assertThat(byteBuf.readUnsignedByte()).isEqualTo((short) usernameLength); + Assertions.assertThat(byteBuf.readUnsignedShort()).isEqualTo((short) usernameLength); Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) .isEqualTo(username); @@ -97,18 +97,18 @@ private static void checkSimpleAuthMetadataEncodingUsingDecoders( Assertions.assertThat(byteBuf.capacity()) .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); - Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + Assertions.assertThat(AuthMetadataCodec.readWellKnownAuthType(byteBuf)) .isEqualTo(WellKnownAuthType.SIMPLE); byteBuf.markReaderIndex(); - Assertions.assertThat(AuthMetadataFlyweight.decodeUsername(byteBuf).toString(CharsetUtil.UTF_8)) + Assertions.assertThat(AuthMetadataCodec.readUsername(byteBuf).toString(CharsetUtil.UTF_8)) .isEqualTo(username); - Assertions.assertThat(AuthMetadataFlyweight.decodePassword(byteBuf).toString(CharsetUtil.UTF_8)) + Assertions.assertThat(AuthMetadataCodec.readPassword(byteBuf).toString(CharsetUtil.UTF_8)) .isEqualTo(password); byteBuf.resetReaderIndex(); - Assertions.assertThat(new String(AuthMetadataFlyweight.decodeUsernameAsCharArray(byteBuf))) + Assertions.assertThat(new String(AuthMetadataCodec.readUsernameAsCharArray(byteBuf))) .isEqualTo(username); - Assertions.assertThat(new String(AuthMetadataFlyweight.decodePasswordAsCharArray(byteBuf))) + Assertions.assertThat(new String(AuthMetadataCodec.readPasswordAsCharArray(byteBuf))) .isEqualTo(password); ReferenceCountUtil.release(byteBuf); @@ -116,16 +116,22 @@ private static void checkSimpleAuthMetadataEncodingUsingDecoders( @Test void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { - String username = + StringBuilder usernameBuilder = new StringBuilder(); + String usernamePart = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + for (int i = 0; i < 65535 / usernamePart.length(); i++) { + usernameBuilder.append(usernamePart); + } String password = "tset1234"; Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.encodeSimpleMetadata( - ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray())) + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, + usernameBuilder.toString().toCharArray(), + password.toCharArray())) .hasMessage( - "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); } @Test @@ -133,8 +139,7 @@ void shouldEncodeBearerMetadata() { String testToken = TEST_BEARER_TOKEN; ByteBuf byteBuf = - AuthMetadataFlyweight.encodeBearerMetadata( - ByteBufAllocator.DEFAULT, testToken.toCharArray()); + AuthMetadataCodec.encodeBearerMetadata(ByteBufAllocator.DEFAULT, testToken.toCharArray()); byteBuf.markReaderIndex(); checkBearerAuthMetadataEncoding(testToken, byteBuf); @@ -146,7 +151,7 @@ private static void checkBearerAuthMetadataEncoding(String testToken, ByteBuf by Assertions.assertThat(byteBuf.capacity()) .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); Assertions.assertThat( - byteBuf.readUnsignedByte() & ~AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK) + byteBuf.readUnsignedByte() & ~AuthMetadataCodec.STREAM_METADATA_KNOWN_MASK) .isEqualTo(WellKnownAuthType.BEARER.getIdentifier()); Assertions.assertThat(byteBuf.readSlice(byteBuf.capacity() - 1).toString(CharsetUtil.UTF_8)) .isEqualTo(testToken); @@ -156,15 +161,15 @@ private static void checkBearerAuthMetadataEncodingUsingDecoders( String testToken, ByteBuf byteBuf) { Assertions.assertThat(byteBuf.capacity()) .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); - Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(byteBuf)).isTrue(); - Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(byteBuf)).isTrue(); + Assertions.assertThat(AuthMetadataCodec.readWellKnownAuthType(byteBuf)) .isEqualTo(WellKnownAuthType.BEARER); byteBuf.markReaderIndex(); - Assertions.assertThat(new String(AuthMetadataFlyweight.decodeBearerTokenAsCharArray(byteBuf))) + Assertions.assertThat(new String(AuthMetadataCodec.readBearerTokenAsCharArray(byteBuf))) .isEqualTo(testToken); byteBuf.resetReaderIndex(); Assertions.assertThat( - AuthMetadataFlyweight.decodePayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) + AuthMetadataCodec.readPayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) .isEqualTo(testToken); } @@ -176,7 +181,7 @@ void shouldEncodeCustomAuth() { String customAuthType = "myownauthtype"; ByteBuf buffer = - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload); checkCustomAuthMetadataEncoding(testSecurityPayload, customAuthType, buffer); @@ -204,7 +209,7 @@ void shouldThrowOnNonASCIIChars() { Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) .hasMessage("custom auth type must be US_ASCII characters only"); } @@ -218,7 +223,7 @@ void shouldThrowOnOutOfAllowedSizeType() { Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) .hasMessage( "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); @@ -231,7 +236,7 @@ void shouldThrowOnOutOfAllowedSizeType1() { Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) .hasMessage( "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); @@ -240,10 +245,10 @@ void shouldThrowOnOutOfAllowedSizeType1() { @Test void shouldEncodeUsingWellKnownAuthType() { ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.SIMPLE, - ByteBufAllocator.DEFAULT.buffer(3, 3).writeByte(1).writeByte('u').writeByte('p')); + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); } @@ -251,10 +256,10 @@ void shouldEncodeUsingWellKnownAuthType() { @Test void shouldEncodeUsingWellKnownAuthType1() { ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.SIMPLE, - ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); } @@ -262,7 +267,7 @@ void shouldEncodeUsingWellKnownAuthType1() { @Test void shouldEncodeUsingWellKnownAuthType2() { ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.BEARER, Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); @@ -279,13 +284,13 @@ void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) .hasMessage("only allowed AuthType should be used"); Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) .hasMessage("only allowed AuthType should be used"); @@ -295,10 +300,10 @@ void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { @Test void shouldCompressMetadata() { ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "simple", - ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); } @@ -306,7 +311,7 @@ void shouldCompressMetadata() { @Test void shouldCompressMetadata1() { ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "bearer", Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); @@ -323,7 +328,7 @@ void shouldNotCompressMetadata() { Unpooled.wrappedBuffer(TEST_BEARER_TOKEN.getBytes(CharsetUtil.UTF_8)); String customAuthType = "testauthtype"; ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, customAuthType, testMetadataPayload); checkCustomAuthMetadataEncoding(testMetadataPayload, customAuthType, byteBuf); @@ -332,12 +337,12 @@ void shouldNotCompressMetadata() { @Test void shouldConfirmWellKnownAuthType() { ByteBuf metadata = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); int initialReaderIndex = metadata.readerIndex(); - Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isTrue(); + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(metadata)).isTrue(); Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); ReferenceCountUtil.release(metadata); @@ -346,12 +351,12 @@ void shouldConfirmWellKnownAuthType() { @Test void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { ByteBuf metadata = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "simple/afafgafadf", Unpooled.EMPTY_BUFFER); int initialReaderIndex = metadata.readerIndex(); - Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isFalse(); + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(metadata)).isFalse(); Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); ReferenceCountUtil.release(metadata); @@ -360,7 +365,7 @@ void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { @Test void shouldReadSimpleWellKnownAuthType() { ByteBuf metadata = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); WellKnownAuthType expectedType = WellKnownAuthType.SIMPLE; checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); @@ -369,7 +374,7 @@ void shouldReadSimpleWellKnownAuthType() { @Test void shouldReadSimpleWellKnownAuthType1() { ByteBuf metadata = - AuthMetadataFlyweight.encodeMetadataWithCompression( + AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "bearer", Unpooled.EMPTY_BUFFER); WellKnownAuthType expectedType = WellKnownAuthType.BEARER; checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); @@ -380,7 +385,7 @@ void shouldReadSimpleWellKnownAuthType2() { ByteBuf metadata = ByteBufAllocator.DEFAULT .buffer() - .writeByte(3 | AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK); + .writeByte(3 | AuthMetadataCodec.STREAM_METADATA_KNOWN_MASK); WellKnownAuthType expectedType = WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); } @@ -395,7 +400,7 @@ void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength() { @Test void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { ByteBuf metadata = - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, "testmetadataauthtype", Unpooled.EMPTY_BUFFER); WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); @@ -404,7 +409,7 @@ void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { @Test void shouldThrowExceptionIsNotEnoughReadableBytes() { Assertions.assertThatThrownBy( - () -> AuthMetadataFlyweight.decodeWellKnownAuthType(Unpooled.EMPTY_BUFFER)) + () -> AuthMetadataCodec.readWellKnownAuthType(Unpooled.EMPTY_BUFFER)) .hasMessage("Unable to decode Well Know Auth type. Not enough readable bytes"); } @@ -412,7 +417,7 @@ private static void checkDecodeWellKnowAuthTypeCorrectly( ByteBuf metadata, WellKnownAuthType expectedType) { int initialReaderIndex = metadata.readerIndex(); - WellKnownAuthType wellKnownAuthType = AuthMetadataFlyweight.decodeWellKnownAuthType(metadata); + WellKnownAuthType wellKnownAuthType = AuthMetadataCodec.readWellKnownAuthType(metadata); Assertions.assertThat(wellKnownAuthType).isEqualTo(expectedType); Assertions.assertThat(metadata.readerIndex()) @@ -426,15 +431,14 @@ private static void checkDecodeWellKnowAuthTypeCorrectly( void shouldReadCustomEncodedAuthType() { String testAuthType = "TestAuthType"; ByteBuf byteBuf = - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, testAuthType, Unpooled.EMPTY_BUFFER); checkDecodeCustomAuthTypeCorrectly(testAuthType, byteBuf); } @Test void shouldThrowExceptionOnEmptyMetadata() { - Assertions.assertThatThrownBy( - () -> AuthMetadataFlyweight.decodeCustomAuthType(Unpooled.EMPTY_BUFFER)) + Assertions.assertThatThrownBy(() -> AuthMetadataCodec.readCustomAuthType(Unpooled.EMPTY_BUFFER)) .hasMessage("Unable to decode custom Auth type. Not enough readable bytes"); } @@ -442,8 +446,8 @@ void shouldThrowExceptionOnEmptyMetadata() { void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.decodeCustomAuthType( - AuthMetadataFlyweight.encodeMetadata( + AuthMetadataCodec.readCustomAuthType( + AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.BEARER, Unpooled.copiedBuffer(new byte[] {'a', 'b'})))) @@ -454,7 +458,7 @@ void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { void shouldThrowExceptionOnMalformedMetadata_length() { Assertions.assertThatThrownBy( () -> - AuthMetadataFlyweight.decodeCustomAuthType( + AuthMetadataCodec.readCustomAuthType( ByteBufAllocator.DEFAULT.buffer().writeByte(127).writeChar('a').writeChar('b'))) .hasMessage("Unable to decode custom Auth type. Malformed length or auth type string"); } @@ -462,7 +466,7 @@ void shouldThrowExceptionOnMalformedMetadata_length() { private static void checkDecodeCustomAuthTypeCorrectly(String testAuthType, ByteBuf byteBuf) { int initialReaderIndex = byteBuf.readerIndex(); - Assertions.assertThat(AuthMetadataFlyweight.decodeCustomAuthType(byteBuf).toString()) + Assertions.assertThat(AuthMetadataCodec.readCustomAuthType(byteBuf).toString()) .isEqualTo(testAuthType); Assertions.assertThat(byteBuf.readerIndex()) .isEqualTo(initialReaderIndex + testAuthType.length() + 1); diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java similarity index 77% rename from rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java index bd5e4295a..a4e8fb2d8 100644 --- a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java @@ -16,18 +16,28 @@ package io.rsocket.metadata; -import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeAndContentBuffersSlices; -import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeIdFromMimeBuffer; -import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer; import static org.assertj.core.api.Assertions.*; import io.netty.buffer.*; import io.netty.util.CharsetUtil; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.ByteBufUtils; import io.rsocket.util.NumberUtils; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; -class CompositeMetadataFlyweightTest { +class CompositeMetadataCodecTest { + + final LeaksTrackingByteBufAllocator testAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @AfterEach + void tearDownAndCheckForLeaks() { + testAllocator.assertHasNoLeaks(); + } static String byteToBitsString(byte b) { return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0'); @@ -48,17 +58,14 @@ void customMimeHeaderLatin1_encodingFails() { assertThatIllegalArgumentException() .isThrownBy( - () -> - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) .withMessage("custom mime type must be US_ASCII characters only"); } @Test void customMimeHeaderLength0_encodingFails() { assertThatIllegalArgumentException() - .isThrownBy( - () -> CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "", 0)) + .isThrownBy(() -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "", 0)) .withMessage( "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); } @@ -70,8 +77,7 @@ void customMimeHeaderLength127() { builder.append('a'); } String mimeString = builder.toString(); - ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111110"); @@ -94,11 +100,12 @@ void customMimeHeaderLength127() { .hasToString(mimeString); header.resetReaderIndex(); - assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) .as("decoded mime string") .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -108,8 +115,7 @@ void customMimeHeaderLength128() { builder.append('a'); } String mimeString = builder.toString(); - ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111111"); @@ -132,11 +138,12 @@ void customMimeHeaderLength128() { .hasToString(mimeString); header.resetReaderIndex(); - assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) .as("decoded mime string") .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -148,9 +155,7 @@ void customMimeHeaderLength129_encodingFails() { assertThatIllegalArgumentException() .isThrownBy( - () -> - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, builder.toString(), 0)) + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, builder.toString(), 0)) .withMessage( "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); } @@ -158,8 +163,7 @@ void customMimeHeaderLength129_encodingFails() { @Test void customMimeHeaderLengthOne() { String mimeString = "w"; - ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000000"); @@ -180,18 +184,18 @@ void customMimeHeaderLengthOne() { .hasToString(mimeString); header.resetReaderIndex(); - assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) .as("decoded mime string") .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test void customMimeHeaderLengthTwo() { String mimeString = "ww"; - ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000001"); @@ -214,11 +218,12 @@ void customMimeHeaderLengthTwo() { .hasToString(mimeString); header.resetReaderIndex(); - assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) .as("decoded mime string") .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -227,9 +232,7 @@ void customMimeHeaderUtf8_encodingFails() { "mime/tyࠒe"; // this is the SAMARITAN LETTER QUF u+0812 represented on 3 bytes assertThatIllegalArgumentException() .isThrownBy( - () -> - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) .withMessage("custom mime type must be US_ASCII characters only"); } @@ -317,72 +320,73 @@ void decodeTypeSkipsFirstByte() { @Test void encodeMetadataCustomTypeDelegates() { - ByteBuf expected = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "foo", 2); + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "foo", 2); - CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf test = testAllocator.compositeBuffer(); - CompositeMetadataFlyweight.encodeAndAddMetadata( - test, ByteBufAllocator.DEFAULT, "foo", ByteBufUtils.getRandomByteBuf(2)); + CompositeMetadataCodec.encodeAndAddMetadata( + test, testAllocator, "foo", ByteBufUtils.getRandomByteBuf(2)); assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); } @Test void encodeMetadataKnownTypeDelegates() { ByteBuf expected = - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, - WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), - 2); + CompositeMetadataCodec.encodeMetadataHeader( + testAllocator, WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), 2); - CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf test = testAllocator.compositeBuffer(); - CompositeMetadataFlyweight.encodeAndAddMetadata( + CompositeMetadataCodec.encodeAndAddMetadata( test, - ByteBufAllocator.DEFAULT, + testAllocator, WellKnownMimeType.APPLICATION_OCTET_STREAM, ByteBufUtils.getRandomByteBuf(2)); assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); } @Test void encodeMetadataReservedTypeDelegates() { - ByteBuf expected = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, (byte) 120, 2); + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, (byte) 120, 2); - CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf test = testAllocator.compositeBuffer(); - CompositeMetadataFlyweight.encodeAndAddMetadata( - test, ByteBufAllocator.DEFAULT, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); + CompositeMetadataCodec.encodeAndAddMetadata( + test, testAllocator, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); } @Test void encodeTryCompressWithCompressableType() { ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); - CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf target = testAllocator.compositeBuffer(); - CompositeMetadataFlyweight.encodeAndAddMetadataWithCompression( - target, - UnpooledByteBufAllocator.DEFAULT, - WellKnownMimeType.APPLICATION_AVRO.getString(), - metadata); + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + target, testAllocator, WellKnownMimeType.APPLICATION_AVRO.getString(), metadata); assertThat(target.readableBytes()).as("readableBytes 1 + 3 + 2").isEqualTo(6); + target.release(); } @Test void encodeTryCompressWithCustomType() { ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); - CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf target = testAllocator.compositeBuffer(); - CompositeMetadataFlyweight.encodeAndAddMetadataWithCompression( - target, UnpooledByteBufAllocator.DEFAULT, "custom/example", metadata); + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + target, testAllocator, "custom/example", metadata); assertThat(target.readableBytes()).as("readableBytes 1 + 14 + 3 + 2").isEqualTo(20); + target.release(); } @Test @@ -390,35 +394,35 @@ void hasEntry() { WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; CompositeByteBuf buffer = - Unpooled.compositeBuffer() + testAllocator + .compositeBuffer() .addComponent( true, - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)) + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0)) .addComponent( true, - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)); + CompositeMetadataCodec.encodeMetadataHeader( + testAllocator, mime.getIdentifier(), 0)); - assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 0)).isTrue(); - assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 4)).isTrue(); - assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 8)).isFalse(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 0)).isTrue(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 4)).isTrue(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 8)).isFalse(); + buffer.release(); } @Test void isWellKnownMimeType() { ByteBuf wellKnown = Unpooled.buffer().writeByte(0); - assertThat(CompositeMetadataFlyweight.isWellKnownMimeType(wellKnown)).isTrue(); + assertThat(CompositeMetadataCodec.isWellKnownMimeType(wellKnown)).isTrue(); ByteBuf explicit = Unpooled.buffer().writeByte(2).writeChar('a'); - assertThat(CompositeMetadataFlyweight.isWellKnownMimeType(explicit)).isFalse(); + assertThat(CompositeMetadataCodec.isWellKnownMimeType(explicit)).isFalse(); } @Test void knownMimeHeader120_reserved() { byte mime = (byte) 120; - ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mime, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime, 0); assertThat(mime) .as("smoke test RESERVED_120 unsigned 7 bits representation") @@ -443,6 +447,7 @@ void knownMimeHeader120_reserved() { assertThat(decodeMimeIdFromMimeBuffer(header)).as("decoded mime id").isEqualTo(mime); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -453,8 +458,7 @@ void knownMimeHeader127_compositeMetadata() { .isEqualTo((byte) 127) .isEqualTo((byte) 0b01111111); ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); assertThat(toHeaderBits(encoded)) .startsWith("1") @@ -480,6 +484,7 @@ void knownMimeHeader127_compositeMetadata() { .isEqualTo(mime.getIdentifier()); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -490,8 +495,7 @@ void knownMimeHeaderZero_avro() { .isEqualTo((byte) 0) .isEqualTo((byte) 0b00000000); ByteBuf encoded = - CompositeMetadataFlyweight.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); assertThat(toHeaderBits(encoded)) .startsWith("1") @@ -517,6 +521,7 @@ void knownMimeHeaderZero_avro() { .isEqualTo(mime.getIdentifier()); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -543,8 +548,7 @@ protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { } }; - assertThatCode( - () -> CompositeMetadataFlyweight.encodeMetadataHeader(allocator, "custom/type", 0)) + assertThatCode(() -> CompositeMetadataCodec.encodeMetadataHeader(allocator, "custom/type", 0)) .doesNotThrowAnyException(); assertThat(badBuf.readByte()).isEqualTo((byte) 10); diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java index f06bdcc0c..0b81ab4b0 100644 --- a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java @@ -108,11 +108,11 @@ void decodeThreeEntries() { metadata3.writeByte(88); CompositeByteBuf compositeMetadataBuffer = ByteBufAllocator.DEFAULT.compositeBuffer(); - CompositeMetadataFlyweight.encodeAndAddMetadata( + CompositeMetadataCodec.encodeAndAddMetadata( compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType1, metadata1); - CompositeMetadataFlyweight.encodeAndAddMetadata( + CompositeMetadataCodec.encodeAndAddMetadata( compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType2, metadata2); - CompositeMetadataFlyweight.encodeAndAddMetadata( + CompositeMetadataCodec.encodeAndAddMetadata( compositeMetadataBuffer, ByteBufAllocator.DEFAULT, reserved, metadata3); Iterator iterator = new CompositeMetadata(compositeMetadataBuffer, true).iterator(); diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java new file mode 100644 index 000000000..5c8d40306 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.util.List; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link MimeTypeMetadataCodec}. */ +public class MimeTypeMetadataCodecTest { + + @Test + public void wellKnownMimeType() { + WellKnownMimeType mimeType = WellKnownMimeType.APPLICATION_HESSIAN; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(WellKnownMimeType.fromString(mimeTypes.get(0))).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void customMimeType() { + String mimeType = "aaa/bb"; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(mimeTypes.get(0)).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void multipleMimeTypes() { + List mimeTypes = Lists.newArrayList("aaa/bbb", "application/x-hessian"); + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeTypes); + + try { + assertThat(MimeTypeMetadataCodec.decode(byteBuf)).isEqualTo(mimeTypes); + } finally { + byteBuf.release(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java index d1fbb50b0..b65ffafee 100644 --- a/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java @@ -23,7 +23,7 @@ public void testParseTags() { Arrays.asList( "ws://localhost:8080/rsocket", String.join("", Collections.nCopies(129, "x"))); TaggingMetadata taggingMetadata = - TaggingMetadataFlyweight.createTaggingMetadata( + TaggingMetadataCodec.createTaggingMetadata( byteBufAllocator, "message/x.rsocket.routing.v0", tags); TaggingMetadata taggingMetadataCopy = new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); @@ -37,7 +37,7 @@ public void testEmptyTagAndOverLengthTag() { Arrays.asList( "ws://localhost:8080/rsocket", "", String.join("", Collections.nCopies(256, "x"))); TaggingMetadata taggingMetadata = - TaggingMetadataFlyweight.createTaggingMetadata( + TaggingMetadataCodec.createTaggingMetadata( byteBufAllocator, "message/x.rsocket.routing.v0", tags); TaggingMetadata taggingMetadataCopy = new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java new file mode 100644 index 000000000..9a19050f9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -0,0 +1,790 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.FrameType; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.annotation.Nullable; + +public class RequestInterceptorTest { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @Test + void ensuresExceptionInTheInterceptorIsHandledProperly() { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnCancel"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + StepVerifier.create(rSocket.fireAndForget(DefaultPayload.create("test"))) + .expectSubscription() + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestStream(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestChannel(Flux.just(DefaultPayload.create("test")))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSupportMultipleInterceptors(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor1 = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequestInterceptor testRequestInterceptor2 = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor1) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor2)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + testRequestInterceptor2 + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java new file mode 100644 index 000000000..8261b3f49 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -0,0 +1,142 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Queue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import reactor.util.annotation.Nullable; + +public class TestRequestInterceptor implements RequestInterceptor { + + final Queue events = new MpscUnboundedArrayQueue<>(128); + + @Override + public void dispose() {} + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_START, streamId, requestType, null)); + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + events.add( + new Event( + t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, requestType, t)); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + events.add(new Event(EventType.ON_CANCEL, streamId, requestType, null)); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, -1, requestType, rejectionReason)); + } + + public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_START) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectOnComplete(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_COMPLETE) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnError(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_ERROR) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnCancel(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_CANCEL) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor assertNext(Consumer consumer) { + final Event event = events.poll(); + Assertions.assertThat(event).isNotNull(); + + consumer.accept(event); + + return this; + } + + public TestRequestInterceptor expectOnReject(FrameType requestType, Throwable rejectionReason) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_REJECT) + .has( + new Condition<>( + e -> { + Assertions.assertThat(e.error) + .isExactlyInstanceOf(rejectionReason.getClass()) + .hasMessage(rejectionReason.getMessage()) + .hasCause(rejectionReason.getCause()); + return true; + }, + "Has rejection reason which matches to %s", + rejectionReason)) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectNothing() { + final Event event = events.poll(); + + Assertions.assertThat(event).isNull(); + + return this; + } + + public static final class Event { + public final EventType eventType; + public final int streamId; + public final FrameType requestType; + public final Throwable error; + + Event(EventType eventType, int streamId, FrameType requestType, Throwable error) { + this.eventType = eventType; + this.streamId = streamId; + this.requestType = requestType; + this.error = error; + } + } + + public enum EventType { + ON_START, + ON_COMPLETE, + ON_ERROR, + ON_CANCEL, + ON_REJECT + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java new file mode 100644 index 000000000..8229bf42b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java @@ -0,0 +1,470 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.exceptions.ConnectionCloseException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ClientRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new ConnectionCloseException("some message"))); + // connection should be closed because of the wrong first frame + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout is still in progress + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void sessionTerminationOnWrongFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // Send KEEPALIVE frame as a first frame + transport + .testConnection() + .addToReceivedBuffer( + KeepAliveFrameCodec.encode(transport.alloc(), false, 0, Unpooled.EMPTY_BUFFER)); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectErrorMessage("RESUME_OK frame must be received before any others") + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldErrorWithNoRetriesOnErrorFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send REJECTED_RESUME_ERROR frame + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new RejectedResumeException("failed resumption"))); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectError(RejectedResumeException.class) + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java index 9da66d424..bba40d674 100644 --- a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java +++ b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java @@ -1,80 +1,528 @@ package io.rsocket.resume; +import static org.assertj.core.api.Assertions.assertThat; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; import java.util.Arrays; -import org.junit.Assert; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Disposable; +import reactor.core.publisher.Hooks; +import reactor.test.util.RaceTestUtils; public class InMemoryResumeStoreTest { + @Test + void saveNonResumableFrame() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeConnectionFrame(10); + final ByteBuf frame2 = fakeConnectionFrame(35); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + } + @Test void saveWithoutTailRemoval() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame = frameMock(10); - store.saveFrames(Flux.just(frame)).block(); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame.readableBytes(), store.cacheSize); - Assert.assertEquals(0, store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame = fakeResumableFrame(10); + + sender.onNext(frame); + + assertThat(store.cachedFrames.size()).isEqualTo(1); + assertThat(store.cacheSize).isEqualTo(frame.readableBytes()); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(1).values().forEach(ByteBuf::release); + + assertThat(frame.refCnt()).isOne(); } @Test void saveRemoveOneFromTail() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame1 = frameMock(20); - ByteBuf frame2 = frameMock(10); - store.saveFrames(Flux.just(frame1, frame2)).block(); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame2.readableBytes(), store.cacheSize); - Assert.assertEquals(frame1.readableBytes(), store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(20); + final ByteBuf frame2 = fakeResumableFrame(10); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame2.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(frame1.readableBytes()); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isOne(); } @Test void saveRemoveTwoFromTail() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(10); - ByteBuf frame3 = frameMock(20); - store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame3.readableBytes(), store.cacheSize); - Assert.assertEquals(size(frame1, frame2), store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(20); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); } @Test void saveBiggerThanStore() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(10); - ByteBuf frame3 = frameMock(30); - store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); - Assert.assertEquals(0, store.cachedFrames.size()); - Assert.assertEquals(0, store.cacheSize); - Assert.assertEquals(size(frame1, frame2, frame3), store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2, frame3)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); } @Test void releaseFrames() { - InMemoryResumableFramesStore store = inMemoryStore(100); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(10); - ByteBuf frame3 = frameMock(30); - store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + store.releaseFrames(20); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame3.readableBytes(), store.cacheSize); - Assert.assertEquals(size(frame1, frame2), store.position); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); } @Test void receiveImpliedPosition() { - InMemoryResumableFramesStore store = inMemoryStore(100); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(30); + final InMemoryResumableFramesStore store = inMemoryStore(100); + + ByteBuf frame1 = fakeResumableFrame(10); + ByteBuf frame2 = fakeResumableFrame(30); + store.resumableFrameReceived(frame1); store.resumableFrameReceived(frame2); - Assert.assertEquals(size(frame1, frame2), store.frameImpliedPosition()); + + assertThat(store.frameImpliedPosition()).isEqualTo(size(frame1, frame2)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void ensuresCleansOnTerminal(boolean hasSubscriber) { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + hasSubscriber ? store.resumeStream().subscribeWith(AssertSubscriber.create()) : null; + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + if (hasSubscriber) { + assertSubscriber.assertValueCount(3).assertTerminated().values().forEach(ByteBuf::release); + } + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @Test + void ensuresCleansOnTerminalLateSubscriber() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + assertSubscriber.assertTerminated(); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @ParameterizedTest(name = "Sending vs Reconnect Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void sendingVsReconnectRaceTest(boolean withLateSubscriber) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + final BlockingQueue receivedFrames = new ArrayBlockingQueue<>(10); + final AtomicInteger receivedPosition = new AtomicInteger(); + + store.saveFrames(frames).subscribe(); + + final Consumer consumer = + f -> { + if (ResumableDuplexConnection.isResumableFrame(f)) { + receivedPosition.addAndGet(f.readableBytes()); + receivedFrames.offer(f); + return; + } + f.release(); + }; + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber ? null : store.resumeStream().subscribe(consumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer)); + } + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + store.releaseFrames(receivedFrames.stream().mapToInt(ByteBuf::readableBytes).sum()); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + + assertThat(receivedFrames) + .hasSize(5) + .containsSequence(byteBuf1, byteBuf2, byteBuf3, byteBuf4, byteBuf5); + receivedFrames.forEach(ReferenceCounted::release); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } + + @ParameterizedTest( + name = "Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValue(1); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest( + name = + "Dispose vs Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextWithDisposeRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }, + store::dispose); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValueGreaterThanOrEqualTo(1).hasValueLessThanOrEqualTo(2); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } } private int size(ByteBuf... byteBufs) { @@ -82,12 +530,18 @@ private int size(ByteBuf... byteBufs) { } private static InMemoryResumableFramesStore inMemoryStore(int size) { - return new InMemoryResumableFramesStore("test", size); + return new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, size); } - private static ByteBuf frameMock(int size) { + private static ByteBuf fakeResumableFrame(int size) { byte[] bytes = new byte[size]; Arrays.fill(bytes, (byte) 7); return Unpooled.wrappedBuffer(bytes); } + + private static ByteBuf fakeConnectionFrame(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 0); + return Unpooled.wrappedBuffer(bytes); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java deleted file mode 100644 index 7d2a7bcc8..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -public class ResumeCalculatorTest { - - @BeforeEach - void setUp() {} - - @Test - void clientResumeSuccess() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 42, -1, 3); - Assertions.assertEquals(3, position); - } - - @Test - void clientResumeError() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(4, 42, -1, 3); - Assertions.assertEquals(-1, position); - } - - @Test - void serverResumeSuccess() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 42, 4, 23); - Assertions.assertEquals(23, position); - } - - @Test - void serverResumeErrorClientState() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 3, 4, 23); - Assertions.assertEquals(-1, position); - } - - @Test - void serverResumeErrorServerState() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(4, 42, 4, 1); - Assertions.assertEquals(-1, position); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeExpBackoffTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeExpBackoffTest.java deleted file mode 100644 index d86276466..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeExpBackoffTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import static org.junit.jupiter.api.Assertions.assertThrows; - -import java.time.Duration; -import java.util.List; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; - -public class ResumeExpBackoffTest { - - @Test - void backOffSeries() { - Duration firstBackoff = Duration.ofSeconds(1); - Duration maxBackoff = Duration.ofSeconds(32); - int factor = 2; - ExponentialBackoffResumeStrategy strategy = - new ExponentialBackoffResumeStrategy(firstBackoff, maxBackoff, factor); - - List expected = - Flux.just(1, 2, 4, 8, 16, 32, 32).map(Duration::ofSeconds).collectList().block(); - - List actual = Flux.range(1, 7).map(v -> strategy.next()).collectList().block(); - - Assertions.assertThat(actual).isEqualTo(expected); - } - - @Test - void nullFirstBackoff() { - assertThrows( - NullPointerException.class, - () -> { - ExponentialBackoffResumeStrategy strategy = - new ExponentialBackoffResumeStrategy(Duration.ofSeconds(1), null, 42); - }); - } - - @Test - void nullMaxBackoff() { - assertThrows( - NullPointerException.class, - () -> { - ExponentialBackoffResumeStrategy strategy = - new ExponentialBackoffResumeStrategy(null, Duration.ofSeconds(1), 42); - }); - } - - @Test - void negativeFactor() { - assertThrows( - IllegalArgumentException.class, - () -> { - ExponentialBackoffResumeStrategy strategy = - new ExponentialBackoffResumeStrategy( - Duration.ofSeconds(1), Duration.ofSeconds(32), -1); - }); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java new file mode 100644 index 000000000..b5625bf8e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java @@ -0,0 +1,190 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class ServerRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // resubscribe so a new connection is generated + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME frame + final ByteBuf resumeFrame = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame, transport.testConnection()); + resumeFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME_OK) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(61)); + + final ByteBuf resumeFrame1 = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame1, transport.testConnection()); + resumeFrame1.release(); + + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java index 58323c066..cdfcefdc8 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,43 +19,81 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import org.reactivestreams.Publisher; -import reactor.core.publisher.DirectProcessor; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; public class LocalDuplexConnection implements DuplexConnection { private final ByteBufAllocator allocator; - private final DirectProcessor send; - private final DirectProcessor receive; - private final MonoProcessor onClose; + private final Sinks.Many send; + private final Sinks.Many receive; + private final Sinks.Empty onClose; private final String name; public LocalDuplexConnection( String name, ByteBufAllocator allocator, - DirectProcessor send, - DirectProcessor receive) { + Sinks.Many send, + Sinks.Many receive) { this.name = name; this.allocator = allocator; this.send = send; this.receive = receive; - this.onClose = MonoProcessor.create(); + this.onClose = Sinks.empty(); } @Override - public Mono send(Publisher frame) { - return Flux.from(frame) - .doOnNext(f -> System.out.println(name + " - " + f.toString())) - .doOnNext(send::onNext) - .doOnError(send::onError) - .then(); + public void sendFrame(int streamId, ByteBuf frame) { + System.out.println(name + " - " + frame.toString()); + send.tryEmitNext(frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + System.out.println(name + " - " + errorFrame.toString()); + send.tryEmitNext(errorFrame); + onClose.tryEmitEmpty(); } @Override public Flux receive() { - return receive.doOnNext(f -> System.out.println(name + " - " + f.toString())); + return receive + .asFlux() + .doOnNext(f -> System.out.println(name + " - " + f.toString())) + .transform( + Operators.lift( + (__, actual) -> + new CoreSubscriber() { + + @Override + public void onSubscribe(Subscription s) { + actual.onSubscribe(s); + } + + @Override + public void onNext(ByteBuf byteBuf) { + actual.onNext(byteBuf); + byteBuf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + })); } @Override @@ -63,18 +101,24 @@ public ByteBufAllocator alloc() { return allocator; } + @Override + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress(name); + } + @Override public void dispose() { - onClose.onComplete(); + onClose.tryEmitEmpty(); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return onClose; + return onClose.asMono(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java index 179afff58..a33c4c4b3 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java @@ -16,8 +16,7 @@ package io.rsocket.test.util; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.is; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -116,6 +115,8 @@ public void assertMetadataPushCount(int expected) { } private static void assertCount(int expected, String type, AtomicInteger counter) { - assertThat("Unexpected invocations for " + type + '.', counter.get(), is(expected)); + assertThat(counter.get()) + .describedAs("Unexpected invocations for " + type + '.') + .isEqualTo(expected); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java index 88694d209..f02bc99a4 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -6,18 +6,21 @@ import io.rsocket.DuplexConnection; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ClientTransport; +import java.time.Duration; import reactor.core.publisher.Mono; public class TestClientTransport implements ClientTransport { private final LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(allocator); + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "client"); + + private volatile TestDuplexConnection testDuplexConnection; int maxFrameLength = FRAME_LENGTH_MASK; @Override public Mono connect() { - return Mono.just(testDuplexConnection); + return Mono.fromSupplier(() -> testDuplexConnection = new TestDuplexConnection(allocator)); } public TestDuplexConnection testConnection() { diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index 17a19b8c9..8793d6ca4 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,20 +17,25 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import java.util.Collection; -import java.util.concurrent.ConcurrentLinkedQueue; +import io.rsocket.RSocketErrorException; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; import reactor.core.publisher.DirectProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; /** * An implementation of {@link DuplexConnection} that provides functionality to modify the behavior @@ -41,56 +46,89 @@ public class TestDuplexConnection implements DuplexConnection { private static final Logger logger = LoggerFactory.getLogger(TestDuplexConnection.class); private final LinkedBlockingQueue sent; + private final DirectProcessor sentPublisher; private final FluxSink sendSink; private final DirectProcessor received; private final FluxSink receivedSink; private final MonoProcessor onClose; - private final ConcurrentLinkedQueue> sendSubscribers; - private final ByteBufAllocator allocator; + private final LeaksTrackingByteBufAllocator allocator; private volatile double availability = 1; private volatile int initialSendRequestN = Integer.MAX_VALUE; - public TestDuplexConnection(ByteBufAllocator allocator) { + public TestDuplexConnection(LeaksTrackingByteBufAllocator allocator) { this.allocator = allocator; this.sent = new LinkedBlockingQueue<>(); this.received = DirectProcessor.create(); this.receivedSink = received.sink(); this.sentPublisher = DirectProcessor.create(); this.sendSink = sentPublisher.sink(); - this.sendSubscribers = new ConcurrentLinkedQueue<>(); this.onClose = MonoProcessor.create(); } @Override - public Mono send(Publisher frames) { + public void sendFrame(int streamId, ByteBuf frame) { if (availability <= 0) { - return Mono.error( - new IllegalStateException("RSocket not available. Availability: " + availability)); + throw new IllegalStateException("RSocket not available. Availability: " + availability); } - Subscriber subscriber = TestSubscriber.create(initialSendRequestN); - Flux.from(frames) - .doOnNext( - frame -> { - sent.offer(frame); - sendSink.next(frame); - }) - .doOnError(throwable -> logger.error("Error in send stream on test connection.", throwable)) - .subscribe(subscriber); - sendSubscribers.add(subscriber); - return Mono.empty(); + + sendSink.next(frame); + sent.offer(frame); } @Override public Flux receive() { - return received; + return received.transform( + Operators.lift( + (__, actual) -> + new CoreSubscriber() { + @Override + public void onSubscribe(Subscription s) { + actual.onSubscribe(s); + } + + @Override + public void onNext(ByteBuf byteBuf) { + actual.onNext(byteBuf); + byteBuf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + })); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + sendSink.next(errorFrame); + sent.offer(errorFrame); + + final Throwable cause = e.getCause(); + if (cause == null) { + onClose.onComplete(); + } else { + onClose.onError(cause); + } } @Override - public ByteBufAllocator alloc() { + public LeaksTrackingByteBufAllocator alloc() { return allocator; } + @Override + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress("TestDuplexConnection"); + } + @Override public double availability() { return availability; @@ -111,15 +149,28 @@ public Mono onClose() { return onClose; } - public ByteBuf awaitSend() throws InterruptedException { - return sent.take(); + public boolean isEmpty() { + return sent.isEmpty(); + } + + @NonNull + public ByteBuf awaitFrame() { + try { + return sent.take(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public ByteBuf pollFrame() { + return sent.poll(); } public void setAvailability(double availability) { this.availability = availability; } - public Collection getSent() { + public BlockingQueue getSent() { return sent; } @@ -135,14 +186,9 @@ public void addToReceivedBuffer(ByteBuf... received) { public void clearSendReceiveBuffers() { sent.clear(); - sendSubscribers.clear(); } public void setInitialSendRequestN(int initialSendRequestN) { this.initialSendRequestN = initialSendRequestN; } - - public Collection> getSendSubscribers() { - return sendSubscribers; - } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java new file mode 100644 index 000000000..2dad2cc1f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java @@ -0,0 +1,46 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test.util; + +import java.net.SocketAddress; +import java.util.Objects; + +public final class TestLocalSocketAddress extends SocketAddress { + + private static final long serialVersionUID = 2608695156052100164L; + + private final String name; + + /** + * Creates a new instance. + * + * @param name the name representing the address + * @throws NullPointerException if {@code name} is {@code null} + */ + public TestLocalSocketAddress(String name) { + this.name = Objects.requireNonNull(name, "name must not be null"); + } + + /** Return the name for this connection. */ + public String getName() { + return name; + } + + @Override + public String toString() { + return "[local address] " + name; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java index 0f9ea8e48..fa9331d3b 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.test.util; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; @@ -6,11 +21,13 @@ import io.rsocket.Closeable; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ServerTransport; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; public class TestServerTransport implements ServerTransport { - private final MonoProcessor conn = MonoProcessor.create(); + private final Sinks.One connSink = Sinks.one(); + private TestDuplexConnection connection; private final LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @@ -18,29 +35,33 @@ public class TestServerTransport implements ServerTransport { @Override public Mono start(ConnectionAcceptor acceptor) { - conn.flatMap(acceptor::apply) + connSink + .asMono() + .flatMap(duplexConnection -> acceptor.apply(duplexConnection)) .subscribe(ignored -> {}, err -> disposeConnection(), this::disposeConnection); return Mono.just( new Closeable() { @Override public Mono onClose() { - return conn.then(); + return connSink.asMono().then(); } @Override public void dispose() { - conn.onComplete(); + connSink.tryEmitEmpty(); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return conn.isTerminated(); + return connSink.scan(Scannable.Attr.TERMINATED) + || connSink.scan(Scannable.Attr.CANCELLED); } }); } private void disposeConnection() { - TestDuplexConnection c = conn.peek(); + TestDuplexConnection c = connection; if (c != null) { c.dispose(); } @@ -48,7 +69,8 @@ private void disposeConnection() { public TestDuplexConnection connect() { TestDuplexConnection c = new TestDuplexConnection(allocator); - conn.onNext(c); + connection = c; + connSink.tryEmitValue(c); return c; } diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java index 6bae0886b..f04de78b6 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -16,14 +16,16 @@ package io.rsocket.util; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; +import static org.assertj.core.api.Assertions.assertThat; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import java.nio.ByteBuffer; -import org.assertj.core.api.Assertions; -import org.junit.Test; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.Test; public class DefaultPayloadTest { public static final String DATA_VAL = "data"; @@ -37,12 +39,12 @@ public void testReuse() { } public void assertDataAndMetadata(Payload p, String dataVal, String metadataVal) { - assertThat("Unexpected data.", p.getDataUtf8(), equalTo(dataVal)); + assertThat(p.getDataUtf8()).describedAs("Unexpected data.").isEqualTo(dataVal); if (metadataVal == null) { - assertThat("Non-null metadata", p.hasMetadata(), equalTo(false)); + assertThat(p.hasMetadata()).describedAs("Non-null metadata").isEqualTo(false); } else { - assertThat("Null metadata", p.hasMetadata(), equalTo(true)); - assertThat("Unexpected metadata.", p.getMetadataUtf8(), equalTo(metadataVal)); + assertThat(p.hasMetadata()).describedAs("Null metadata").isEqualTo(true); + assertThat(p.getMetadataUtf8()).describedAs("Unexpected metadata.").isEqualTo(metadataVal); } } @@ -56,7 +58,7 @@ public void staticMethods() { public void shouldIndicateThatItHasNotMetadata() { Payload payload = DefaultPayload.create("data"); - Assertions.assertThat(payload.hasMetadata()).isFalse(); + assertThat(payload.hasMetadata()).isFalse(); } @Test @@ -64,7 +66,7 @@ public void shouldIndicateThatItHasMetadata1() { Payload payload = DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); - Assertions.assertThat(payload.hasMetadata()).isTrue(); + assertThat(payload.hasMetadata()).isTrue(); } @Test @@ -72,6 +74,34 @@ public void shouldIndicateThatItHasMetadata2() { Payload payload = DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); - Assertions.assertThat(payload.hasMetadata()).isTrue(); + assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldReleaseGivenByteBufDataAndMetadataUpOnPayloadCreation() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (byte i = 0; i < 126; i++) { + ByteBuf data = allocator.buffer(); + data.writeByte(i); + + boolean metadataPresent = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadata = null; + if (metadataPresent) { + metadata = allocator.buffer(); + metadata.writeByte(i + 1); + } + + Payload payload = DefaultPayload.create(data, metadata); + + assertThat(payload.getData()).isEqualTo(ByteBuffer.wrap(new byte[] {i})); + + assertThat(payload.getMetadata()) + .isEqualTo( + metadataPresent + ? ByteBuffer.wrap(new byte[] {(byte) (i + 1)}) + : DefaultPayload.EMPTY_BUFFER); + allocator.assertHasNoLeaks(); + } } } diff --git a/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension new file mode 100644 index 000000000..2b51ba0de --- /dev/null +++ b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension @@ -0,0 +1 @@ +io.rsocket.frame.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle index 01e80cfa1..4059eb957 100644 --- a/rsocket-examples/build.gradle +++ b/rsocket-examples/build.gradle @@ -20,8 +20,19 @@ plugins { dependencies { implementation project(':rsocket-core') + implementation project(':rsocket-load-balancer') implementation project(':rsocket-transport-local') implementation project(':rsocket-transport-netty') + + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + + implementation 'com.netflix.concurrency-limits:concurrency-limits-core' + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + runtimeOnly 'ch.qos.logback:logback-classic' testImplementation project(':rsocket-test') @@ -29,11 +40,11 @@ dependencies { testImplementation 'org.mockito:mockito-core' testImplementation 'org.assertj:assertj-core' testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.awaitility:awaitility' + testImplementation "io.micrometer:micrometer-test" + testImplementation "io.micrometer:micrometer-tracing-integration-test" - // TODO: Remove after JUnit5 migration - testCompileOnly 'junit:junit' - testImplementation 'org.hamcrest:hamcrest-library' - testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } description = 'Example usage of the RSocket library' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java index b532c0140..463043020 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -43,9 +43,7 @@ public static void main(String[] args) { .map(s -> "Echo: " + s) .map(DefaultPayload::create)); - RSocketServer.create(echoAcceptor) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + RSocketServer.create(echoAcceptor).bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java index e1bf459b9..dfbbcde53 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java @@ -1,8 +1,9 @@ package io.rsocket.examples.transport.tcp.client; import io.rsocket.Payload; -import io.rsocket.RSocketClient; +import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketClient; import io.rsocket.core.RSocketConnector; import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; @@ -33,14 +34,14 @@ public static void main(String[] args) { .bind(TcpServerTransport.create("localhost", 7000)) .delaySubscription(Duration.ofSeconds(5)) .doOnNext(cc -> logger.info("Server started on the address : {}", cc.address())) - .subscribe(); + .block(); - RSocketClient rSocketClient = + Mono source = RSocketConnector.create() .reconnect(Retry.backoff(50, Duration.ofMillis(500))) - .toRSocketClient(TcpClientTransport.create("localhost", 7000)); + .connect(TcpClientTransport.create("localhost", 7000)); - rSocketClient + RSocketClient.from(source) .requestResponse(Mono.just(DefaultPayload.create("Test Request"))) .doOnSubscribe(s -> logger.info("Executing Request")) .doOnNext( diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java index cf68dcdde..89b22749f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java @@ -35,7 +35,7 @@ import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.util.concurrent.Queues; /** @@ -48,12 +48,12 @@ public class TaskProcessingWithServerSideNotificationsExample { public static void main(String[] args) throws InterruptedException { - UnicastProcessor tasksProcessor = - UnicastProcessor.create(Queues.unboundedMultiproducer().get()); + Sinks.Many tasksProcessor = + Sinks.many().unicast().onBackpressureBuffer(Queues.unboundedMultiproducer().get()); ConcurrentMap> idToCompletedTasksMap = new ConcurrentHashMap<>(); ConcurrentMap idToRSocketMap = new ConcurrentHashMap<>(); BackgroundWorker backgroundWorker = - new BackgroundWorker(tasksProcessor, idToCompletedTasksMap, idToRSocketMap); + new BackgroundWorker(tasksProcessor.asFlux(), idToCompletedTasksMap, idToRSocketMap); RSocketServer.create(new TasksAcceptor(tasksProcessor, idToCompletedTasksMap, idToRSocketMap)) .bindNow(TcpServerTransport.create(9991)); @@ -132,12 +132,12 @@ static class TasksAcceptor implements SocketAcceptor { static final Logger logger = LoggerFactory.getLogger(TasksAcceptor.class); - final UnicastProcessor tasksToProcess; + final Sinks.Many tasksToProcess; final ConcurrentMap> idToCompletedTasksMap; final ConcurrentMap idToRSocketMap; TasksAcceptor( - UnicastProcessor tasksToProcess, + Sinks.Many tasksToProcess, ConcurrentMap> idToCompletedTasksMap, ConcurrentMap idToRSocketMap) { this.tasksToProcess = tasksToProcess; @@ -197,11 +197,11 @@ private static class RSocketTaskHandler implements RSocket { private final String id; private final RSocket sendingSocket; private ConcurrentMap idToRSocketMap; - private UnicastProcessor tasksToProcess; + private Sinks.Many tasksToProcess; public RSocketTaskHandler( ConcurrentMap idToRSocketMap, - UnicastProcessor tasksToProcess, + Sinks.Many tasksToProcess, String id, RSocket sendingSocket) { this.id = id; @@ -213,9 +213,9 @@ public RSocketTaskHandler( @Override public Mono fireAndForget(Payload payload) { logger.info("Received a Task[{}] from Client.ID[{}]", payload.getDataUtf8(), id); - tasksToProcess.onNext(new Task(id, payload.getDataUtf8())); + Sinks.EmitResult result = tasksToProcess.tryEmitNext(new Task(id, payload.getDataUtf8())); payload.release(); - return Mono.empty(); + return result.isFailure() ? Mono.error(new Sinks.EmissionException(result)) : Mono.empty(); } @Override diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java new file mode 100644 index 000000000..272caf7a0 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java @@ -0,0 +1,144 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class LeaseManager implements Runnable { + + static final Logger logger = LoggerFactory.getLogger(LeaseManager.class); + + volatile int activeConnectionsCount; + static final AtomicIntegerFieldUpdater ACTIVE_CONNECTIONS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "activeConnectionsCount"); + + volatile int stateAndInFlight; + static final AtomicIntegerFieldUpdater STATE_AND_IN_FLIGHT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "stateAndInFlight"); + + static final int MASK_PAUSED = 0b1_000_0000_0000_0000_0000_0000_0000_0000; + static final int MASK_IN_FLIGHT = 0b0_111_1111_1111_1111_1111_1111_1111_1111; + + final BlockingDeque sendersQueue = new LinkedBlockingDeque<>(); + final Scheduler worker = Schedulers.newSingle(LeaseManager.class.getName()); + + final int capacity; + final int ttl; + + public LeaseManager(int capacity, int ttl) { + this.capacity = capacity; + this.ttl = ttl; + } + + @Override + public void run() { + try { + LimitBasedLeaseSender leaseSender = sendersQueue.poll(); + + if (leaseSender == null) { + return; + } + + if (leaseSender.isDisposed()) { + logger.debug("Connection[" + leaseSender.connectionId + "]: LeaseSender is Disposed"); + worker.schedule(this); + return; + } + + int limit = leaseSender.limitAlgorithm.getLimit(); + + if (limit == 0) { + throw new IllegalStateException("Limit is 0"); + } + + if (pauseIfNoCapacity()) { + sendersQueue.addFirst(leaseSender); + logger.debug("Pause execution. Not enough capacity"); + return; + } + + leaseSender.sendLease(ttl, limit); + sendersQueue.offer(leaseSender); + + int activeConnections = activeConnectionsCount; + int nextDelay = activeConnections == 0 ? ttl : (ttl / activeConnections); + + logger.debug("Next check happens in " + nextDelay + "ms"); + + worker.schedule(this, nextDelay, TimeUnit.MILLISECONDS); + } catch (Throwable e) { + logger.error("LeaseSender failed to send lease", e); + } + } + + int incrementInFlightAndGet() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight + 1; + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return nextInFlight; + } + } + } + + void decrementInFlight() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight - 1; + + if (inFlight == capacity && paused == MASK_PAUSED) { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight)) { + logger.debug("Resume execution"); + worker.schedule(this); + return; + } + } else { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return; + } + } + } + } + + boolean pauseIfNoCapacity() { + int capacity = this.capacity; + for (; ; ) { + int inFlight = stateAndInFlight; + + if (inFlight < capacity) { + return false; + } + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, inFlight, inFlight | MASK_PAUSED)) { + return true; + } + } + } + + void unregister() { + ACTIVE_CONNECTIONS_COUNT.decrementAndGet(this); + } + + void register(LimitBasedLeaseSender sender) { + sendersQueue.offer(sender); + final int activeCount = ACTIVE_CONNECTIONS_COUNT.getAndIncrement(this); + + if (activeCount == 0) { + worker.schedule(this); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java new file mode 100644 index 000000000..8e1b27823 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java @@ -0,0 +1,54 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.rsocket.lease.Lease; +import io.rsocket.lease.TrackingLeaseSender; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.util.concurrent.Queues; + +public class LimitBasedLeaseSender extends LimitBasedStatsCollector implements TrackingLeaseSender { + + static final Logger logger = LoggerFactory.getLogger(LimitBasedLeaseSender.class); + + final String connectionId; + final Sinks.Many sink = + Sinks.many().unicast().onBackpressureBuffer(Queues.one().get()); + + public LimitBasedLeaseSender( + String connectionId, LeaseManager leaseManager, Limit limitAlgorithm) { + super(leaseManager, limitAlgorithm); + this.connectionId = connectionId; + } + + @Override + public Flux send() { + logger.info("Received new leased Connection[" + connectionId + "]"); + + leaseManager.register(this); + + return sink.asFlux(); + } + + public void sendLease(int ttl, int amount) { + final Lease nextLease = Lease.create(Duration.ofMillis(ttl), amount); + final Sinks.EmitResult result = sink.tryEmitNext(nextLease); + + if (result.isFailure()) { + logger.warn( + "Connection[" + + connectionId + + "]. Issued Lease: [" + + nextLease + + "] was not sent due to " + + result); + } else { + if (logger.isDebugEnabled()) { + logger.debug("To Connection[" + connectionId + "]: Issued Lease: [" + nextLease + "]"); + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java new file mode 100644 index 000000000..7f639ab87 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java @@ -0,0 +1,73 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.LongSupplier; +import reactor.util.annotation.Nullable; + +public class LimitBasedStatsCollector extends AtomicBoolean implements RequestInterceptor { + + final LeaseManager leaseManager; + final Limit limitAlgorithm; + + final ConcurrentMap inFlightMap = new ConcurrentHashMap<>(); + final ConcurrentMap timeMap = new ConcurrentHashMap<>(); + + final LongSupplier clock = System::nanoTime; + + public LimitBasedStatsCollector(LeaseManager leaseManager, Limit limitAlgorithm) { + this.leaseManager = leaseManager; + this.limitAlgorithm = limitAlgorithm; + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + long startTime = clock.getAsLong(); + + int currentInFlight = leaseManager.incrementInFlightAndGet(); + + inFlightMap.put(streamId, currentInFlight); + timeMap.put(streamId, startTime); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) {} + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, t != null); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, true); + } + + @Override + public boolean isDisposed() { + return get(); + } + + @Override + public void dispose() { + if (!getAndSet(true)) { + leaseManager.unregister(); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java new file mode 100644 index 000000000..a18dd9484 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java @@ -0,0 +1,27 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// emulating a worker that process data from the queue +public class Task implements Runnable { + private static final Logger logger = LoggerFactory.getLogger(Task.class); + + final String message; + final int processingTime; + + Task(String message, int processingTime) { + this.message = message; + this.processingTime = processingTime; + } + + @Override + public void run() { + logger.info("Processing Task[{}]", message); + try { + Thread.sleep(processingTime); // emulating processing + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java new file mode 100644 index 000000000..cbecadfc3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java @@ -0,0 +1,44 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; + +public class TasksHandlingRSocket implements RSocket { + + private static final Logger logger = LoggerFactory.getLogger(TasksHandlingRSocket.class); + + final Disposable terminatable; + final Scheduler workScheduler; + final int processingTime; + + public TasksHandlingRSocket(Disposable terminatable, Scheduler scheduler, int processingTime) { + this.terminatable = terminatable; + this.workScheduler = scheduler; + this.processingTime = processingTime; + } + + @Override + public Mono fireAndForget(Payload payload) { + + // specifically to show that lease can limit rate of fnf requests in + // that example + String message = payload.getDataUtf8(); + payload.release(); + + return Mono.fromRunnable(new Task(message, processingTime)) + // schedule task on specific, limited in size scheduler + .subscribeOn(workScheduler) + // if errors - terminates server + .doOnError( + t -> { + logger.error("Queue has been overflowed. Terminating server"); + terminatable.dispose(); + System.exit(9); + }); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java new file mode 100644 index 000000000..30eb0c0e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Comparator; +import java.util.concurrent.PriorityBlockingQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RequestingServer { + + private static final Logger logger = LoggerFactory.getLogger(RequestingServer.class); + + public static void main(String[] args) { + PriorityBlockingQueue rSockets = + new PriorityBlockingQueue<>( + 16, Comparator.comparingDouble(RSocket::availability).reversed()); + + CloseableChannel server = + RSocketServer.create( + (setup, sendingSocket) -> { + logger.info("Received new connection"); + return Mono.just(new RSocket() {}) + .doAfterTerminate(() -> rSockets.put(sendingSocket)); + }) + .lease(spec -> spec.maxPendingRequests(Integer.MAX_VALUE)) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + logger.info("Server started on port {}", server.address().getPort()); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .flatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + + return Mono.fromCallable( + () -> { + RSocket rSocket = rSockets.take(); + rSockets.offer(rSocket); + return rSocket; + }) + .flatMap( + clientRSocket -> + clientRSocket.fireAndForget(ByteBufPayload.create("" + tick))) + .retry(); + }) + .blockLast(); + + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java new file mode 100644 index 000000000..4a06855b2 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java @@ -0,0 +1,67 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.client.TcpClientTransport; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingClient { + private static final Logger logger = LoggerFactory.getLogger(RespondingClient.class); + + public static final int PROCESSING_TASK_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager periodicLeaseSender = + new LeaseManager(CONCURRENT_WORKERS_COUNT, PROCESSING_TASK_TIME); + + Disposable.Composite disposable = Disposables.composite(); + RSocket clientRSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, PROCESSING_TASK_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + periodicLeaseSender, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + disposable.add(clientRSocket); + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java new file mode 100644 index 000000000..c2fde38e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java @@ -0,0 +1,41 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class RequestingClient { + private static final Logger logger = LoggerFactory.getLogger(RequestingClient.class); + + public static void main(String[] args) { + + RSocket clientRSocket = + RSocketConnector.create() + .lease() + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .concatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); + }) + .blockLast(); + + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java new file mode 100644 index 000000000..b54330450 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketServer; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingServer { + + private static final Logger logger = LoggerFactory.getLogger(RespondingServer.class); + + public static final int TASK_PROCESSING_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager leaseManager = new LeaseManager(CONCURRENT_WORKERS_COUNT, TASK_PROCESSING_TIME); + + Disposable.Composite disposable = Disposables.composite(); + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, TASK_PROCESSING_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + leaseManager, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + disposable.add(server); + + logger.info("Server started on port {}", server.address().getPort()); + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java similarity index 66% rename from rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java rename to rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java index 49f683204..c54335ccc 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java @@ -14,33 +14,26 @@ * limitations under the License. */ -package io.rsocket.examples.transport.tcp.lease; +package io.rsocket.examples.transport.tcp.lease.simple; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.core.RSocketConnector; import io.rsocket.core.RSocketServer; import io.rsocket.lease.Lease; -import io.rsocket.lease.LeaseStats; -import io.rsocket.lease.Leases; -import io.rsocket.lease.MissingLeaseException; +import io.rsocket.lease.LeaseSender; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.function.Consumer; -import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.ReplayProcessor; -import reactor.util.retry.Retry; public class LeaseExample { @@ -95,13 +88,12 @@ public Mono fireAndForget(Payload payload) { return Mono.empty(); } })) - .lease(() -> Leases.create().sender(new LeaseCalculator(SERVER_TAG, messagesQueue))) + .lease(leases -> leases.sender(new LeaseCalculator(SERVER_TAG, messagesQueue))) .bindNow(TcpServerTransport.create("localhost", 7000)); - LeaseReceiver receiver = new LeaseReceiver(CLIENT_TAG); RSocket clientRSocket = RSocketConnector.create() - .lease(() -> Leases.create().receiver(receiver)) + .lease((config) -> config.maxPendingRequests(1)) .connect(TcpClientTransport.create(server.address())) .block(); @@ -116,22 +108,10 @@ public Mono fireAndForget(Payload payload) { }) // here we wait for the first lease for the responder side and start execution // on if there is allowance - .delaySubscription(receiver.notifyWhenNewLease().then()) .concatMap( tick -> { logger.info("Requesting FireAndForget({})", tick); - return Mono.defer(() -> clientRSocket.fireAndForget(ByteBufPayload.create("" + tick))) - .retryWhen( - Retry.indefinitely() - // ensures that error is the result of missed lease - .filter(t -> t instanceof MissingLeaseException) - .doBeforeRetryAsync( - rs -> { - // here we create a mechanism to delay the retry until - // the new lease allowance comes in. - logger.info("Ran out of leases {}", rs); - return receiver.notifyWhenNewLease().then(); - })); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); }) .blockLast(); @@ -146,7 +126,7 @@ public Mono fireAndForget(Payload payload) { * connection.
    * In real-world projects this class has to issue leases based on real metrics */ - private static class LeaseCalculator implements Function, Flux> { + private static class LeaseCalculator implements LeaseSender { final String tag; final BlockingQueue queue; @@ -156,8 +136,7 @@ public LeaseCalculator(String tag, BlockingQueue queue) { } @Override - public Flux apply(Optional leaseStats) { - logger.info("{} stats are {}", tag, leaseStats.isPresent() ? "present" : "absent"); + public Flux send() { Duration ttlDuration = Duration.ofSeconds(5); // The interval function is used only for the demo purpose and should not be // considered as the way to issue leases. @@ -173,45 +152,9 @@ public Flux apply(Optional leaseStats) { // reissue new lease only if queue has remaining capacity to // accept more requests if (requests > 0) { - long ttl = ttlDuration.toMillis(); - sink.next(Lease.create((int) ttl, requests)); + sink.next(Lease.create(ttlDuration, requests)); } }); } } - - /** - * Requester-side Lease listener.
    - * In the nutshell this class implements mechanism to listen (and do appropriate actions as - * needed) to incoming leases issued by the Responder - */ - private static class LeaseReceiver implements Consumer> { - final String tag; - final ReplayProcessor lastLeaseReplay = ReplayProcessor.cacheLast(); - - public LeaseReceiver(String tag) { - this.tag = tag; - } - - @Override - public void accept(Flux receivedLeases) { - receivedLeases.subscribe( - l -> { - logger.info( - "{} received leases - ttl: {}, requests: {}", - tag, - l.getTimeToLiveMillis(), - l.getAllowedRequests()); - lastLeaseReplay.onNext(l); - }); - } - - /** - * This method allows to listen to new incoming leases and delay some action (e.g . retry) until - * new valid lease has come in - */ - public Mono notifyWhenNewLease() { - return lastLeaseReplay.filter(l -> l.isValid()).next(); - } - } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java new file mode 100644 index 000000000..abed4a52d --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.examples.transport.tcp.loadbalancer; + +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketServer; +import io.rsocket.loadbalance.LoadbalanceRSocketClient; +import io.rsocket.loadbalance.LoadbalanceTarget; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RoundRobinRSocketLoadbalancerExample { + + public static void main(String[] args) { + CloseableChannel server1 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 1 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 1 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8080)); + + CloseableChannel server2 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 2 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 2 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8081)); + + CloseableChannel server3 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 3 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 3 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8082)); + + LoadbalanceTarget target8080 = LoadbalanceTarget.from("8080", TcpClientTransport.create(8080)); + LoadbalanceTarget target8081 = LoadbalanceTarget.from("8081", TcpClientTransport.create(8081)); + LoadbalanceTarget target8082 = LoadbalanceTarget.from("8082", TcpClientTransport.create(8082)); + + Flux> producer = + Flux.interval(Duration.ofSeconds(5)) + .log() + .map( + i -> { + int val = i.intValue(); + switch (val) { + case 0: + return Collections.emptyList(); + case 1: + return Collections.singletonList(target8080); + case 2: + return Arrays.asList(target8080, target8081); + case 3: + return Arrays.asList(target8080, target8082); + case 4: + return Arrays.asList(target8081, target8082); + case 5: + return Arrays.asList(target8080, target8081, target8082); + case 6: + return Collections.emptyList(); + case 7: + return Collections.emptyList(); + default: + return Arrays.asList(target8080, target8081, target8082); + } + }); + + RSocketClient rSocketClient = + LoadbalanceRSocketClient.builder(producer).roundRobinLoadbalanceStrategy().build(); + + for (int i = 0; i < 10000; i++) { + try { + rSocketClient.requestResponse(Mono.just(DefaultPayload.create("test" + i))).block(); + } catch (Throwable t) { + // no ops + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java new file mode 100644 index 000000000..a0a02a946 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java @@ -0,0 +1,102 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class CompositeMetadataExample { + static final Logger logger = LoggerFactory.getLogger(CompositeMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that every metadata payload will be encoded using + // CompositeMetadata layout as specified in the following subspec + // https://github.com/rsocket/rsocket/blob/master/Extensions/CompositeMetadata.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + final CompositeByteBuf compositeMetadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadata, + ByteBufAllocator.DEFAULT, + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, + routeMetadata); + + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), compositeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + + for (CompositeMetadata.Entry metadatum : compositeMetadata) { + if (Objects.requireNonNull(metadatum.getMimeType()) + .equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString())) { + return new RoutingMetadata(metadatum.getContent()).iterator().next(); + } + } + + return null; + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java new file mode 100644 index 000000000..2aee18bf9 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class RoutingMetadataExample { + static final Logger logger = LoggerFactory.getLogger(RoutingMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that route will be encoded using + // Routing&Tagging Metadata layout specified at this + // subspec https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), routeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final RoutingMetadata routingMetadata = new RoutingMetadata(metadata); + + return routingMetadata.iterator().next(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java index 67a85b67f..5491a1aab 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java @@ -39,8 +39,7 @@ public Flux requestChannel(Publisher payloads) { } })) .interceptors(registry -> registry.forResponder(LimitRateInterceptor.forResponder(64))) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = RSocketConnector.create() diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java index 85faeee82..0c372d2d8 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -50,8 +50,7 @@ public Mono requestResponse(Payload p) { }; RSocketServer.create(SocketAcceptor.with(rsocket)) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java index 93b54e146..ba82c7c93 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -62,11 +62,11 @@ public static void main(String[] args) { return Files.fileSource(fileName, chunkSize) .map(DefaultPayload::create) - .zipWith(ticks, (p, tick) -> p); + .zipWith(ticks, (p, tick) -> p) + .log("server"); })) .resume(resume) - .bind(TcpServerTransport.create("localhost", 8000)) - .block(); + .bindNow(TcpServerTransport.create("localhost", 8000)); RSocket client = RSocketConnector.create() @@ -76,8 +76,9 @@ public static void main(String[] args) { client .requestStream(codec.encode(new Request(16, "lorem.txt"))) + .log("client") .doFinally(s -> server.dispose()) - .subscribe(Files.fileSink("rsocket-examples/out/lorem_output.txt", PREFETCH_WINDOW_SIZE)); + .subscribe(Files.fileSink("rsocket-examples/build/lorem_output.txt", PREFETCH_WINDOW_SIZE)); server.onClose().block(); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java index 8116ad4ae..af0df3be1 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java @@ -33,20 +33,23 @@ public final class ClientStreamingToServer { private static final Logger logger = LoggerFactory.getLogger(ClientStreamingToServer.class); - public static void main(String[] args) { + public static void main(String[] args) throws InterruptedException { RSocketServer.create( SocketAcceptor.forRequestStream( payload -> Flux.interval(Duration.ofMillis(100)) .map(aLong -> DefaultPayload.create("Interval: " + aLong)))) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = - RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); + RSocketConnector.create() + .setupPayload(DefaultPayload.create("test", "test")) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + final Payload payload = DefaultPayload.create("Hello"); socket - .requestStream(DefaultPayload.create("Hello")) + .requestStream(payload) .map(Payload::getDataUtf8) .doOnNext(logger::debug) .take(10) @@ -54,5 +57,7 @@ public static void main(String[] args) { .doFinally(signalType -> socket.dispose()) .then() .block(); + + Thread.sleep(1000000); } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java index f5b1e00e5..10ed34553 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java @@ -43,8 +43,7 @@ public static void main(String[] args) { return Mono.just(new RSocket() {}); }) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket rsocket = RSocketConnector.create() diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java new file mode 100644 index 000000000..89304853c --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketAggregationSample { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketAggregationSample.class); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor connectionAcceptor = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .port(0) + .handle( + (req, res) -> + res.sendWebsocket( + (in, out) -> + connectionAcceptor + .apply( + new WebsocketDuplexConnection( + (Connection) in.aggregateFrames())) + .then(out.neverComplete()))) + .bindNow(); + + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()); + + RSocket clientRSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(transport) + .block(); + + Flux.range(1, 100) + .concatMap(i -> clientRSocket.requestResponse(ByteBufPayload.create("Hello " + i))) + .doOnNext(payload -> logger.debug("Processed " + payload.getDataUtf8())) + .blockLast(); + clientRSocket.dispose(); + server.dispose(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index e2471f2fc..ac311a231 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -16,9 +16,8 @@ package io.rsocket.integration; -import static org.hamcrest.Matchers.is; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -38,9 +37,10 @@ import io.rsocket.util.RSocketProxy; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; @@ -108,7 +108,7 @@ public Mono requestResponse(Payload payload) { private CountDownLatch disconnectionCounter; private AtomicInteger errorCount; - @Before + @BeforeEach public void startup() { errorCount = new AtomicInteger(); requestCount = new AtomicInteger(); @@ -163,23 +163,26 @@ public Flux requestChannel(Publisher payloads) { .block(); } - @After + @AfterEach public void teardown() { server.dispose(); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testRequest() { client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); - assertThat("Server did not see the request.", requestCount.get(), is(1)); - assertTrue(calledRequester); - assertTrue(calledResponder); - assertTrue(calledClientAcceptor); - assertTrue(calledServerAcceptor); - assertTrue(calledFrame); + assertThat(requestCount).as("Server did not see the request.").hasValue(1); + + assertThat(calledRequester).isTrue(); + assertThat(calledResponder).isTrue(); + assertThat(calledClientAcceptor).isTrue(); + assertThat(calledServerAcceptor).isTrue(); + assertThat(calledFrame).isTrue(); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testStream() { Subscriber subscriber = TestSubscriber.createCancelling(); client.requestStream(DefaultPayload.create("start")).subscribe(subscriber); @@ -188,7 +191,8 @@ public void testStream() { verifyNoMoreInteractions(subscriber); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testClose() throws InterruptedException { client.dispose(); disconnectionCounter.await(); @@ -196,10 +200,8 @@ public void testClose() throws InterruptedException { @Test // (timeout = 5_000L) public void testCallRequestWithErrorAndThenRequest() { - try { - client.requestChannel(Mono.error(new Throwable())).blockLast(); - } catch (Throwable t) { - } + assertThatThrownBy(client.requestChannel(Mono.error(new Throwable("test")))::blockLast) + .hasMessage("java.lang.Throwable: test"); testRequest(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index de27bcb9b..1924668fb 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,7 @@ package io.rsocket.integration; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -31,12 +30,13 @@ import io.rsocket.util.RSocketProxy; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.core.scheduler.Schedulers; public class TcpIntegrationTest { @@ -44,7 +44,7 @@ public class TcpIntegrationTest { private CloseableChannel server; - @Before + @BeforeEach public void startup() { server = RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) @@ -56,12 +56,13 @@ private RSocket buildClient() { return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); } - @After + @AfterEach public void cleanup() { server.dispose(); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testCompleteWithoutNext() { handler = new RSocket() { @@ -74,10 +75,11 @@ public Flux requestStream(Payload payload) { Boolean hasElements = client.requestStream(DefaultPayload.create("REQUEST", "META")).log().hasElements().block(); - assertFalse(hasElements); + assertThat(hasElements).isFalse(); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testSingleStream() { handler = new RSocket() { @@ -91,10 +93,11 @@ public Flux requestStream(Payload payload) { Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockLast(); - assertEquals("RESPONSE", result.getDataUtf8()); + assertThat(result.getDataUtf8()).isEqualTo("RESPONSE"); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testZeroPayload() { handler = new RSocket() { @@ -108,10 +111,11 @@ public Flux requestStream(Payload payload) { Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockFirst(); - assertEquals("", result.getDataUtf8()); + assertThat(result.getDataUtf8()).isEmpty(); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testRequestResponseErrors() { handler = new RSocket() { @@ -141,23 +145,24 @@ public Mono requestResponse(Payload payload) { .onErrorReturn(DefaultPayload.create("ERROR")) .block(); - assertEquals("ERROR", response1.getDataUtf8()); - assertEquals("SUCCESS", response2.getDataUtf8()); + assertThat(response1.getDataUtf8()).isEqualTo("ERROR"); + assertThat(response2.getDataUtf8()).isEqualTo("SUCCESS"); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testTwoConcurrentStreams() throws InterruptedException { - ConcurrentHashMap> map = new ConcurrentHashMap<>(); - UnicastProcessor processor1 = UnicastProcessor.create(); + ConcurrentHashMap> map = new ConcurrentHashMap<>(); + Sinks.Many processor1 = Sinks.many().unicast().onBackpressureBuffer(); map.put("REQUEST1", processor1); - UnicastProcessor processor2 = UnicastProcessor.create(); + Sinks.Many processor2 = Sinks.many().unicast().onBackpressureBuffer(); map.put("REQUEST2", processor2); handler = new RSocket() { @Override public Flux requestStream(Payload payload) { - return map.get(payload.getDataUtf8()); + return map.get(payload.getDataUtf8()).asFlux(); } }; @@ -177,13 +182,13 @@ public Flux requestStream(Payload payload) { .subscribeOn(Schedulers.newSingle("2")) .subscribe(c -> nextCountdown.countDown(), t -> {}, completeCountdown::countDown); - processor1.onNext(DefaultPayload.create("RESPONSE1A")); - processor2.onNext(DefaultPayload.create("RESPONSE2A")); + processor1.tryEmitNext(DefaultPayload.create("RESPONSE1A")); + processor2.tryEmitNext(DefaultPayload.create("RESPONSE2A")); nextCountdown.await(); - processor1.onComplete(); - processor2.onComplete(); + processor1.tryEmitComplete(); + processor2.tryEmitComplete(); completeCountdown.await(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java index 7d34ba478..cd96584ed 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -27,13 +27,14 @@ import io.rsocket.util.DefaultPayload; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Test; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; public class TestingStreaming { LocalServerTransport serverTransport = LocalServerTransport.create("test"); - @Test(expected = ApplicationErrorException.class) + @Test public void testRangeButThrowException() { Closeable server = null; try { @@ -53,8 +54,9 @@ public void testRangeButThrowException() { .bind(serverTransport) .block(); - Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); - System.out.println("here"); + Assertions.assertThatThrownBy( + Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i))::blockLast) + .isInstanceOf(ApplicationErrorException.class); } finally { server.dispose(); @@ -76,8 +78,6 @@ public void testRangeOfConsumers() { .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); - System.out.println("here"); - } finally { server.dispose(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java new file mode 100644 index 000000000..870ecf0cd --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java @@ -0,0 +1,246 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration.observation; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.micrometer.core.tck.MeterRegistryAssert; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.tracing.test.SampleTestRunner; +import io.micrometer.tracing.test.reporter.BuildingBlocks; +import io.micrometer.tracing.test.simple.SpansAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.micrometer.observation.ByteBufGetter; +import io.rsocket.micrometer.observation.ByteBufSetter; +import io.rsocket.micrometer.observation.ObservationRequesterRSocketProxy; +import io.rsocket.micrometer.observation.ObservationResponderRSocketProxy; +import io.rsocket.micrometer.observation.RSocketRequesterTracingObservationHandler; +import io.rsocket.micrometer.observation.RSocketResponderTracingObservationHandler; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Deque; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class ObservationIntegrationTest extends SampleTestRunner { + private static final MeterRegistry registry = new SimpleMeterRegistry(); + private static final ObservationRegistry observationRegistry = ObservationRegistry.create(); + + static { + observationRegistry + .observationConfig() + .observationHandler(new DefaultMeterObservationHandler(registry)); + } + + private final RSocketInterceptor requesterInterceptor; + private final RSocketInterceptor responderInterceptor; + + ObservationIntegrationTest() { + super(SampleRunnerConfig.builder().build()); + requesterInterceptor = + reactiveSocket -> new ObservationRequesterRSocketProxy(reactiveSocket, observationRegistry); + + responderInterceptor = + reactiveSocket -> new ObservationResponderRSocketProxy(reactiveSocket, observationRegistry); + } + + private CloseableChannel server; + private RSocket client; + private AtomicInteger counter; + + @Override + public BiConsumer>> + customizeObservationHandlers() { + return (buildingBlocks, observationHandlers) -> { + observationHandlers.addFirst( + new RSocketRequesterTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufSetter(), + false)); + observationHandlers.addFirst( + new RSocketResponderTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufGetter(), + false)); + }; + } + + @AfterEach + public void teardown() { + if (server != null) { + server.dispose(); + } + } + + private void testRequest() { + counter.set(0); + client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testStream() { + counter.set(0); + client.requestStream(DefaultPayload.create("start")).blockLast(); + + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testRequestChannel() { + counter.set(0); + client.requestChannel(Mono.just(DefaultPayload.create("start"))).blockFirst(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testFireAndForget() { + counter.set(0); + client.fireAndForget(DefaultPayload.create("start")).subscribe(); + Awaitility.await().atMost(Duration.ofSeconds(50)).until(() -> counter.get() == 1); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + @Override + public SampleTestRunnerConsumer yourCode() { + return (bb, meterRegistry) -> { + counter = new AtomicInteger(); + server = + RSocketServer.create( + (setup, sendingSocket) -> { + sendingSocket.onClose().subscribe(); + + return Mono.just( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")); + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Flux.range(1, 10_000) + .map(i -> DefaultPayload.create("data -> " + i)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + counter.incrementAndGet(); + return Flux.from(payloads); + } + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.empty(); + } + }); + }) + .interceptors(registry -> registry.forResponder(responderInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + + client = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(requesterInterceptor)) + .connect(TcpClientTransport.create(server.address())) + .block(); + + testRequest(); + + testStream(); + + testRequestChannel(); + + testFireAndForget(); + + // @formatter:off + SpansAssert.assertThat(bb.getFinishedSpans()) + .haveSameTraceId() + // "request_*" + "handle" x 4 + .hasNumberOfSpansEqualTo(8) + .hasNumberOfSpansWithNameEqualTo("handle", 4) + .forAllSpansWithNameEqualTo("handle", span -> span.hasTagWithKey("rsocket.request-type")) + .hasASpanWithNameIgnoreCase("request_stream") + .thenASpanWithNameEqualToIgnoreCase("request_stream") + .hasTag("rsocket.request-type", "REQUEST_STREAM") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_channel") + .thenASpanWithNameEqualToIgnoreCase("request_channel") + .hasTag("rsocket.request-type", "REQUEST_CHANNEL") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_fnf") + .thenASpanWithNameEqualToIgnoreCase("request_fnf") + .hasTag("rsocket.request-type", "REQUEST_FNF") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_response") + .thenASpanWithNameEqualToIgnoreCase("request_response") + .hasTag("rsocket.request-type", "REQUEST_RESPONSE"); + + MeterRegistryAssert.assertThat(registry) + .hasTimerWithNameAndTags( + "rsocket.response", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.fnf", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_FNF"))) + .hasTimerWithNameAndTags( + "rsocket.request", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.channel", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_CHANNEL"))) + .hasTimerWithNameAndTags( + "rsocket.stream", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_STREAM"))); + // @formatter:on + }; + } + + @Override + protected MeterRegistry getMeterRegistry() { + return registry; + } + + @Override + protected ObservationRegistry getObservationRegistry() { + return observationRegistry; + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java index b2dad0022..5eb78fabe 100644 --- a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -182,7 +182,7 @@ private static Mono newClientRSocket( .resume( new Resume() .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) - .storeFactory(t -> new InMemoryResumableFramesStore("client", 500_000)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", t, 500_000)) .cleanupStoreOnKeepAlive() .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) @@ -199,7 +199,7 @@ private static Mono newServerRSocket(int sessionDurationSecond new Resume() .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) .cleanupStoreOnKeepAlive() - .storeFactory(t -> new InMemoryResumableFramesStore("server", 500_000))) + .storeFactory(t -> new InMemoryResumableFramesStore("server", t, 500_000))) .bind(serverTransport(SERVER_HOST, SERVER_PORT)); } @@ -212,7 +212,7 @@ public Flux requestChannel(Publisher payloads) { return duplicate( Flux.interval(Duration.ofMillis(1)) .onBackpressureLatest() - .publishOn(Schedulers.elastic()), + .publishOn(Schedulers.boundedElastic()), 20) .map(v -> DefaultPayload.create(String.valueOf(counter.getAndIncrement()))) .takeUntilOther(Flux.from(payloads).then()); diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle index 748f95de6..6d91324ae 100644 --- a/rsocket-load-balancer/build.gradle +++ b/rsocket-load-balancer/build.gradle @@ -17,8 +17,7 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { @@ -28,12 +27,12 @@ dependencies { testImplementation project(':rsocket-test') testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' testImplementation 'org.mockito:mockito-core' + testImplementation 'org.assertj:assertj-core' + testImplementation 'io.projectreactor:reactor-test' - // TODO: Remove after JUnit5 migration - testCompileOnly 'junit:junit' - testImplementation 'org.hamcrest:hamcrest-library' - testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' testRuntimeOnly 'ch.qos.logback:logback-classic' } diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index 65ce80934..6329da826 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -34,7 +34,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,13 +41,19 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.retry.Retry; /** * An implementation of {@link Mono} that load balances across a pool of RSockets and emits one when * it is subscribed to * *

    It estimates the load of each RSocket based on statistics collected. + * + * @deprecated as of 1.1. in favor of {@link io.rsocket.loadbalance.LoadbalanceRSocketClient}. */ +@Deprecated public abstract class LoadBalancedRSocketMono extends Mono implements Availability, Closeable { @@ -589,7 +594,9 @@ private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { factory .get() - .retryBackoff(weightedSocketRetries, weightedSocketBackOff, weightedSocketMaxBackOff) + .retryWhen( + Retry.backoff(weightedSocketRetries, weightedSocketBackOff) + .maxBackoff(weightedSocketMaxBackOff)) .doOnError( throwable -> { logger.error( @@ -661,26 +668,28 @@ private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { @Override public Mono requestResponse(Payload payload) { return rSocketMono.flatMap( - source -> { - return Mono.from( - subscriber -> - source - .requestResponse(payload) - .subscribe(new LatencySubscriber<>(subscriber, this))); - }); + source -> + Mono.from( + subscriber -> + source + .requestResponse(payload) + .subscribe( + new LatencySubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } @Override public Flux requestStream(Payload payload) { return rSocketMono.flatMapMany( - source -> { - return Flux.from( - subscriber -> - source - .requestStream(payload) - .subscribe(new CountingSubscriber<>(subscriber, this))); - }); + source -> + Flux.from( + subscriber -> + source + .requestStream(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } @Override @@ -692,7 +701,9 @@ public Mono fireAndForget(Payload payload) { subscriber -> source .fireAndForget(payload) - .subscribe(new CountingSubscriber<>(subscriber, this))); + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); }); } @@ -704,7 +715,9 @@ public Mono metadataPush(Payload payload) { subscriber -> source .metadataPush(payload) - .subscribe(new CountingSubscriber<>(subscriber, this))); + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); }); } @@ -712,13 +725,14 @@ public Mono metadataPush(Payload payload) { public Flux requestChannel(Publisher payloads) { return rSocketMono.flatMapMany( - source -> { - return Flux.from( - subscriber -> - source - .requestChannel(payloads) - .subscribe(new CountingSubscriber<>(subscriber, this))); - }); + source -> + Flux.from( + subscriber -> + source + .requestChannel(payloads) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } synchronized double getPredictedLatency() { @@ -861,18 +875,23 @@ public long lastTimeUsedMillis() { * Subscriber wrapper used for request/response interaction model, measure and collect latency * information. */ - private class LatencySubscriber implements Subscriber { - private final Subscriber child; + private class LatencySubscriber implements CoreSubscriber { + private final CoreSubscriber child; private final WeightedSocket socket; private final AtomicBoolean done; private long start; - LatencySubscriber(Subscriber child, WeightedSocket socket) { + LatencySubscriber(CoreSubscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; this.done = new AtomicBoolean(false); } + @Override + public Context currentContext() { + return child.currentContext(); + } + @Override public void onSubscribe(Subscription s) { start = incr(); @@ -925,15 +944,20 @@ public void onComplete() { * Subscriber wrapper used for stream like interaction model, it only counts the number of * active streams */ - private class CountingSubscriber implements Subscriber { - private final Subscriber child; + private class CountingSubscriber implements CoreSubscriber { + private final CoreSubscriber child; private final WeightedSocket socket; - CountingSubscriber(Subscriber child, WeightedSocket socket) { + CountingSubscriber(CoreSubscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; } + @Override + public Context currentContext() { + return child.currentContext(); + } + @Override public void onSubscribe(Subscription s) { socket.pendingStreams.incrementAndGet(); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java index 00d1861af..0cb35d180 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java @@ -18,6 +18,7 @@ import io.rsocket.Availability; +@Deprecated /** A contract for the metrics managed by {@link LoadBalancedRSocketMono} per socket. */ public interface LoadBalancerSocketMetrics extends Availability { diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java index 08fb98d2d..295d25d75 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java @@ -16,6 +16,7 @@ package io.rsocket.client; +@Deprecated /** An exception that indicates that no RSocket was available. */ public final class NoAvailableRSocketException extends Exception { diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java index 1683ee125..8249083ad 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java @@ -15,6 +15,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +@Deprecated public class RSocketSupplierPool implements Supplier>, Consumer, Closeable { private static final Logger logger = LoggerFactory.getLogger(RSocketSupplierPool.class); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java index 7bd6c7135..a32ac2224 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java @@ -16,6 +16,7 @@ package io.rsocket.client; +@Deprecated public final class TimeoutException extends Exception { public static final TimeoutException INSTANCE = new TimeoutException(); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java index 6823e3db8..4779c6d4d 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java @@ -16,6 +16,7 @@ package io.rsocket.client; +@Deprecated public final class TransportException extends Throwable { private static final long serialVersionUID = -3339846338318701123L; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java index db7e18df4..beb424797 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java @@ -33,6 +33,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +@Deprecated public class BackupRequestSocket implements RSocket { private final ScheduledExecutorService executor; private final RSocket child; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java index 7a94bdf07..aaf9f71e6 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java @@ -31,6 +31,7 @@ import reactor.core.publisher.MonoProcessor; /** */ +@Deprecated public class RSocketSupplier implements Availability, Supplier>, Closeable { private static final double EPSILON = 1e-4; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java index d52cf87f1..89ff74143 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java @@ -27,6 +27,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +@Deprecated public final class RSockets { private RSockets() { diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java index 4641f5411..3968ec0a4 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java @@ -27,6 +27,7 @@ *

    e.g. with a half-life of 10 unit, if you insert 100 at t=0 and 200 at t=10 the ewma will be * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) */ +@Deprecated public class Ewma { private final long tau; private volatile long stamp; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java index c152c4a80..99c12e801 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java @@ -25,6 +25,7 @@ * *

    More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ */ +@Deprecated public class FrugalQuantile implements Quantile { private final double increment; private double quantile; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java index d0a982a26..00dd69de9 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java @@ -17,6 +17,7 @@ package io.rsocket.stat; /** This implementation gives better results because it considers more data-point. */ +@Deprecated public class Median extends FrugalQuantile { public Median() { super(0.5, 1.0, null); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java index fa62aedc9..aa3667e8f 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java @@ -15,6 +15,7 @@ */ package io.rsocket.stat; +@Deprecated public interface Quantile { /** @return the estimation of the current value of the quantile */ double estimation(); diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java index 4baa106c5..52bf89558 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java @@ -16,6 +16,9 @@ package io.rsocket.client; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSocketSupplier; @@ -24,8 +27,9 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -33,7 +37,8 @@ public class LoadBalancedRSocketMonoTest { - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) public void testNeverSelectFailingFactories() throws InterruptedException { TestingRSocket socket = new TestingRSocket(Function.identity()); RSocketSupplier failing = failingClient(); @@ -43,7 +48,8 @@ public void testNeverSelectFailingFactories() throws InterruptedException { testBalancer(factories); } - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) public void testNeverSelectFailingSocket() throws InterruptedException { TestingRSocket socket = new TestingRSocket(Function.identity()); TestingRSocket failingSocket = @@ -66,7 +72,9 @@ public double availability() { testBalancer(clients); } - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) + @Disabled public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDelivered() { TestingRSocket socket = new TestingRSocket(Function.identity()); @@ -85,12 +93,12 @@ public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDe LoadBalancedRSocketMono balancer = LoadBalancedRSocketMono.create(factories); - Assert.assertEquals(0.0, balancer.availability(), 0); + assertThat(balancer.availability()).isZero(); laterSupplier.complete(succeedingFactory(socket)); balancer.rSocketMono.block(); - Assert.assertEquals(1.0, balancer.availability(), 0); + assertThat(balancer.availability()).isEqualTo(1.0); } private void testBalancer(List factories) throws InterruptedException { @@ -126,7 +134,7 @@ private static RSocketSupplier failingClient() { Mockito.when(mock.get()) .thenAnswer( a -> { - Assert.fail(); + fail(); return null; }); diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java index 887132f99..9e1982465 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java @@ -16,9 +16,8 @@ package io.rsocket.client; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -31,7 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -44,7 +43,7 @@ public class RSocketSupplierTest { public void testError() throws InterruptedException { testRSocket( (latch, socket) -> { - assertEquals(1.0, socket.availability(), 0.0); + assertThat(socket.availability()).isEqualTo(1.0); Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); Subscriber subscriber = TestSubscriber.create(); @@ -64,7 +63,7 @@ public void testError() throws InterruptedException { payloadPublisher.subscribe(subscriber); verify(subscriber).onError(any(RuntimeException.class)); double bad = socket.availability(); - assertTrue(good > bad); + assertThat(good > bad).isTrue(); latch.countDown(); }); } @@ -73,7 +72,7 @@ public void testError() throws InterruptedException { public void testWidowReset() throws InterruptedException { testRSocket( (latch, socket) -> { - assertEquals(1.0, socket.availability(), 0.0); + assertThat(socket.availability()).isEqualTo(1.0); Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); Subscriber subscriber = TestSubscriber.create(); @@ -87,7 +86,7 @@ public void testWidowReset() throws InterruptedException { verify(subscriber).onError(any(RuntimeException.class)); double bad = socket.availability(); - assertTrue(good > bad); + assertThat(good > bad).isTrue(); try { Thread.sleep(200); @@ -96,7 +95,7 @@ public void testWidowReset() throws InterruptedException { } double reset = socket.availability(); - assertTrue(reset > bad); + assertThat(reset > bad).isTrue(); latch.countDown(); }); } diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java index 96982121b..2827c8ed4 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +24,13 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.Scannable; import reactor.core.publisher.*; public class TestingRSocket implements RSocket { private final AtomicInteger count; - private final MonoProcessor onClose = MonoProcessor.create(); + private final Sinks.Empty onClose = Sinks.empty(); private final BiFunction, Payload, Boolean> eachPayloadHandler; public TestingRSocket(Function responder) { @@ -128,16 +129,17 @@ public double availability() { @Override public void dispose() { - onClose.onComplete(); + onClose.tryEmitEmpty(); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return onClose; + return onClose.asMono(); } } diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java index 9a5ac644b..b8866b1f6 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java @@ -16,15 +16,14 @@ package io.rsocket.client; -import static org.hamcrest.Matchers.instanceOf; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSockets; import io.rsocket.util.EmptyPayload; import java.time.Duration; -import org.hamcrest.MatcherAssert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -50,8 +49,9 @@ public void onNext(Payload payload) { @Override public void onError(Throwable t) { - MatcherAssert.assertThat( - "Unexpected exception in onError", t, instanceOf(TimeoutException.class)); + assertThat(t) + .describedAs("Unexpected exception in onError") + .isInstanceOf(TimeoutException.class); } @Override diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java index 0aab4afd7..b214a725e 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java @@ -18,8 +18,8 @@ import java.util.Arrays; import java.util.Random; -import org.junit.Assert; -import org.junit.Test; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; public class MedianTest { private double errorSum = 0; @@ -59,7 +59,8 @@ private void testMedian(Random rng) { maxError = Math.max(maxError, error); minError = Math.min(minError, error); - Assert.assertTrue( - "p50=" + estimation + ", real=" + expected + ", error=" + error, error < 0.02); + Assertions.assertThat(error < 0.02) + .describedAs("p50=" + estimation + ", real=" + expected + ", error=" + error) + .isTrue(); } } diff --git a/rsocket-micrometer/build.gradle b/rsocket-micrometer/build.gradle index 4be616623..debf02f34 100644 --- a/rsocket-micrometer/build.gradle +++ b/rsocket-micrometer/build.gradle @@ -17,13 +17,14 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { api project(':rsocket-core') + api 'io.micrometer:micrometer-observation' api 'io.micrometer:micrometer-core' + api 'io.micrometer:micrometer-tracing' implementation 'org.slf4j:slf4j-api' @@ -37,4 +38,10 @@ dependencies { testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.micrometer") + } +} + description = 'Transparent Metrics exposure to Micrometer' diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java index c8b22382a..7c7ac37b9 100644 --- a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,13 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import java.net.SocketAddress; import java.util.Objects; import java.util.function.Consumer; -import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -88,6 +89,11 @@ public ByteBufAllocator alloc() { return delegate.alloc(); } + @Override + public SocketAddress remoteAddress() { + return delegate.remoteAddress(); + } + @Override public void dispose() { delegate.dispose(); @@ -105,10 +111,14 @@ public Flux receive() { } @Override - public Mono send(Publisher frames) { - Objects.requireNonNull(frames, "frames must not be null"); + public void sendFrame(int streamId, ByteBuf frame) { + frameCounters.accept(frame); + delegate.sendFrame(streamId, frame); + } - return delegate.send(Flux.from(frames).doOnNext(frameCounters)); + @Override + public void sendErrorAndClose(RSocketErrorException e) { + delegate.sendErrorAndClose(e); } private static final class FrameCounters implements Consumer { diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java new file mode 100644 index 000000000..09c8ba316 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata; + +public class ByteBufGetter implements Propagator.Getter { + + @Override + public String get(ByteBuf carrier, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(carrier, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + if (key.equals(entry.getMimeType())) { + return entry.getContent().toString(CharsetUtil.UTF_8); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java new file mode 100644 index 000000000..678bdb1ed --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java @@ -0,0 +1,33 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadataCodec; + +public class ByteBufSetter implements Propagator.Setter { + + @Override + public void set(CompositeByteBuf carrier, String key, String value) { + final ByteBufAllocator alloc = carrier.alloc(); + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + carrier, alloc, key, ByteBufUtil.writeUtf8(alloc, value)); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java new file mode 100644 index 000000000..357be8f15 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.core.lang.Nullable; +import io.netty.buffer.ByteBuf; +import io.rsocket.metadata.CompositeMetadata; + +final class CompositeMetadataUtils { + + private CompositeMetadataUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + @Nullable + static ByteBuf extract(ByteBuf metadata, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + final String entryKey = entry.getMimeType(); + if (key.equals(entryKey)) { + return entry.getContent(); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java new file mode 100644 index 000000000..2c10fc78d --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java @@ -0,0 +1,49 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +class DefaultRSocketObservationConvention { + + private final RSocketContext rSocketContext; + + public DefaultRSocketObservationConvention(RSocketContext rSocketContext) { + this.rSocketContext = rSocketContext; + } + + String getName() { + if (this.rSocketContext.frameType == FrameType.REQUEST_FNF) { + return "rsocket.fnf"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_STREAM) { + return "rsocket.stream"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_CHANNEL) { + return "rsocket.channel"; + } + return "%s"; + } + + protected RSocketContext getRSocketContext() { + return this.rSocketContext; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java new file mode 100644 index 000000000..73e04b749 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java @@ -0,0 +1,62 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketRequesterObservationConvention + extends DefaultRSocketObservationConvention implements RSocketRequesterObservationConvention { + + public DefaultRSocketRequesterObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues values = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + values = + values.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return values; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.request"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java new file mode 100644 index 000000000..5318c1b37 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java @@ -0,0 +1,61 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketResponderObservationConvention + extends DefaultRSocketObservationConvention implements RSocketResponderObservationConvention { + + public DefaultRSocketResponderObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues tags = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + tags = tags.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return tags; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.response"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java new file mode 100644 index 000000000..fb80ea317 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java @@ -0,0 +1,208 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.docs.ObservationDocumentation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; +import reactor.util.context.ContextView; + +/** + * Tracing representation of a {@link RSocketProxy} for the requester. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationRequesterRSocketProxy extends RSocketProxy { + + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketRequesterObservationConvention observationConvention; + + public ObservationRequesterRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationRequesterRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketRequesterObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + return setObservation( + super::fireAndForget, + payload, + FrameType.REQUEST_FNF, + RSocketObservationDocumentation.RSOCKET_REQUESTER_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return setObservation( + super::requestResponse, + payload, + FrameType.REQUEST_RESPONSE, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_RESPONSE); + } + + Mono setObservation( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation observation) { + return Mono.deferContextual( + contextView -> observe(input, payload, frameType, observation, contextView)); + } + + private String route(Payload payload) { + if (payload.hasMetadata()) { + try { + ByteBuf extracted = + CompositeMetadataUtils.extract( + payload.sliceMetadata(), WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + final RoutingMetadata routingMetadata = new RoutingMetadata(extracted); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } catch (Exception e) { + + } + } + return null; + } + + private Mono observe( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs, + ContextView contextView) { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, payload.sliceMetadata(), frameType, route, RSocketContext.Side.REQUESTER); + Observation parentObservation = contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation observation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, observation); + observation.start(); + Payload newPayload = payload; + if (rSocketContext.modifiedPayload != null) { + newPayload = rSocketContext.modifiedPayload; + } + return input + .apply(newPayload) + .doOnError(observation::error) + .doFinally(signalType -> observation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, observation)); + } + + @Override + public Flux requestStream(Payload payload) { + return observationFlux( + super::requestStream, + payload, + FrameType.REQUEST_STREAM, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher inbound) { + return Flux.from(inbound) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + return observationFlux( + p -> super.requestChannel(flux.skip(1).startWith(p)), + firstPayload, + FrameType.REQUEST_CHANNEL, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_CHANNEL); + } + return flux; + }); + } + + private Flux observationFlux( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs) { + return Flux.deferContextual( + contextView -> { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + frameType, + route, + RSocketContext.Side.REQUESTER); + Observation parentObservation = + contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation newObservation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, newObservation); + newObservation.start(); + return input + .apply(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + }); + } + + private void setContextualName(FrameType frameType, String route, Observation newObservation) { + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(frameType.name() + " " + route); + } else { + newObservation.contextualName(frameType.name()); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java new file mode 100644 index 000000000..9ed27adf3 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java @@ -0,0 +1,179 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Tracing representation of a {@link RSocketProxy} for the responder. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationResponderRSocketProxy extends RSocketProxy { + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketResponderObservationConvention observationConvention; + + public ObservationResponderRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationResponderRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketResponderObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + // called on Netty EventLoop + // there can't be observation in thread local here + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_FNF, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation(RSocketObservationDocumentation.RSOCKET_RESPONDER_FNF, rSocketContext); + return super.fireAndForget(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + private Observation startObservation( + RSocketObservationDocumentation observation, RSocketContext rSocketContext) { + return observation.start( + this.observationConvention, + new DefaultRSocketResponderObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry); + } + + @Override + public Mono requestResponse(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_RESPONSE, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_RESPONSE, rSocketContext); + return super.requestResponse(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestStream(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, sliceMetadata, FrameType.REQUEST_STREAM, route, RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_STREAM, rSocketContext); + return super.requestStream(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + ByteBuf sliceMetadata = firstPayload.sliceMetadata(); + String route = route(firstPayload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + firstPayload, + firstPayload.sliceMetadata(), + FrameType.REQUEST_CHANNEL, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_CHANNEL, + rSocketContext); + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(rSocketContext.frameType.name() + " " + route); + } + return super.requestChannel(flux.skip(1).startWith(rSocketContext.modifiedPayload)) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite( + context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + return flux; + }); + } + + private String route(Payload payload, ByteBuf headers) { + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } + } catch (Exception e) { + + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java new file mode 100644 index 000000000..e5286a53f --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java @@ -0,0 +1,73 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.util.HashSet; +import java.util.Set; + +final class PayloadUtils { + + private PayloadUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + static CompositeByteBuf cleanTracingMetadata(Payload payload, Set fields) { + Set fieldsWithDefaultZipkin = new HashSet<>(fields); + fieldsWithDefaultZipkin.add(WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + final CompositeByteBuf metadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + if (payload.hasMetadata()) { + try { + final CompositeMetadata entries = new CompositeMetadata(payload.metadata(), false); + for (Entry entry : entries) { + if (!fieldsWithDefaultZipkin.contains(entry.getMimeType())) { + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + metadata, + ByteBufAllocator.DEFAULT, + entry.getMimeType(), + entry.getContent().retain()); + } + } + } catch (Exception e) { + + } + } + return metadata; + } + + static Payload payload(Payload payload, CompositeByteBuf metadata) { + final Payload newPayload; + try { + if (payload instanceof ByteBufPayload) { + newPayload = ByteBufPayload.create(payload.data().retain(), metadata); + } else { + newPayload = DefaultPayload.create(payload.data().retain(), metadata); + } + } finally { + payload.release(); + } + return newPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java new file mode 100644 index 000000000..8622cdfa5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java @@ -0,0 +1,76 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.lang.Nullable; +import io.micrometer.observation.Observation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; + +public class RSocketContext extends Observation.Context { + + final Payload payload; + + final ByteBuf metadata; + + final FrameType frameType; + + final String route; + + final Side side; + + Payload modifiedPayload; + + RSocketContext( + Payload payload, ByteBuf metadata, FrameType frameType, @Nullable String route, Side side) { + this.payload = payload; + this.metadata = metadata; + this.frameType = frameType; + this.route = route; + this.side = side; + } + + public enum Side { + REQUESTER, + RESPONDER + } + + public Payload getPayload() { + return payload; + } + + public ByteBuf getMetadata() { + return metadata; + } + + public FrameType getFrameType() { + return frameType; + } + + public String getRoute() { + return route; + } + + public Side getSide() { + return side; + } + + public Payload getModifiedPayload() { + return modifiedPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java new file mode 100644 index 000000000..1be6b4599 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java @@ -0,0 +1,232 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.docs.KeyName; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; +import io.micrometer.observation.docs.ObservationDocumentation; + +enum RSocketObservationDocumentation implements ObservationDocumentation { + + /** Observation created on the RSocket responder side. */ + RSOCKET_RESPONDER { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + }, + + /** Observation created on the RSocket requester side for Fire and Forget frame type. */ + RSOCKET_REQUESTER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Fire and Forget frame type. */ + RSOCKET_RESPONDER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Response frame type. */ + RSOCKET_REQUESTER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Response frame type. */ + RSOCKET_RESPONDER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Stream frame type. */ + RSOCKET_REQUESTER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Stream frame type. */ + RSOCKET_RESPONDER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Channel frame type. */ + RSOCKET_REQUESTER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Channel frame type. */ + RSOCKET_RESPONDER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }; + + enum RequesterTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + }, + + /** Name of the RSocket content type. */ + CONTENT_TYPE { + @Override + public String asString() { + return "rsocket.content-type"; + } + } + } + + enum ResponderTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java new file mode 100644 index 000000000..d795f81b5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket requester {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketRequesterObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java new file mode 100644 index 000000000..996267d4a --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java @@ -0,0 +1,131 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.TracingMetadataCodec; +import java.util.HashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketRequesterTracingObservationHandler + implements TracingObservationHandler { + private static final Logger log = + LoggerFactory.getLogger(RSocketRequesterTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Setter setter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketRequesterTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Setter setter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.setter = setter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + @Override + public void onStart(RSocketContext context) { + Payload payload = context.payload; + Span.Builder spanBuilder = this.tracer.spanBuilder(); + Span parentSpan = getParentSpan(context); + if (parentSpan != null) { + spanBuilder.setParent(parentSpan.context()); + } + Span span = spanBuilder.kind(Span.Kind.PRODUCER).start(); + log.debug("Extracted result from context or thread local {}", span); + // TODO: newmetadata returns an empty composite byte buf + final CompositeByteBuf newMetadata = + PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields())); + TraceContext traceContext = span.context(); + if (this.isZipkinPropagationEnabled) { + injectDefaultZipkinRSocketHeaders(newMetadata, traceContext); + } + this.propagator.inject(traceContext, newMetadata, this.setter); + context.modifiedPayload = PayloadUtils.payload(payload, newMetadata); + getTracingContext(context).setSpan(span); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.name(context.getContextualName()).end(); + } + + private void injectDefaultZipkinRSocketHeaders( + CompositeByteBuf newMetadata, TraceContext traceContext) { + TracingMetadataCodec.Flags flags = + traceContext.sampled() == null + ? TracingMetadataCodec.Flags.UNDECIDED + : traceContext.sampled() + ? TracingMetadataCodec.Flags.SAMPLE + : TracingMetadataCodec.Flags.NOT_SAMPLE; + String traceId = traceContext.traceId(); + long[] traceIds = EncodingUtils.fromString(traceId); + long[] spanId = EncodingUtils.fromString(traceContext.spanId()); + long[] parentSpanId = EncodingUtils.fromString(traceContext.parentId()); + boolean isTraceId128Bit = traceIds.length == 2; + if (isTraceId128Bit) { + TracingMetadataCodec.encode128( + newMetadata.alloc(), + traceIds[0], + traceIds[1], + spanId[0], + EncodingUtils.fromString(traceContext.parentId())[0], + flags); + } else { + TracingMetadataCodec.encode64( + newMetadata.alloc(), traceIds[0], spanId[0], parentSpanId[0], flags); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java new file mode 100644 index 000000000..a5d6808bd --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket responder {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketResponderObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java new file mode 100644 index 000000000..e3975b577 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java @@ -0,0 +1,152 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TracingMetadata; +import io.rsocket.metadata.TracingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import java.util.HashSet; +import java.util.Iterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketResponderTracingObservationHandler + implements TracingObservationHandler { + + private static final Logger log = + LoggerFactory.getLogger(RSocketResponderTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Getter getter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketResponderTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Getter getter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.getter = getter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public void onStart(RSocketContext context) { + Span handle = consumerSpanBuilder(context.payload, context.metadata, context.frameType); + CompositeByteBuf bufs = + PayloadUtils.cleanTracingMetadata(context.payload, new HashSet<>(propagator.fields())); + context.modifiedPayload = PayloadUtils.payload(context.payload, bufs); + getTracingContext(context).setSpan(handle); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.end(); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + private Span consumerSpanBuilder(Payload payload, ByteBuf headers, FrameType requestType) { + Span.Builder consumerSpanBuilder = consumerSpanBuilder(payload, headers); + log.debug("Extracted result from headers {}", consumerSpanBuilder); + String name = "handle"; + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + name = requestType.name() + " " + iterator.next(); + } + } catch (Exception e) { + + } + } + return consumerSpanBuilder.kind(Span.Kind.CONSUMER).name(name).start(); + } + + private Span.Builder consumerSpanBuilder(Payload payload, ByteBuf headers) { + if (this.isZipkinPropagationEnabled && payload.hasMetadata()) { + try { + ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + if (extract != null) { + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(extract); + Span.Builder builder = this.tracer.spanBuilder(); + String traceId = EncodingUtils.fromLong(tracingMetadata.traceId()); + long traceIdHigh = tracingMetadata.traceIdHigh(); + if (traceIdHigh != 0L) { + // ExtendedTraceId + traceId = EncodingUtils.fromLong(traceIdHigh) + traceId; + } + TraceContext.Builder parentBuilder = + this.tracer + .traceContextBuilder() + .sampled(tracingMetadata.isDebug() || tracingMetadata.isSampled()) + .traceId(traceId) + .spanId(EncodingUtils.fromLong(tracingMetadata.spanId())) + .parentId(EncodingUtils.fromLong(tracingMetadata.parentId())); + return builder.setParent(parentBuilder.build()); + } else { + return this.propagator.extract(headers, this.getter); + } + } catch (Exception e) { + + } + } + return this.propagator.extract(headers, this.getter); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java index 03abd2084..7806200dd 100644 --- a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java @@ -34,7 +34,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import org.reactivestreams.Publisher; +import org.mockito.Mockito; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; @@ -153,32 +153,29 @@ void receive() { @SuppressWarnings("unchecked") @Test void send() { - ArgumentCaptor> captor = ArgumentCaptor.forClass(Publisher.class); - when(delegate.send(captor.capture())).thenReturn(Mono.empty()); - - Flux frames = - Flux.just( - createTestCancelFrame(), - createTestErrorFrame(), - createTestKeepaliveFrame(), - createTestLeaseFrame(), - createTestMetadataPushFrame(), - createTestPayloadFrame(), - createTestRequestChannelFrame(), - createTestRequestFireAndForgetFrame(), - createTestRequestNFrame(), - createTestRequestResponseFrame(), - createTestRequestStreamFrame(), - createTestSetupFrame()); - - new MicrometerDuplexConnection( - SERVER, delegate, meterRegistry, Tag.of("test-key", "test-value")) - .send(frames) - .as(StepVerifier::create) + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + doNothing().when(delegate).sendFrame(Mockito.anyInt(), captor.capture()); + + final MicrometerDuplexConnection micrometerDuplexConnection = + new MicrometerDuplexConnection( + SERVER, delegate, meterRegistry, Tag.of("test-key", "test-value")); + micrometerDuplexConnection.sendFrame(1, createTestCancelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestErrorFrame()); + micrometerDuplexConnection.sendFrame(1, createTestKeepaliveFrame()); + micrometerDuplexConnection.sendFrame(1, createTestLeaseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestMetadataPushFrame()); + micrometerDuplexConnection.sendFrame(1, createTestPayloadFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestChannelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestFireAndForgetFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestNFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestResponseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestStreamFrame()); + micrometerDuplexConnection.sendFrame(1, createTestSetupFrame()); + + StepVerifier.create(Flux.fromIterable(captor.getAllValues())) + .expectNextCount(12) .verifyComplete(); - StepVerifier.create(captor.getValue()).expectNextCount(12).verifyComplete(); - assertThat(findCounter(SERVER, CANCEL).count()).isEqualTo(1); assertThat(findCounter(SERVER, COMPLETE).count()).isEqualTo(1); assertThat(findCounter(SERVER, ERROR).count()).isEqualTo(1); @@ -193,15 +190,6 @@ void send() { assertThat(findCounter(SERVER, SETUP).count()).isEqualTo(1); } - @DisplayName("send throws NullPointerException with null frames") - @Test - void sendNullFrames() { - assertThatNullPointerException() - .isThrownBy( - () -> new MicrometerDuplexConnection(CLIENT, delegate, meterRegistry).send(null)) - .withMessage("frames must not be null"); - } - private Counter findCounter(Type connectionType, FrameType frameType) { return meterRegistry .get("rsocket.frame") diff --git a/rsocket-test/build.gradle b/rsocket-test/build.gradle index 5ec1a8061..bcdf88f28 100644 --- a/rsocket-test/build.gradle +++ b/rsocket-test/build.gradle @@ -17,8 +17,7 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { @@ -29,9 +28,14 @@ dependencies { implementation 'io.projectreactor:reactor-test' implementation 'org.assertj:assertj-core' implementation 'org.mockito:mockito-core' + implementation 'org.awaitility:awaitility' + implementation 'org.slf4j:slf4j-api' +} - // TODO: Remove after JUnit5 migration - implementation 'junit:junit' +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.test") + } } description = 'Test utilities for RSocket projects' diff --git a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java index d74f59fd8..e773b4a0d 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java @@ -16,22 +16,35 @@ package io.rsocket.test; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.util.DefaultPayload; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Flux; public abstract class BaseClientServerTest> { - @Rule public final T setup = createClientServer(); + public final T setup = createClientServer(); protected abstract T createClientServer(); - @Test(timeout = 10000) + @BeforeEach + public void init() { + setup.init(); + } + + @AfterEach + public void teardown() { + setup.tearDown(); + } + + @Test + @Timeout(10000) public void testFireNForget10() { long outputCount = Flux.range(1, 10) @@ -40,10 +53,11 @@ public void testFireNForget10() { .count() .block(); - assertEquals(0, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testPushMetadata10() { long outputCount = Flux.range(1, 10) @@ -52,7 +66,7 @@ public void testPushMetadata10() { .count() .block(); - assertEquals(0, outputCount); + assertThat(outputCount).isZero(); } @Test // (timeout = 10000) @@ -65,10 +79,11 @@ public void testRequestResponse1() { .count() .block(); - assertEquals(1, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestResponse10() { long outputCount = Flux.range(1, 10) @@ -78,7 +93,7 @@ public void testRequestResponse10() { .count() .block(); - assertEquals(10, outputCount); + assertThat(outputCount).isEqualTo(10); } private Payload testPayload(int metadataPresent) { @@ -97,7 +112,8 @@ private Payload testPayload(int metadataPresent) { return DefaultPayload.create("hello", metadata); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestResponse100() { long outputCount = Flux.range(1, 100) @@ -107,10 +123,11 @@ public void testRequestResponse100() { .count() .block(); - assertEquals(100, outputCount); + assertThat(outputCount).isEqualTo(100); } - @Test(timeout = 20000) + @Test + @Timeout(20000) public void testRequestResponse10_000() { long outputCount = Flux.range(1, 10_000) @@ -120,28 +137,31 @@ public void testRequestResponse10_000() { .count() .block(); - assertEquals(10_000, outputCount); + assertThat(outputCount).isEqualTo(10_000); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStream() { Flux publisher = setup.getRSocket().requestStream(testPayload(3)); long count = publisher.take(5).count().block(); - assertEquals(5, count); + assertThat(count).isEqualTo(5); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamAll() { Flux publisher = setup.getRSocket().requestStream(testPayload(3)); long count = publisher.count().block(); - assertEquals(10000, count); + assertThat(count).isEqualTo(10000); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamWithRequestN() { CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); ts.expect(5); @@ -149,16 +169,17 @@ public void testRequestStreamWithRequestN() { setup.getRSocket().requestStream(testPayload(3)).subscribe(ts); ts.await(); - assertEquals(5, ts.count()); + assertThat(ts.count()).isEqualTo(5); ts.expect(5); ts.await(); ts.cancel(); - assertEquals(10, ts.count()); + assertThat(ts.count()).isEqualTo(10); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamWithDelayedRequestN() { CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); @@ -167,34 +188,37 @@ public void testRequestStreamWithDelayedRequestN() { ts.expect(5); ts.await(); - assertEquals(5, ts.count()); + assertThat(ts.count()).isEqualTo(5); ts.expect(5); ts.await(); ts.cancel(); - assertEquals(10, ts.count()); + assertThat(ts.count()).isEqualTo(10); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel0() { Flux publisher = setup.getRSocket().requestChannel(Flux.empty()); long count = publisher.count().block(); - assertEquals(0, count); + assertThat(count).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel1() { Flux publisher = setup.getRSocket().requestChannel(Flux.just(testPayload(0))); long count = publisher.count().block(); - assertEquals(1, count); + assertThat(count).isOne(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel3() { Flux publisher = setup @@ -203,44 +227,48 @@ public void testChannel3() { long count = publisher.count().block(); - assertEquals(3, count); + assertThat(count).isEqualTo(3); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel512() { Flux payloads = Flux.range(1, 512).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(512, count); + assertThat(count).isEqualTo(512); } - @Test(timeout = 30000) + @Test + @Timeout(30000) public void testChannel20_000() { Flux payloads = Flux.range(1, 20_000).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(20_000, count); + assertThat(count).isEqualTo(20_000); } - @Test(timeout = 60_000) + @Test + @Timeout(60_000) public void testChannel200_000() { Flux payloads = Flux.range(1, 200_000).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(200_000, count); + assertThat(count).isEqualTo(200_000); } - @Test(timeout = 60_000) - @Ignore + @Test + @Timeout(60_000) + @Disabled public void testChannel2_000_000() { AtomicInteger counter = new AtomicInteger(0); Flux payloads = Flux.range(1, 2_000_000).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(2_000_000, count); + assertThat(count).isEqualTo(2_000_000); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java new file mode 100644 index 000000000..d065f3d71 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.presentation.StandardRepresentation; + +public final class ByteBufRepresentation extends StandardRepresentation { + + @Override + protected String fallbackToStringOf(Object object) { + if (object instanceof ByteBuf) { + try { + String normalBufferString = object.toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 256) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } + } catch (IllegalReferenceCountException e) { + // noops + } + } + + return super.fallbackToStringOf(object); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java index 6f562875f..1d6b7f69e 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -25,12 +25,9 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; import reactor.core.publisher.Mono; -public class ClientSetupRule extends ExternalResource { +public class ClientSetupRule { private static final String data = "hello world"; private static final String metadata = "metadata"; @@ -39,6 +36,7 @@ public class ClientSetupRule extends ExternalResource { private Function serverInit; private RSocket client; + private S server; public ClientSetupRule( Supplier addressSupplier, @@ -59,18 +57,14 @@ public ClientSetupRule( .block(); } - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - T address = addressSupplier.get(); - S server = serverInit.apply(address); - client = clientConnector.apply(address, server); - base.evaluate(); - server.dispose(); - } - }; + public void init() { + T address = addressSupplier.get(); + S server = serverInit.apply(address); + client = clientConnector.apply(address, server); + } + + public void tearDown() { + server.dispose(); } public RSocket getRSocket() { diff --git a/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..46e807b09 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,294 @@ +package io.rsocket.test; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { + this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java index 9017e854b..14740950a 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java @@ -63,8 +63,8 @@ Flux pingPong( BiFunction> interaction, int count, final Recorder histogram) { - return client - .flatMapMany( + return Flux.usingWhen( + client, rsocket -> Flux.range(1, count) .flatMap( @@ -78,7 +78,11 @@ Flux pingPong( histogram.recordValue(diff); }); }, - 64)) + 64), + rsocket -> { + rsocket.dispose(); + return rsocket.onClose(); + }) .doOnError(Throwable::printStackTrace); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java new file mode 100644 index 000000000..57a00e229 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java @@ -0,0 +1,166 @@ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.PayloadFrameCodec; +import java.net.SocketAddress; +import java.util.function.BiFunction; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class TestDuplexConnection implements DuplexConnection { + + final ByteBufAllocator allocator; + final Sinks.Many inbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.Many outbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.One close = Sinks.one(); + + public TestDuplexConnection( + CoreSubscriber outboundSubscriber, boolean trackLeaks) { + this.outbound.asFlux().subscribe(outboundSubscriber); + this.allocator = + trackLeaks + ? LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT) + : ByteBufAllocator.DEFAULT; + } + + @Override + public void dispose() { + this.inbound.tryEmitComplete(); + this.outbound.tryEmitComplete(); + this.close.tryEmitEmpty(); + } + + @Override + public Mono onClose() { + return this.close.asMono(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) {} + + @Override + public Flux receive() { + return this.inbound + .asFlux() + .transform( + Operators.lift( + (BiFunction< + Scannable, + CoreSubscriber, + CoreSubscriber>) + ByteBufReleaserOperator::create)); + } + + @Override + public ByteBufAllocator alloc() { + return this.allocator; + } + + @Override + public SocketAddress remoteAddress() { + return new SocketAddress() { + @Override + public String toString() { + return "Test"; + } + }; + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + this.outbound.tryEmitNext(frame); + } + + public void sendPayloadFrame( + int streamId, ByteBuf data, @Nullable ByteBuf metadata, boolean complete) { + sendFrame( + streamId, + PayloadFrameCodec.encode(this.allocator, streamId, false, complete, true, metadata, data)); + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + static CoreSubscriber create( + Scannable scannable, CoreSubscriber actual) { + return new ByteBufReleaserOperator(actual); + } + + final CoreSubscriber actual; + + Subscription s; + + public ByteBufReleaserOperator(CoreSubscriber actual) { + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + this.actual.onNext(buf); + buf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index d48700445..1b294e394 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -16,9 +16,13 @@ package io.rsocket.test; +import static java.util.concurrent.locks.LockSupport.parkNanos; + import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.util.DefaultPayload; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -27,6 +31,9 @@ public class TestRSocket implements RSocket { private final String data; private final String metadata; + private final AtomicLong observedInteractions = new AtomicLong(); + private final AtomicLong activeInteractions = new AtomicLong(); + public TestRSocket(String data, String metadata) { this.data = data; this.metadata = metadata; @@ -34,27 +41,69 @@ public TestRSocket(String data, String metadata) { @Override public Mono requestResponse(Payload payload) { - return Mono.just(DefaultPayload.create(data, metadata)); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.just(ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000).flatMap(l -> requestResponse(payload)); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Flux.range(1, 10_000) + .map(l -> ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Mono metadataPush(Payload payload) { - return Mono.empty(); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Mono fireAndForget(Payload payload) { - return Mono.empty(); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Flux requestChannel(Publisher payloads) { - // TODO is defensive copy neccesary? - return Flux.from(payloads).map(Payload::retain); + activeInteractions.getAndIncrement(); + observedInteractions.getAndIncrement(); + return Flux.from(payloads).doFinally(__ -> activeInteractions.getAndDecrement()); + } + + public boolean awaitAllInteractionTermination(Duration duration) { + long end = duration.plusNanos(System.nanoTime()).toNanos(); + long activeNow; + while ((activeNow = activeInteractions.get()) > 0) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return activeNow == 0; + } + + public boolean awaitUntilObserved(int interactions, Duration duration) { + long end = System.nanoTime() + duration.toNanos(); + long observed; + while ((observed = observedInteractions.get()) < interactions) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return observed >= interactions; } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index fc059c7d1..1fcca97db 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,44 +16,73 @@ package io.rsocket.test; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; import io.rsocket.core.RSocketConnector; import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.resume.InMemoryResumableFramesStore; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.util.DefaultPayload; +import io.rsocket.util.ByteBufPayload; import java.io.BufferedReader; import java.io.InputStreamReader; +import java.net.SocketAddress; import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; -import java.util.function.Function; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; +import reactor.util.Logger; +import reactor.util.Loggers; public interface TransportTest { + Logger logger = Loggers.getLogger(TransportTest.class); + String MOCK_DATA = "test-data"; String MOCK_METADATA = "metadata"; String LARGE_DATA = read("words.shakespeare.txt.gz"); - Payload LARGE_PAYLOAD = DefaultPayload.create(LARGE_DATA, LARGE_DATA); + Payload LARGE_PAYLOAD = ByteBufPayload.create(LARGE_DATA, LARGE_DATA); static String read(String resourceName) { - try (BufferedReader br = new BufferedReader( new InputStreamReader( @@ -67,14 +96,55 @@ static String read(String resourceName) { } @BeforeEach - default void setUp() { + default void setup() { Hooks.onOperatorDebug(); } @AfterEach default void close() { - getTransportPair().dispose(); - Hooks.resetOnOperatorDebug(); + try { + logger.debug("------------------Awaiting communication to finish------------------"); + getTransportPair().responder.awaitAllInteractionTermination(getTimeout()); + logger.debug("---------------------Disposing Client And Server--------------------"); + getTransportPair().dispose(); + getTransportPair().awaitClosed(getTimeout()); + logger.debug("------------------------Disposing Schedulers-------------------------"); + Schedulers.parallel().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.boundedElastic().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.single().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + logger.debug("---------------------------Leaks Checking----------------------------"); + RuntimeException throwable = + new RuntimeException() { + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } + + @Override + public String getMessage() { + return Arrays.toString(getSuppressed()); + } + }; + + try { + getTransportPair().byteBufAllocator2.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + try { + getTransportPair().byteBufAllocator1.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + if (throwable.getSuppressed().length > 0) { + throw throwable; + } + } finally { + Hooks.resetOnOperatorDebug(); + Schedulers.resetOnHandleError(); + } } default Payload createTestPayload(int metadataPresent) { @@ -93,7 +163,7 @@ default Payload createTestPayload(int metadataPresent) { } String metadata = metadata1; - return DefaultPayload.create(MOCK_DATA, metadata); + return ByteBufPayload.create(MOCK_DATA, metadata); } @DisplayName("makes 10 fireAndForget requests") @@ -102,20 +172,22 @@ default void fireAndForget10() { Flux.range(1, 10) .flatMap(i -> getClient().fireAndForget(createTestPayload(i))) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } @DisplayName("makes 10 fireAndForget with Large Payload in Requests") @Test default void largePayloadFireAndForget10() { Flux.range(1, 10) - .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD)) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD.retain())) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } default RSocket getClient() { @@ -129,23 +201,27 @@ default RSocket getClient() { @DisplayName("makes 10 metadataPush requests") @Test default void metadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); Flux.range(1, 10) - .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", "test-metadata"))) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", "test-metadata"))) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } @DisplayName("makes 10 metadataPush with Large Metadata in requests") @Test default void largePayloadMetadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); Flux.range(1, 10) - .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", LARGE_DATA))) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", LARGE_DATA))) .as(StepVerifier::create) - .expectNextCount(0) .expectComplete() .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); } @DisplayName("makes 1 requestChannel request with 0 payloads") @@ -154,8 +230,11 @@ default void requestChannel0() { getClient() .requestChannel(Flux.empty()) .as(StepVerifier::create) - .expectNextCount(0) - .expectComplete() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Empty Source")) .verify(getTimeout()); } @@ -164,8 +243,9 @@ default void requestChannel0() { default void requestChannel1() { getClient() .requestChannel(Mono.just(createTestPayload(0))) + .doOnNext(Payload::release) .as(StepVerifier::create) - .expectNextCount(1) + .thenConsumeWhile(new PayloadPredicate(1)) .expectComplete() .verify(getTimeout()); } @@ -177,21 +257,24 @@ default void requestChannel200_000() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) + .limitRate(8) .as(StepVerifier::create) - .expectNextCount(200_000) + .thenConsumeWhile(new PayloadPredicate(200_000)) .expectComplete() .verify(getTimeout()); } - @DisplayName("makes 1 requestChannel request with 200 large payloads") + @DisplayName("makes 1 requestChannel request with 50 large payloads") @Test - default void largePayloadRequestChannel200() { - Flux payloads = Flux.range(0, 200).map(__ -> LARGE_PAYLOAD); + default void largePayloadRequestChannel50() { + Flux payloads = Flux.range(0, 50).map(__ -> LARGE_PAYLOAD.retain()); getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(StepVerifier::create) - .expectNextCount(200) + .thenConsumeWhile(new PayloadPredicate(50)) .expectComplete() .verify(getTimeout()); } @@ -204,8 +287,9 @@ default void requestChannel20_000() { getClient() .requestChannel(payloads) .doOnNext(this::assertChannelPayload) + .doOnNext(Payload::release) .as(StepVerifier::create) - .expectNextCount(20_000) + .thenConsumeWhile(new PayloadPredicate(20_000)) .expectComplete() .verify(getTimeout()); } @@ -217,8 +301,10 @@ default void requestChannel2_000_000() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) + .limitRate(8) .as(StepVerifier::create) - .expectNextCount(2_000_000) + .thenConsumeWhile(new PayloadPredicate(2_000_000)) .expectComplete() .verify(getTimeout()); } @@ -232,31 +318,46 @@ default void requestChannel3() { getClient() .requestChannel(payloads) + .doOnNext(Payload::release) .as(publisher -> StepVerifier.create(publisher, 3)) - .expectNextCount(3) + .thenConsumeWhile(new PayloadPredicate(3)) .expectComplete() .verify(getTimeout()); Assertions.assertThat(requested.get()).isEqualTo(3L); } - @DisplayName("makes 1 requestChannel request with 512 payloads") + @DisplayName("makes 1 requestChannel request with 256 payloads") @Test - default void requestChannel512() { - Flux payloads = Flux.range(0, 512).map(this::createTestPayload); - - Flux.range(0, 1024) - .flatMap( - v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(Schedulers.elastic()), 12) - .blockLast(); + default void requestChannel256() { + AtomicInteger counter = new AtomicInteger(); + Flux payloads = + Flux.defer( + () -> { + final int subscription = counter.getAndIncrement(); + return Flux.range(0, 256) + .map(i -> "S{" + subscription + "}: Data{" + i + "}") + .map(data -> ByteBufPayload.create(data)); + }); + final Scheduler scheduler = Schedulers.fromExecutorService(Executors.newFixedThreadPool(12)); + + try { + Flux.range(0, 1024) + .flatMap(v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(scheduler), 12) + .blockLast(); + } finally { + scheduler.disposeGracefully().block(); + } } default void check(Flux payloads) { getClient() .requestChannel(payloads) + .doOnNext(ReferenceCounted::release) + .limitRate(8) .as(StepVerifier::create) - .expectNextCount(512) - .as("expected 512 items") + .thenConsumeWhile(new PayloadPredicate(256)) + .as("expected 256 items") .expectComplete() .verify(getTimeout()); } @@ -267,6 +368,7 @@ default void requestResponse1() { getClient() .requestResponse(createTestPayload(1)) .doOnNext(this::assertPayload) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(1) .expectComplete() @@ -278,7 +380,11 @@ default void requestResponse1() { default void requestResponse10() { Flux.range(1, 10) .flatMap( - i -> getClient().requestResponse(createTestPayload(i)).doOnNext(v -> assertPayload(v))) + i -> + getClient() + .requestResponse(createTestPayload(i)) + .doOnNext(v -> assertPayload(v)) + .doOnNext(Payload::release)) .as(StepVerifier::create) .expectNextCount(10) .expectComplete() @@ -289,20 +395,21 @@ default void requestResponse10() { @Test default void requestResponse100() { Flux.range(1, 100) - .flatMap(i -> getClient().requestResponse(createTestPayload(i)).map(Payload::getDataUtf8)) + .flatMap(i -> getClient().requestResponse(createTestPayload(i)).doOnNext(Payload::release)) .as(StepVerifier::create) .expectNextCount(100) .expectComplete() .verify(getTimeout()); } - @DisplayName("makes 100 requestResponse requests") + @DisplayName("makes 50 requestResponse requests") @Test - default void largePayloadRequestResponse100() { - Flux.range(1, 100) - .flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD).map(Payload::getDataUtf8)) + default void largePayloadRequestResponse50() { + Flux.range(1, 50) + .flatMap( + i -> getClient().requestResponse(LARGE_PAYLOAD.retain()).doOnNext(Payload::release)) .as(StepVerifier::create) - .expectNextCount(100) + .expectNextCount(50) .expectComplete() .verify(getTimeout()); } @@ -311,7 +418,7 @@ default void largePayloadRequestResponse100() { @Test default void requestResponse10_000() { Flux.range(1, 10_000) - .flatMap(i -> getClient().requestResponse(createTestPayload(i)).map(Payload::getDataUtf8)) + .flatMap(i -> getClient().requestResponse(createTestPayload(i)).doOnNext(Payload::release)) .as(StepVerifier::create) .expectNextCount(10_000) .expectComplete() @@ -324,6 +431,7 @@ default void requestStream10_000() { getClient() .requestStream(createTestPayload(3)) .doOnNext(this::assertPayload) + .doOnNext(Payload::release) .as(StepVerifier::create) .expectNextCount(10_000) .expectComplete() @@ -336,6 +444,7 @@ default void requestStream5() { getClient() .requestStream(createTestPayload(3)) .doOnNext(this::assertPayload) + .doOnNext(Payload::release) .take(5) .as(StepVerifier::create) .expectNextCount(5) @@ -349,6 +458,7 @@ default void requestStreamDelayedRequestN() { getClient() .requestStream(createTestPayload(3)) .take(10) + .doOnNext(Payload::release) .as(StepVerifier::create) .thenRequest(5) .expectNextCount(5) @@ -372,35 +482,179 @@ default void assertChannelPayload(Payload p) { } } - final class TransportPair implements Disposable { + class TransportPair implements Disposable { + private static final String data = "hello world"; private static final String metadata = "metadata"; + private final boolean withResumability; + private final boolean runClientWithAsyncInterceptors; + private final boolean runServerWithAsyncInterceptors; + + private final LeaksTrackingByteBufAllocator byteBufAllocator1 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Client"); + private final LeaksTrackingByteBufAllocator byteBufAllocator2 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Server"); + + private final TestRSocket responder; + private final RSocket client; private final S server; public TransportPair( Supplier addressSupplier, - BiFunction clientTransportSupplier, - Function> serverTransportSupplier) { + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier) { + this(addressSupplier, clientTransportSupplier, serverTransportSupplier, false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation) { + this( + addressSupplier, + clientTransportSupplier, + serverTransportSupplier, + withRandomFragmentation, + false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation, + boolean withResumability) { + Schedulers.onHandleError((t, e) -> e.printStackTrace()); + Schedulers.resetFactory(); + + this.withResumability = withResumability; T address = addressSupplier.get(); + this.runClientWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + this.runServerWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + + ByteBufAllocator allocatorToSupply1; + ByteBufAllocator allocatorToSupply2; + if (ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.ADVANCED + || ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.PARANOID) { + logger.info("Using LeakTrackingByteBufAllocator"); + allocatorToSupply1 = byteBufAllocator1; + allocatorToSupply2 = byteBufAllocator2; + } else { + allocatorToSupply1 = ByteBufAllocator.DEFAULT; + allocatorToSupply2 = ByteBufAllocator.DEFAULT; + } + responder = new TestRSocket(TransportPair.data, metadata); + final RSocketServer rSocketServer = + RSocketServer.create((setup, sendingSocket) -> Mono.just(responder)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors( + registry -> { + if (runServerWithAsyncInterceptors && !withResumability) { + logger.info( + "Perform Integration Test with Async Interceptors Enabled For Server"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection, "server")) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Server", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(100, 1000))) + : duplexConnection); + } + }); + + if (withResumability) { + rSocketServer.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("server", token, Integer.MAX_VALUE))); + } + + if (withRandomFragmentation) { + rSocketServer.fragment(ThreadLocalRandom.current().nextInt(256, 512)); + } + server = - RSocketServer.create((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .bind(serverTransportSupplier.apply(address)) - .block(); + rSocketServer.bind(serverTransportSupplier.apply(address, allocatorToSupply2)).block(); + + final RSocketConnector rSocketConnector = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMillis(10), Duration.ofHours(1)) + .interceptors( + registry -> { + if (runClientWithAsyncInterceptors && !withResumability) { + logger.info( + "Perform Integration Test with Async Interceptors Enabled For Client"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection, "client")) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Client", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(10, 1500))) + : duplexConnection); + } + }); + + if (withResumability) { + rSocketConnector.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("client", token, Integer.MAX_VALUE))); + } + + if (withRandomFragmentation) { + rSocketConnector.fragment(ThreadLocalRandom.current().nextInt(256, 512)); + } client = - RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) + rSocketConnector + .connect(clientTransportSupplier.apply(address, server, allocatorToSupply1)) .doOnError(Throwable::printStackTrace) .block(); } @Override public void dispose() { - server.dispose(); + logger.info("terminating transport pair"); + client.dispose(); } RSocket getClient() { @@ -414,5 +668,317 @@ public String expectedPayloadData() { public String expectedPayloadMetadata() { return metadata; } + + public void awaitClosed(Duration timeout) { + logger.info("awaiting termination of transport pair"); + logger.info( + "wrappers combination: client{async=" + + runClientWithAsyncInterceptors + + "; resume=" + + withResumability + + "} server{async=" + + runServerWithAsyncInterceptors + + "; resume=" + + withResumability + + "}"); + client + .onClose() + .doOnSubscribe(s -> logger.info("Client termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Client termination stage=" + s)) + .onErrorResume(t -> Mono.empty()) + .doOnTerminate(() -> logger.info("Client terminated. Terminating Server")) + .then(Mono.fromRunnable(server::dispose)) + .then( + server + .onClose() + .doOnSubscribe( + s -> logger.info("Server termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Server termination stage=" + s))) + .onErrorResume(t -> Mono.empty()) + .block(timeout); + + logger.info("TransportPair has been terminated"); + } + + private static class AsyncDuplexConnection implements DuplexConnection { + + private final DuplexConnection duplexConnection; + private String tag; + private final ByteBufReleaserOperator bufReleaserOperator; + + public AsyncDuplexConnection(DuplexConnection duplexConnection, String tag) { + this.duplexConnection = duplexConnection; + this.tag = tag; + this.bufReleaserOperator = new ByteBufReleaserOperator(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + duplexConnection.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + duplexConnection.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return duplexConnection + .receive() + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done before PO")) + .subscribeOn(Schedulers.boundedElastic()) + .doOnNext(ByteBuf::retain) + .publishOn(Schedulers.boundedElastic(), Integer.MAX_VALUE) + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done after PO")) + .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::safeRelease) + .transform( + Operators.lift( + (__, actual) -> { + bufReleaserOperator.actual = actual; + return bufReleaserOperator; + })); + } + + @Override + public ByteBufAllocator alloc() { + return duplexConnection.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return duplexConnection.remoteAddress(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError( + duplexConnection + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")), + bufReleaserOperator + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] BufferReleaser is done"))); + } + + @Override + public void dispose() { + duplexConnection.dispose(); + } + + @Override + public String toString() { + return "AsyncDuplexConnection{" + + "duplexConnection=" + + duplexConnection + + ", tag='" + + tag + + '\'' + + ", bufReleaserOperator=" + + bufReleaserOperator + + '}'; + } + } + + private static class DisconnectingDuplexConnection implements DuplexConnection { + + private final String tag; + final DuplexConnection source; + final Duration delay; + final Disposable.Swap disposables = Disposables.swap(); + + DisconnectingDuplexConnection(String tag, DuplexConnection source, Duration delay) { + this.tag = tag; + this.source = source; + this.delay = delay; + } + + @Override + public void dispose() { + disposables.dispose(); + source.dispose(); + } + + @Override + public Mono onClose() { + return source + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) { + source.sendErrorAndClose(errorException); + } + + boolean receivedFirst; + + @Override + public Flux receive() { + return source + .receive() + .doOnSubscribe( + __ -> logger.warn("Tag {}. Subscribing Connection[{}]", tag, source.hashCode())) + .doOnNext( + bb -> { + if (!receivedFirst) { + receivedFirst = true; + disposables.replace( + Mono.delay(delay) + .takeUntilOther(source.onClose()) + .subscribe( + __ -> { + logger.warn( + "Tag {}. Disposing Connection[{}]", tag, source.hashCode()); + source.dispose(); + })); + } + }); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public String toString() { + return "DisconnectingDuplexConnection{" + + "tag='" + + tag + + '\'' + + ", source=" + + source + + ", disposables=" + + disposables + + '}'; + } + } + + private static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + CoreSubscriber actual; + final Sinks.Empty closeableMonoSink; + + Subscription s; + + public ByteBufReleaserOperator() { + this.closeableMonoSink = Sinks.unsafe().empty(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + Mono onClose() { + return closeableMonoSink.asMono(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + closeableMonoSink.tryEmitError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + closeableMonoSink.tryEmitEmpty(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + closeableMonoSink.tryEmitEmpty(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public String toString() { + return "ByteBufReleaserOperator{" + + "isActualPresent=" + + (actual != null) + + ", " + + "isSubscriptionPresent=" + + (s != null) + + '}'; + } + } + } + + class PayloadPredicate implements Predicate { + final int expectedCnt; + int cnt; + + public PayloadPredicate(int expectedCnt) { + this.expectedCnt = expectedCnt; + } + + @Override + public boolean test(Payload p) { + boolean shouldConsume = cnt++ < expectedCnt; + if (!shouldConsume) { + logger.info( + "Metadata: \n\r{}\n\rData:{}", + p.hasMetadata() + ? new ByteBufRepresentation().fallbackToStringOf(p.sliceMetadata()) + : "Empty", + new ByteBufRepresentation().fallbackToStringOf(p.sliceData())); + } + return shouldConsume; + } } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java new file mode 100644 index 000000000..87a1d4dbf --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java @@ -0,0 +1,6 @@ +package io.rsocket.test; + +@FunctionalInterface +public interface TriFunction { + R apply(T1 t1, T2 t2, T3 t3); +} diff --git a/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation new file mode 100644 index 000000000..0c33b5ff7 --- /dev/null +++ b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation @@ -0,0 +1,16 @@ +# +# Copyright 2015-2018 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +io.rsocket.test.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-transport-local/build.gradle b/rsocket-transport-local/build.gradle index a5ba84d5c..fc32125e2 100644 --- a/rsocket-transport-local/build.gradle +++ b/rsocket-transport-local/build.gradle @@ -17,8 +17,7 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { @@ -33,4 +32,10 @@ dependencies { testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.local") + } +} + description = 'Local RSocket transport implementation' diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index b80fc2337..1b3779e85 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package io.rsocket.transport.local; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.internal.UnboundedProcessor; @@ -24,7 +23,7 @@ import io.rsocket.transport.ServerTransport; import java.util.Objects; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; /** * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} in the @@ -78,14 +77,17 @@ public Mono connect() { return Mono.error(new IllegalArgumentException("Could not find server: " + name)); } - UnboundedProcessor in = new UnboundedProcessor<>(); - UnboundedProcessor out = new UnboundedProcessor<>(); - MonoProcessor closeNotifier = MonoProcessor.create(); + Sinks.One inSink = Sinks.one(); + Sinks.One outSink = Sinks.one(); + UnboundedProcessor in = new UnboundedProcessor(inSink::tryEmitEmpty); + UnboundedProcessor out = new UnboundedProcessor(outSink::tryEmitEmpty); - server.apply(new LocalDuplexConnection(allocator, out, in, closeNotifier)).subscribe(); + Mono onClose = inSink.asMono().and(outSink.asMono()); - return Mono.just( - (DuplexConnection) new LocalDuplexConnection(allocator, in, out, closeNotifier)); + server.apply(new LocalDuplexConnection(name, allocator, out, in, onClose)).subscribe(); + + return Mono.just( + new LocalDuplexConnection(name, allocator, in, out, onClose)); }); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java index afaa14f95..c1d0fd2a3 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,36 +19,45 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; import java.util.Objects; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; /** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ final class LocalDuplexConnection implements DuplexConnection { + private final LocalSocketAddress address; private final ByteBufAllocator allocator; - private final Flux in; + private final UnboundedProcessor in; - private final MonoProcessor onClose; + private final Mono onClose; - private final Subscriber out; + private final UnboundedProcessor out; /** * Creates a new instance. * + * @param name the name assigned to this local connection * @param in the inbound {@link ByteBuf}s * @param out the outbound {@link ByteBuf}s * @param onClose the closing notifier * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} */ LocalDuplexConnection( + String name, ByteBufAllocator allocator, - Flux in, - Subscriber out, - MonoProcessor onClose) { + UnboundedProcessor in, + UnboundedProcessor out, + Mono onClose) { + this.address = new LocalSocketAddress(name); this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); this.in = Objects.requireNonNull(in, "in must not be null"); this.out = Objects.requireNonNull(out, "out must not be null"); @@ -58,12 +67,11 @@ final class LocalDuplexConnection implements DuplexConnection { @Override public void dispose() { out.onComplete(); - onClose.onComplete(); } @Override public boolean isDisposed() { - return onClose.isDisposed(); + return out.isDisposed(); } @Override @@ -73,25 +81,118 @@ public Mono onClose() { @Override public Flux receive() { - return in; + return in.transform( + Operators.lift( + (__, actual) -> new ByteBufReleaserOperator(actual, this))); } @Override - public Mono send(Publisher frames) { - Objects.requireNonNull(frames, "frames must not be null"); - - return Flux.from(frames).doOnNext(out::onNext).then(); + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + out.tryEmitPrioritized(frame); + } else { + out.tryEmitNormal(frame); + } } @Override - public Mono sendOne(ByteBuf frame) { - Objects.requireNonNull(frame, "frame must not be null"); - out.onNext(frame); - return Mono.empty(); + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + out.tryEmitFinal(errorFrame); } @Override public ByteBufAllocator alloc() { return allocator; } + + @Override + public SocketAddress remoteAddress() { + return address; + } + + @Override + public String toString() { + return "LocalDuplexConnection{" + "address=" + address + "hash=" + hashCode() + '}'; + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final LocalDuplexConnection parent; + + Subscription s; + + public ByteBufReleaserOperator( + CoreSubscriber actual, LocalDuplexConnection parent) { + this.actual = actual; + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + @Override + public void onError(Throwable t) { + parent.out.onError(t); + actual.onError(t); + } + + @Override + public void onComplete() { + parent.out.onComplete(); + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + parent.out.onComplete(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index c07713cb3..975cb6793 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,14 +17,18 @@ package io.rsocket.transport.local; import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.Objects; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; /** @@ -33,7 +37,7 @@ */ public final class LocalServerTransport implements ServerTransport { - private static final ConcurrentMap registry = + private static final ConcurrentMap registry = new ConcurrentHashMap<>(); private final String name; @@ -71,7 +75,10 @@ public static LocalServerTransport createEphemeral() { */ public static void dispose(String name) { Objects.requireNonNull(name, "name must not be null"); - registry.remove(name); + ServerCloseableAcceptor sca = registry.remove(name); + if (sca != null) { + sca.dispose(); + } } /** @@ -106,44 +113,66 @@ public Mono start(ConnectionAcceptor acceptor) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return Mono.create( sink -> { - ServerCloseable closeable = new ServerCloseable(name, acceptor); - if (registry.putIfAbsent(name, acceptor) != null) { - throw new IllegalStateException("name already registered: " + name); + ServerCloseableAcceptor closeable = new ServerCloseableAcceptor(name, acceptor); + if (registry.putIfAbsent(name, closeable) != null) { + sink.error(new IllegalStateException("name already registered: " + name)); } sink.success(closeable); }); } - static class ServerCloseable implements Closeable { + @SuppressWarnings({"ReactorTransformationOnMonoVoid", "CallingSubscribeInNonBlockingScope"}) + static class ServerCloseableAcceptor implements ConnectionAcceptor, Closeable { private final LocalSocketAddress address; private final ConnectionAcceptor acceptor; - private final MonoProcessor onClose = MonoProcessor.create(); + private final Set activeConnections = ConcurrentHashMap.newKeySet(); + + private final Sinks.Empty onClose = Sinks.unsafe().empty(); - ServerCloseable(String name, ConnectionAcceptor acceptor) { + ServerCloseableAcceptor(String name, ConnectionAcceptor acceptor) { Objects.requireNonNull(name, "name must not be null"); this.address = new LocalSocketAddress(name); this.acceptor = acceptor; } + @Override + public Mono apply(DuplexConnection duplexConnection) { + activeConnections.add(duplexConnection); + duplexConnection + .onClose() + .doFinally(__ -> activeConnections.remove(duplexConnection)) + .subscribe(); + return acceptor.apply(duplexConnection); + } + @Override public void dispose() { - if (!registry.remove(address.getName(), acceptor)) { - throw new AssertionError(); + if (!registry.remove(address.getName(), this)) { + // already disposed + return; } - onClose.onComplete(); + + Mono.whenDelayError( + activeConnections + .stream() + .peek(DuplexConnection::dispose) + .map(DuplexConnection::onClose) + .collect(Collectors.toList())) + .subscribe(null, onClose::tryEmitError, onClose::tryEmitEmpty); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return onClose; + return onClose.asMono(); } } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java index d04fd482e..4d0da126a 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import java.util.Objects; /** An implementation of {@link SocketAddress} representing a local connection. */ -final class LocalSocketAddress extends SocketAddress { +public final class LocalSocketAddress extends SocketAddress { private static final long serialVersionUID = -7513338854585475473L; @@ -32,16 +32,17 @@ final class LocalSocketAddress extends SocketAddress { * @param name the name representing the address * @throws NullPointerException if {@code name} is {@code null} */ - LocalSocketAddress(String name) { + public LocalSocketAddress(String name) { this.name = Objects.requireNonNull(name, "name must not be null"); } - @Override - public String toString() { - return "[local server] " + name; + /** Return the name for this connection. */ + public String getName() { + return name; } - String getName() { - return name; + @Override + public String toString() { + return "[local address] " + name; } } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java index ac4c13efe..095de3f0e 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java @@ -19,9 +19,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import io.rsocket.Closeable; +import java.time.Duration; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; import reactor.test.StepVerifier; final class LocalClientTransportTest { @@ -31,12 +32,20 @@ final class LocalClientTransportTest { void connect() { LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); - serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(closeable -> LocalClientTransport.create(serverTransport.getName()).connect()) - .as(StepVerifier::create) - .expectNextCount(1) - .verifyComplete(); + Closeable closeable = + serverTransport.start(duplexConnection -> duplexConnection.receive().then()).block(); + + try { + LocalClientTransport.create(serverTransport.getName()) + .connect() + .doOnNext(d -> d.receive().subscribe()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + closeable.dispose(); + closeable.onClose().block(Duration.ofSeconds(5)); + } } @DisplayName("generates error if server not started") diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java new file mode 100644 index 000000000..28c1dacac --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..8ae16a0a5 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableWithFragmentationTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableWithFragmentationTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java index ed906f65b..e4edafc39 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java @@ -96,11 +96,16 @@ void named() { @DisplayName("starts local server transport") @Test void start() { - LocalServerTransport.createEphemeral() - .start(duplexConnection -> Mono.empty()) - .as(StepVerifier::create) - .expectNextCount(1) - .verifyComplete(); + LocalServerTransport ephemeral = LocalServerTransport.createEphemeral(); + try { + ephemeral + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + LocalServerTransport.dispose(ephemeral.getName()); + } } @DisplayName("start throws NullPointerException with null acceptor") diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java index 7184dd645..87ad2105b 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,31 +16,32 @@ package io.rsocket.transport.local; -final class LocalTransportTest { // implements TransportTest { - /* - TODO // think this has a memory leak or something in the local connection now that needs to be checked into. the test - TODO // isn't very happy when run from commandline i the command line - private static final AtomicInteger UNIQUE_NAME_GENERATOR = new AtomicInteger(); +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; - private final TransportPair transportPair = - new TransportPair<>( - () -> "test" + UNIQUE_NAME_GENERATOR.incrementAndGet(), - (address, server) -> LocalClientTransport.create(address), - LocalServerTransport::create); +final class LocalTransportTest implements TransportTest { - @Override - @Test - public void requestChannel512() { + private TransportPair transportPair; - } + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> "LocalTransportTest-" + testInfo.getDisplayName() + "-" + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address)); + } - @Override - public Duration getTimeout() { - return Duration.ofSeconds(10); - } + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } - @Override - public TransportPair getTransportPair() { - return transportPair; - }*/ + @Override + public TransportPair getTransportPair() { + return transportPair; + } } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java new file mode 100644 index 000000000..3ca5f5911 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalTransportWithFragmentationTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalTransportWithFragmentationTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/resources/logback-test.xml b/rsocket-transport-local/src/test/resources/logback-test.xml index 01a7fa4cd..5c92235c2 100644 --- a/rsocket-transport-local/src/test/resources/logback-test.xml +++ b/rsocket-transport-local/src/test/resources/logback-test.xml @@ -23,10 +23,27 @@ + + ./test-out.log + false + + %-5relative %-5level %logger{35} - %msg%n + + + + + + + + + + + - + + diff --git a/rsocket-transport-netty/build.gradle b/rsocket-transport-netty/build.gradle index 64e483c90..39a5ceac5 100644 --- a/rsocket-transport-netty/build.gradle +++ b/rsocket-transport-netty/build.gradle @@ -17,19 +17,19 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' id "com.google.osdetector" version "1.4.0" } def os_suffix = "" -if (osdetector.classifier in ["linux-x86_64"] || ["osx-x86_64"] || ["windows-x86_64"]) { +if (osdetector.classifier in ["linux-x86_64", "linux-aarch_64", "osx-x86_64", "osx-aarch_64", "windows-x86_64"]) { os_suffix = "::" + osdetector.classifier } dependencies { api project(':rsocket-core') - api 'io.projectreactor.netty:reactor-netty' + api "io.projectreactor.netty:reactor-netty-core" + api "io.projectreactor.netty:reactor-netty-http" api 'org.slf4j:slf4j-api' testImplementation project(':rsocket-test') @@ -40,9 +40,21 @@ dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.junit.jupiter:junit-jupiter-params' + testRuntimeOnly 'org.bouncycastle:bcpkix-jdk15on' testRuntimeOnly 'ch.qos.logback:logback-classic' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' testRuntimeOnly 'io.netty:netty-tcnative-boringssl-static' + os_suffix } +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.netty") + } +} + +test { + minHeapSize = "512m" + maxHeapSize = "4096m" +} + description = 'Reactor Netty RSocket transport implementations (TCP, Websocket)' diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index 618708bf0..f5d36269c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,19 +19,20 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameLengthCodec; import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; import java.util.Objects; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.Connection; /** An implementation of {@link DuplexConnection} that connects via TCP. */ public final class TcpDuplexConnection extends BaseDuplexConnection { - + private final String side; private final Connection connection; - private final boolean encodeLength; /** * Creates a new instance @@ -39,29 +40,19 @@ public final class TcpDuplexConnection extends BaseDuplexConnection { * @param connection the {@link Connection} for managing the server */ public TcpDuplexConnection(Connection connection) { - this(connection, true); + this("unknown", connection); } /** * Creates a new instance * - * @param encodeLength indicates if this connection should encode the length or not. - * @param connection the {@link Connection} to for managing the server - * @deprecated as of 1.0.1 in favor of using {@link #TcpDuplexConnection(Connection)} and hence - * {@code encodeLength} should always be true. + * @param connection the {@link Connection} for managing the server */ - @Deprecated - public TcpDuplexConnection(Connection connection, boolean encodeLength) { - this.encodeLength = encodeLength; + public TcpDuplexConnection(String side, Connection connection) { this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; - connection - .channel() - .closeFuture() - .addListener( - future -> { - if (!isDisposed()) dispose(); - }); + connection.outbound().send(sender).then().doFinally(__ -> connection.dispose()).subscribe(); } @Override @@ -69,39 +60,39 @@ public ByteBufAllocator alloc() { return connection.channel().alloc(); } + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + @Override protected void doOnClose() { - if (!connection.isDisposed()) { - connection.dispose(); - } + connection.dispose(); } @Override - public Flux receive() { - return connection.inbound().receive().map(this::decode); + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); } @Override - public Mono send(Publisher frames) { - if (frames instanceof Mono) { - return connection.outbound().sendObject(((Mono) frames).map(this::encode)).then(); - } - return connection.outbound().send(Flux.from(frames).map(this::encode)).then(); + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(FrameLengthCodec.encode(alloc(), errorFrame.readableBytes(), errorFrame)); } - private ByteBuf encode(ByteBuf frame) { - if (encodeLength) { - return FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame); - } else { - return frame; - } + @Override + public Flux receive() { + return connection.inbound().receive().map(FrameLengthCodec::frame); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + super.sendFrame(streamId, FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame)); } - private ByteBuf decode(ByteBuf frame) { - if (encodeLength) { - return FrameLengthCodec.frame(frame).retain(); - } else { - return frame; - } + @Override + public String toString() { + return "TcpDuplexConnection{" + "side='" + side + '\'' + ", connection=" + connection + '}'; } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index 0183ef19d..8f1170c5b 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; import java.util.Objects; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.Connection; @@ -34,7 +36,7 @@ * stitched back on for frames received. */ public final class WebsocketDuplexConnection extends BaseDuplexConnection { - + private final String side; private final Connection connection; /** @@ -43,15 +45,24 @@ public final class WebsocketDuplexConnection extends BaseDuplexConnection { * @param connection the {@link Connection} to for managing the server */ public WebsocketDuplexConnection(Connection connection) { + this("unknown", connection); + } + + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(String side, Connection connection) { this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; connection - .channel() - .closeFuture() - .addListener( - future -> { - if (!isDisposed()) dispose(); - }); + .outbound() + .sendObject(sender.map(BinaryWebSocketFrame::new)) + .then() + .doFinally(__ -> connection.dispose()) + .subscribe(); } @Override @@ -59,29 +70,40 @@ public ByteBufAllocator alloc() { return connection.channel().alloc(); } + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + @Override protected void doOnClose() { - if (!connection.isDisposed()) { - connection.dispose(); - } + connection.dispose(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); } @Override public Flux receive() { - return connection.inbound().receive().map(ByteBuf::retain); + return connection.inbound().receive(); } @Override - public Mono send(Publisher frames) { - if (frames instanceof Mono) { - return connection - .outbound() - .sendObject(((Mono) frames).map(BinaryWebSocketFrame::new)) - .then(); - } - return connection - .outbound() - .sendObject(Flux.from(frames).map(BinaryWebSocketFrame::new)) - .then(); + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(errorFrame); + } + + @Override + public String toString() { + return "WebsocketDuplexConnection{" + + "side='" + + side + + '\'' + + ", connection=" + + connection + + '}'; } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index f64c6063c..84214b98c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -116,6 +116,6 @@ public Mono connect() { return client .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec(maxFrameLength))) .connect() - .map(TcpDuplexConnection::new); + .map(connection -> new TcpDuplexConnection("client", connection)); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index dd6c535db..86be47893 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -27,10 +27,8 @@ import java.net.InetSocketAddress; import java.net.URI; import java.util.Arrays; -import java.util.Map; import java.util.Objects; import java.util.function.Consumer; -import java.util.function.Supplier; import reactor.core.publisher.Mono; import reactor.netty.http.client.HttpClient; import reactor.netty.http.client.WebsocketClientSpec; @@ -40,9 +38,7 @@ * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} over * WebSocket. */ -@SuppressWarnings("deprecation") -public final class WebsocketClientTransport - implements ClientTransport, io.rsocket.transport.TransportHeaderAware { +public final class WebsocketClientTransport implements ClientTransport { private static final String DEFAULT_PATH = "/"; @@ -164,13 +160,6 @@ public WebsocketClientTransport webSocketSpec(Consumer> transportHeaders) { - if (transportHeaders != null) { - transportHeaders.get().forEach((name, value) -> headers.add(name, value)); - } - } - @Override public int maxFrameLength() { return specBuilder.build().maxFramePayloadLength(); @@ -183,6 +172,6 @@ public Mono connect() { .websocket(specBuilder.build()) .uri(path) .connect() - .map(WebsocketDuplexConnection::new); + .map(connection -> new WebsocketDuplexConnection("client", connection)); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java index 5f04eb575..33cff28b4 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java @@ -24,10 +24,7 @@ abstract class BaseWebsocketServerTransport< private static final ChannelHandler pongHandler = new PongHandler(); static Function serverConfigurer = - server -> - server.tcpConfiguration( - tcpServer -> - tcpServer.doOnConnection(connection -> connection.addHandlerLast(pongHandler))); + server -> server.doOnConnection(connection -> connection.addHandlerLast(pongHandler)); final WebsocketServerSpec.Builder specBuilder = WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java index c4a257f76..7e98905ff 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java @@ -60,8 +60,8 @@ public final class CloseableChannel implements Closeable { */ public InetSocketAddress address() { try { - return channel.address(); - } catch (NoSuchMethodError e) { + return (InetSocketAddress) channel.address(); + } catch (ClassCastException | NoSuchMethodError e) { try { return (InetSocketAddress) channelAddressMethod.invoke(this.channel); } catch (Exception ex) { diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index effc7bed5..32562c4a4 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -114,7 +114,7 @@ public Mono start(ConnectionAcceptor acceptor) { c -> { c.addHandlerLast(new RSocketLengthCodec(maxFrameLength)); acceptor - .apply(new TcpDuplexConnection(c)) + .apply(new TcpDuplexConnection("server", c)) .then(Mono.never()) .subscribe(c.disposeSubscriber()); }) diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index 38344c472..db13720e7 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -80,6 +80,8 @@ public Mono start(ConnectionAcceptor acceptor) { public static BiFunction> newHandler( ConnectionAcceptor acceptor) { return (in, out) -> - acceptor.apply(new WebsocketDuplexConnection((Connection) in)).then(out.neverComplete()); + acceptor + .apply(new WebsocketDuplexConnection("server", (Connection) in)) + .then(out.neverComplete()); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 4fb6417c9..4fe736fad 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -23,9 +23,7 @@ import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.net.InetSocketAddress; import java.util.Arrays; -import java.util.Map; import java.util.Objects; -import java.util.function.Supplier; import reactor.core.publisher.Mono; import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; @@ -34,10 +32,8 @@ * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a * Websocket. */ -@SuppressWarnings("deprecation") public final class WebsocketServerTransport - extends BaseWebsocketServerTransport - implements io.rsocket.transport.TransportHeaderAware { + extends BaseWebsocketServerTransport { private final HttpServer server; @@ -111,13 +107,6 @@ public WebsocketServerTransport header(String name, String... values) { return this; } - @Override - public void setTransportHeaders(Supplier> transportHeaders) { - if (transportHeaders != null) { - transportHeaders.get().forEach((name, value) -> headers.add(name, value)); - } - } - @Override public Mono start(ConnectionAcceptor acceptor) { Objects.requireNonNull(acceptor, "acceptor must not be null"); @@ -128,7 +117,7 @@ public Mono start(ConnectionAcceptor acceptor) { return response.sendWebsocket( (in, out) -> acceptor - .apply(new WebsocketDuplexConnection((Connection) in)) + .apply(new WebsocketDuplexConnection("server", (Connection) in)) .then(out.neverComplete()), specBuilder.build()); }) diff --git a/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json new file mode 100644 index 000000000..3a2baa440 --- /dev/null +++ b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json @@ -0,0 +1,16 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.RSocketLengthCodec" + }, + "name": "io.rsocket.transport.netty.RSocketLengthCodec", + "queryAllPublicMethods": true + }, + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler" + }, + "name": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler", + "queryAllPublicMethods": true + } +] \ No newline at end of file diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java new file mode 100644 index 000000000..f05713215 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java @@ -0,0 +1,190 @@ +package io.rsocket.integration; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; +import reactor.util.retry.RetryBackoffSpec; + +/** + * Test case that reproduces the following GitHub Issue + */ +public class KeepaliveTest { + + private static final Logger LOG = LoggerFactory.getLogger(KeepaliveTest.class); + private static final int PORT = 23200; + + private CloseableChannel server; + + @BeforeEach + void setUp() { + server = createServer().block(); + } + + @AfterEach + void tearDown() { + server.dispose(); + server.onClose().block(); + } + + @Test + void keepAliveTest() { + RSocketClient rsocketClient = createClient(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketClient + .requestResponse(Mono.just(DefaultPayload.create(""))) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete())) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + @Test + void keepAliveTestLazy() { + Mono rsocketMono = createClientLazy(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketMono.flatMap( + rsocket -> + rsocket + .requestResponse(DefaultPayload.create("")) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete()))) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + private static Mono createServer() { + LOG.info("Starting server at port {}", PORT); + + TcpServer tcpServer = TcpServer.create().host("localhost").port(PORT); + + return RSocketServer.create( + (setupPayload, rSocket) -> { + rSocket + .onClose() + .doFirst(() -> LOG.info("Connected on server side.")) + .doOnTerminate(() -> LOG.info("Connection closed on server side.")) + .subscribe(); + + return Mono.just(new MyServerRsocket()); + }) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(tcpServer)) + .doOnNext(closeableChannel -> LOG.info("RSocket server started.")); + } + + private static RSocketClient createClient() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + Mono rsocketMono = + RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + + RSocketClient client = RSocketClient.from(rsocketMono); + + client + .source() + .doOnNext(r -> LOG.info("Got RSocket")) + .flatMap(RSocket::onClose) + .doOnError(err -> LOG.error("Error during onClose.", err)) + .retryWhen(reconnectSpec.apply("client-close")) + .doFirst(() -> LOG.info("Connected on client side.")) + .doOnTerminate(() -> LOG.info("Connection closed on client side.")) + .repeat() + .subscribe(); + + return client; + } + + private static Mono createClientLazy() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + return RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + } + + public static class MyServerRsocket implements RSocket { + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just("Pong").map(DefaultPayload::create); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java index 6fd3de791..76c352768 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.transport.netty; import io.rsocket.ConnectionSetupPayload; @@ -20,9 +35,9 @@ import java.util.function.Function; import java.util.stream.Stream; import org.junit.jupiter.params.provider.Arguments; -import reactor.core.publisher.EmitterProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; public class SetupRejectionTest { @@ -85,21 +100,21 @@ static Stream transports() { } static class ErrorConsumer implements Consumer { - private final EmitterProcessor errors = EmitterProcessor.create(); + private final Sinks.Many errors = Sinks.many().multicast().onBackpressureBuffer(); @Override public void accept(Throwable t) { - errors.onNext(t); + errors.tryEmitNext(t); } Flux errors() { - return errors; + return errors.asFlux(); } } private static class RejectingAcceptor implements SocketAcceptor { private final String msg; - private final EmitterProcessor requesters = EmitterProcessor.create(); + private final Sinks.Many requesters = Sinks.many().multicast().onBackpressureBuffer(); public RejectingAcceptor(String msg) { this.msg = msg; @@ -107,12 +122,12 @@ public RejectingAcceptor(String msg) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { - requesters.onNext(sendingSocket); + requesters.tryEmitNext(sendingSocket); return Mono.error(new RuntimeException(msg)); } public Mono requesterRSocket() { - return requesters.next(); + return requesters.asFlux().next(); } } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java new file mode 100644 index 000000000..b17da654f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(2); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java new file mode 100644 index 000000000..7be1c1c54 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..39b3cec67 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java index 95bebd6aa..ee49b83cd 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -1,5 +1,22 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -9,39 +26,47 @@ import java.net.InetSocketAddress; import java.security.cert.CertificateException; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.core.Exceptions; import reactor.netty.tcp.TcpClient; import reactor.netty.tcp.TcpServer; public class TcpSecureTransportTest implements TransportTest { - private final TransportPair transportPair = - new TransportPair<>( - () -> new InetSocketAddress("localhost", 0), - (address, server) -> - TcpClientTransport.create( - TcpClient.create() - .remoteAddress(server::address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE)))), - address -> { - try { - SelfSignedCertificate ssc = new SelfSignedCertificate(); - TcpServer server = - TcpServer.create() - .bindAddress(() -> address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forServer( - ssc.certificate(), ssc.privateKey()))); - return TcpServerTransport.create(server); - } catch (CertificateException e) { - throw Exceptions.propagate(e); - } - }); + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java index 182be1d91..428681f3e 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,36 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.rsocket.test.TransportTest; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; final class TcpTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server) -> TcpClientTransport.create(server.address()), - TcpServerTransport::create); + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java index e2ee9e521..ff0fa75b4 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.transport.netty; import io.netty.buffer.Unpooled; @@ -25,8 +40,9 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; import reactor.netty.http.client.HttpClient; import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; @@ -100,13 +116,13 @@ private static Stream provideServerTransport() { } private static class PingSender extends ChannelInboundHandlerAdapter { - private final MonoProcessor channel = MonoProcessor.create(); - private final MonoProcessor pong = MonoProcessor.create(); + private final Sinks.One channel = Sinks.one(); + private final Sinks.One pong = Sinks.one(); @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof PongWebSocketFrame) { - pong.onNext(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); + pong.tryEmitValue(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); ReferenceCountUtil.safeRelease(msg); ctx.read(); } else { @@ -117,8 +133,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception @Override public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { Channel ch = ctx.channel(); - if (!channel.isTerminated() && ch.isWritable()) { - channel.onNext(ctx.channel()); + if (!(channel.scan(Scannable.Attr.TERMINATED)) && ch.isWritable()) { + channel.tryEmitValue(ctx.channel()); } super.channelWritabilityChanged(ctx); } @@ -127,7 +143,7 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio public void handlerAdded(ChannelHandlerContext ctx) throws Exception { Channel ch = ctx.channel(); if (ch.isWritable()) { - channel.onNext(ch); + channel.tryEmitValue(ch); } super.handlerAdded(ctx); } @@ -142,11 +158,11 @@ public Mono sendPong() { } public Mono receivePong() { - return pong; + return pong.asMono(); } private Mono send(WebSocketFrame webSocketFrame) { - return channel.doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); + return channel.asMono().doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); } } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java new file mode 100644 index 000000000..043f6bc64 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..b1ca65fcc --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java index ec33060b2..81f7ffb95 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.SelfSignedCertificate; @@ -25,49 +26,54 @@ import java.net.InetSocketAddress; import java.security.cert.CertificateException; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.core.Exceptions; import reactor.netty.http.client.HttpClient; import reactor.netty.http.server.HttpServer; -import reactor.netty.tcp.TcpServer; final class WebsocketSecureTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> new InetSocketAddress("localhost", 0), - (address, server) -> - WebsocketClientTransport.create( - HttpClient.create() - .remoteAddress(server::address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE))), - String.format( - "https://%s:%d/", - server.address().getHostName(), server.address().getPort())), - address -> { - try { - SelfSignedCertificate ssc = new SelfSignedCertificate(); - HttpServer server = - HttpServer.from( - TcpServer.create() - .bindAddress(() -> address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forServer( - ssc.certificate(), ssc.privateKey())))); - return WebsocketServerTransport.create(server); - } catch (CertificateException e) { - throw Exceptions.propagate(e); - } - }); + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE))), + String.format( + "https://%s:%d/", + server.address().getHostName(), server.address().getPort())), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + HttpServer server = + HttpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return WebsocketServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } @Override public Duration getTimeout() { - return Duration.ofMinutes(3); + return Duration.ofMinutes(5); } @Override diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java index 10d27daeb..cdd507456 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,19 +16,39 @@ package io.rsocket.transport.netty; +import io.netty.channel.ChannelOption; import io.rsocket.test.TransportTest; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.server.WebsocketServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; final class WebsocketTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server) -> WebsocketClientTransport.create(server.address()), - address -> WebsocketServerTransport.create(address.getHostName(), address.getPort())); + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java index 944d20313..2a3670251 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -22,7 +22,6 @@ import io.rsocket.transport.netty.server.WebsocketServerTransport; import java.net.InetSocketAddress; import java.net.URI; -import java.util.Collections; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -150,10 +149,4 @@ void createUriPath() { .isNotNull() .hasFieldOrPropertyWithValue("path", "/test"); } - - @DisplayName("sets transport headers") - @Test - void setTransportHeader() { - WebsocketClientTransport.create(8000).setTransportHeaders(Collections::emptyMap); - } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java index 308118955..bd53a9b3f 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java @@ -19,7 +19,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; @@ -57,9 +56,6 @@ void constructorNullContext() { .withMessage("channel must not be null"); } - @Disabled( - "NettyContext isDisposed() is not accurate\n" - + "https://github.com/reactor/reactor-netty/issues/360") @DisplayName("disposes context") @Test void dispose() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java index 7f7567dc8..540076704 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -19,9 +19,9 @@ import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.ArgumentMatchers.any; import java.net.InetSocketAddress; -import java.util.Collections; import java.util.function.BiFunction; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -31,29 +31,35 @@ import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServerRequest; import reactor.netty.http.server.HttpServerResponse; +import reactor.netty.http.server.WebsocketServerSpec; import reactor.test.StepVerifier; final class WebsocketServerTransportTest { - // @Test + @Test public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { - ArgumentCaptor captor = ArgumentCaptor.forClass(BiFunction.class); - HttpServer httpServer = Mockito.spy(HttpServer.create()); - Mockito.doAnswer(a -> httpServer).when(httpServer).handle(captor.capture()); - Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); - - WebsocketServerTransport serverTransport = WebsocketServerTransport.create(httpServer); + ArgumentCaptor httpHandlerCaptor = ArgumentCaptor.forClass(BiFunction.class); + HttpServer server = Mockito.spy(HttpServer.create()); + Mockito.doAnswer(a -> server).when(server).handle(httpHandlerCaptor.capture()); + Mockito.doAnswer(a -> server).when(server).doOnConnection(any()); + Mockito.doAnswer(a -> Mono.empty()).when(server).bind(); + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(server); serverTransport.start(c -> Mono.empty()).subscribe(); HttpServerRequest httpServerRequest = Mockito.mock(HttpServerRequest.class); HttpServerResponse httpServerResponse = Mockito.mock(HttpServerResponse.class); - captor.getValue().apply(httpServerRequest, httpServerResponse); + httpHandlerCaptor.getValue().apply(httpServerRequest, httpServerResponse); + + ArgumentCaptor handlerCaptor = ArgumentCaptor.forClass(BiFunction.class); + ArgumentCaptor specCaptor = + ArgumentCaptor.forClass(WebsocketServerSpec.class); + + Mockito.verify(httpServerResponse).sendWebsocket(handlerCaptor.capture(), specCaptor.capture()); - Mockito.verify(httpServerResponse) - .sendWebsocket( - Mockito.nullable(String.class), Mockito.eq(FRAME_LENGTH_MASK), Mockito.any()); + WebsocketServerSpec spec = specCaptor.getValue(); + assertThat(spec.maxFramePayloadLength()).isEqualTo(FRAME_LENGTH_MASK); } @DisplayName("creates server with BindAddress") @@ -107,12 +113,6 @@ void createPort() { assertThat(WebsocketServerTransport.create(8000)).isNotNull(); } - @DisplayName("sets transport headers") - @Test - void setTransportHeader() { - WebsocketServerTransport.create(8000).setTransportHeaders(Collections::emptyMap); - } - @DisplayName("starts server") @Test void start() { diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml index f9dec2bbe..981d6d0b6 100644 --- a/rsocket-transport-netty/src/test/resources/logback-test.xml +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -26,6 +26,14 @@ + + + + + + + +